{-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} module TeachingGADTs where import Control.Applicative import Control.Monad import Unsafe.Coerce -- Motivation: how would you write a "typed" expression language and its evaluator? -- We use GADT *syntax* here but we don't use GADT *semantic* features -- introducing GADT syntax: here's how you'd do Maybe data Maybe' a where Just' :: a -> Maybe' a Nothing' :: Maybe' a data Exp where PlusExp :: Exp -> Exp -> Exp LessExp :: Exp -> Exp -> Exp IntExp :: Int -> Exp BoolExp :: Bool -> Exp IfTeExp :: Exp -> Exp -> Exp -> Exp -- type-checking happens at run-time during evaluation, and we can easily get it wrong eval :: Exp -> Maybe (Either Int Bool) -- we don't get compile-time type-checking for our DSL eval (PlusExp e1 e2) = do Left i1 <- eval e1 Left i2 <- eval e2 pure . Left $ i1 + i2 eval (LessExp e1 e2) = do Left i1 <- eval e1 Left i2 <- eval e2 pure . Right $ i1 < i2 eval (IntExp i) = Just $ Left i eval (BoolExp b) = Just $ Right b eval (IfTeExp e1 e2 e3) = do Right b <- eval e1 v2 <- eval e2 v3 <- eval e3 guard (bothLeft v2 v3 || bothRight v2 v3) -- this line would be easy to forget! pure $ if b then v2 else v3 where bothLeft (Left _) (Left _) = True bothLeft _ _ = False bothRight (Right _) (Right _) = True bothRight _ _ = False -- A small expression language data Expr a where Plus :: (Num a) => Expr a -> Expr a -> Expr a Less :: (Ord a) => Expr a -> Expr a -> Expr Bool Lift :: a -> Expr a IfTe :: Expr Bool -> Expr a -> Expr a -> Expr a -- ... is equivalent to ... data Expr' a = forall x. (Num x, a ~ x) => Plus' (Expr' x) (Expr' x) | forall x. (Ord x, a ~ Bool) => Less' (Expr' x) (Expr' x) | forall x. (x ~ a) => Lift' x | forall x b. (x ~ a, b ~ Bool) => IfTe' (Expr' b) (Expr' x) (Expr' x) evaluate :: Expr a -> a -- wow! types! evaluate expr = case expr of Plus e1 e2 -> let v1 = evaluate e1 v2 = evaluate e2 in v1 + v2 Less e1 e2 -> let v1 = evaluate e1 v2 = evaluate e2 in v1 < v2 Lift e -> e IfTe e1 e2 e3 -> let v1 = evaluate e1 v2 = evaluate e2 v3 = evaluate e3 in if v1 then v2 else v3 -- witness of equality proof data x :~: y where Refl :: x :~: x -- discharge an equality proof -- notice what happens when you pass undefined coerce :: x :~: y -> x -> y coerce Refl x = x -- witness of a constraint data Dict c a where Dict :: c a => Dict c a -- discharge a constraint proof withDict :: Dict c a -> (c a => b) -> b withDict Dict x = x -- specialized version of Dict -- example to "Show" what's going on data IsShow a where IsShow :: (Show a) => IsShow a -- a weird GADT data IsInt (x :: Bool) where Yep :: (a ~ Int) => a -> IsInt 'True Perhaps :: a -> IsInt 'False -- this is total! getMeAnIntPlease :: IsInt True -> Int getMeAnIntPlease (Yep x) = x -- ordinary value-level Nat data Nat = S Nat | Z -- singleton value-level Nat which witnesses a type-level Nat -- (we get the type-level Nat via DataKinds) data SNat (n :: Nat) where SZ :: SNat Z SS :: SNat n -> SNat (S n) -- addition of SNats which preserves type index properly plus :: SNat n -> SNat m -> SNat (n + m) plus SZ n = n plus (SS n) m = SS (plus n m) -- derive instances for GADTs this way deriving instance Show (SNat n) -- length-indexed list data Vec (n :: Nat) (a :: *) where Nil :: Vec Z a Cons :: a -> Vec n a -> Vec (S n) a deriving instance Show a => Show (Vec n a) -- An excursion: what we would do with DataKinds, but WITHOUT TypeFamilies & GADTs -- Here we use phantom types and MODULE ABSTRACTION to manually verify & enforce invariants data FakeVec (n :: Nat) (a :: *) = FakeVec [a] deriving (Read) -- but OH NO we accidentally everything forever and we are now sad fakeNil :: FakeVec Z a fakeNil = FakeVec [] fakeCons :: a -> FakeVec n a -> FakeVec (S n) a fakeCons a (FakeVec as) = FakeVec (a:as) fakeVecSplit :: FakeVec (S n) a -> (a, FakeVec n a) fakeVecSplit (FakeVec (a:as)) = (a, FakeVec as) fakeVecSplit (FakeVec []) = error "invariant violation: FakeVec is empty!" -- We can violate the invariant by reading a string at the "wrong" phantom type! -- -- an example Vec vec1 :: Vec (S (S (S Z))) Char vec1 = Cons 'A' (Cons 'B' (Cons 'C' Nil)) -- this is exhaustive (total)! zipSame :: Vec n a -> Vec n b -> Vec n (a, b) zipSame Nil Nil = Nil zipSame (Cons x xs) (Cons y ys) = Cons (x,y) (zipSame xs ys) -- length of a vector as a singleton Nat (SNat) vecLength :: Vec n a -> SNat n vecLength Nil = SZ vecLength (Cons x xs) = SS $ vecLength xs -- type level addition function (requires TypeFamilies & TypeOperators) type family a + b where Z + n = n S n + m = S (n + m) -- GHC will verify this automatically, because it's the exact same recursion pattern as (+) easyAppend :: Vec m a -> Vec n a -> Vec (m + n) a easyAppend Nil ys = ys easyAppend (Cons x xs) ys = Cons x (easyAppend xs ys) -- But... -- the type level addition '+' is not *automatically* provable -- to be commutative, and hence fails to typecheck... unless we PROVE IT using a lemma. hardAppend :: Vec m a -> Vec n a -> Vec (n + m) a hardAppend v w = case additionCommutative (vecLength v) (vecLength w) of Refl -> easyAppend v w -- sub-lemma: zero is a right neutral for (+) rightNeutral :: SNat n -> n :~: (n + Z) rightNeutral SZ = Refl rightNeutral (SS n) = case rightNeutral n of Refl -> Refl -- sub-lemma: n + S m = S (n + m) plusSucc :: SNat n -> SNat m -> (n + S m) :~: S (n + m) plusSucc SZ _ = Refl plusSucc (SS n) m = case plusSucc n m of Refl -> Refl -- we can use these to prove for any given SNat n, m that addition commutes -- note: this is O(n^2) and MUST BE EXECUTED AT RUN-TIME additionCommutative :: SNat n -> SNat m -> (n + m) :~: (m + n) additionCommutative SZ n = case rightNeutral n of Refl -> Refl additionCommutative (SS m) n = case additionCommutative m n of Refl -> case plusSucc n m of Refl -> Refl -- we can use a type-level minimum function to type-check a truncating zip type family Min (m :: Nat) (n :: Nat) where Min Z y = Z Min x Z = Z Min (S x) (S y) = S (Min x y) -- truncating zip (ala Haskell's ordinary zip) zipMin :: Vec m a -> Vec n b -> Vec (Min m n) (a, b) zipMin Nil _ = Nil zipMin _ Nil = Nil zipMin (Cons x xs) (Cons y ys) = Cons (x, y) (zipMin xs ys) -- Addendum: making things go fast again: -- And here's how we can make things go fast, unsafely -- This is more or less what a dependently typed language can do in some circumstances, -- because if it is total, it knows that it doesn't actually need to run proofs to -- make sure they're not bottom. -- If there's a runtime-costly proof which you are ABSOLUTELY CERTAIN will never be equal -- to bottom (i.e. is the result of a DEFINITELY TOTAL function), you can wrap it in this -- function to avoid ever forcing it and doing the extra work to run the proof. unsafeEraseProof :: forall a b. (a :~: b) -> (a :~: b) unsafeEraseProof _proof = unsafeCoerce Refl :: a :~: b -- append two vectors using a type requiring addition to be commutative, -- but skip actually running the proof at runtime fastHardAppend :: Vec m a -> Vec n a -> Vec (n + m) a fastHardAppend v w = case unsafeEraseProof (additionCommutative (vecLength v) (vecLength w)) of Refl -> easyAppend v w -- existentially quantify over the length of a vector data SomeVec a where SomeVec :: Vec n a -> SomeVec a deriving instance Show a => Show (SomeVec a) -- convert a list into an existential-length-ed vector toVec :: [a] -> SomeVec a toVec [] = SomeVec Nil toVec (x : xs) = case toVec xs of SomeVec xs' -> SomeVec (Cons x xs') -- this will run sloooowwwwwly -- O(n^2) slowHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a slowHardAppendTest (SomeVec x) (SomeVec y) = SomeVec $ hardAppend x y -- this will run quick -- O(n) fastHardAppendTest :: SomeVec a -> SomeVec a -> SomeVec a fastHardAppendTest (SomeVec x) (SomeVec y) = SomeVec $ fastHardAppend x y -- for instance, try out: -- > slowHardAppendTest (toVec [0..3000]) (toVec [0..3000]) -- > fastHardAppendTest (toVec [0..3000]) (toVec [0..3000]) -- Notice that there is a noticeable pause before slowHardAppendTest begins printing output. -- This is the time it takes to force the thunk which evaluates to the proof of addition being -- commutative; that is, this entire time is spent evaluating line 185 of this file.