Created
May 5, 2026 15:27
-
-
Save neel-krishnaswami/09b525a6d6d630f358c8c8a206ff872b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| (* Type inference for the STLC, but with bidirectional constraint generation *) | |
| (* This is basically unification-based type inference *without* let-generalization. | |
| *) | |
| type evar = string | |
| type tp = TVar of evar | Arrow of tp * tp | Bool | |
| (* The type of constraints. We assume there are n shadowed or duplicated | |
| variables in constraints, which we enforce by generating fresh | |
| variables for any new name. *) | |
| type constr = | |
| | Eq of tp * tp | |
| | Exists of evar * constr | |
| | And of constr * constr | |
| | Top | |
| (* exist [x1, .., xn] C returns ∃x1. ... ∃xn. C *) | |
| let exist vars c = | |
| List.fold_right (fun a c -> Exists(a, c)) vars c | |
| (* Lambda terms *) | |
| type var = string | |
| type exp = | |
| | Var of var | |
| | Lam of var * exp | |
| | App of exp * exp | |
| | Annot of exp * tp | |
| | Blit of bool | |
| | If of exp * exp * exp | |
| | Let of var * exp * exp | |
| (* Typechecking monad – basically supports fresh names + error stops *) | |
| module TC = struct | |
| type 'a t = M of (int -> ('a * int) option) | |
| let map f (M m) = M (fun n -> match m n with | |
| | None -> None | |
| | Some (a, n) -> Some(f a, n)) | |
| let return x = M(fun n -> Some(x, n)) | |
| let (let+) (M c) f = | |
| M(fun n -> | |
| match c n with | |
| | None -> None | |
| | Some(v, n') -> let M g = f v in | |
| g n') | |
| let fail = M(fun n -> None) | |
| let fresh name = | |
| M(fun n -> | |
| let s = Printf.sprintf "_%s_%d" name n in | |
| Some(s, n+1)) | |
| end | |
| module Destruct = struct | |
| open TC | |
| (* Destructors: | |
| Each type destructor returns the subcomponents of a type if the | |
| argument type has the right shape. If it's an evar, then it | |
| generates fresh evars and a constraint that it has the right | |
| shape. Otherwise it fails. | |
| *) | |
| let arrow = function | |
| | Arrow(tp1, tp2) -> return ((tp1, tp2), [], Top) | |
| | TVar a -> | |
| let+ dom = fresh "dom" in | |
| let+ cod = fresh "cod" in | |
| return ((TVar dom, TVar cod), | |
| [dom; cod], | |
| Eq(TVar a, Arrow(TVar dom, TVar cod))) | |
| | _ -> fail | |
| let bool = function | |
| | Bool -> return ((), [], Top) | |
| | TVar a -> return ((), [], Eq(TVar a, Bool)) | |
| | _ -> fail | |
| end | |
| module Bidir = struct | |
| open TC | |
| (* Looking up variables in a context *) | |
| type ctx = (var * tp) list | |
| let lookup ctx x = | |
| match List.assoc_opt x ctx with | |
| | None -> fail | |
| | Some tp -> return tp | |
| (* Bidirectional typechecking: | |
| 1. The checking judgement Σ; Γ ⊢ e ⇐ τ ↝ C | |
| This takes a current scope of evars Σ, a context Γ (all of whose | |
| evars are in Σ), a term e, and an expected type τ (whose evars are | |
| in Σ), and returns a constraint C, with free variables bounded by | |
| Σ. | |
| Checking never extends the constraint context. We introduce | |
| existential variables as needed to make this happen. | |
| 2. The synthesis judgement Σ; Γ ⊢ e ⇒ τ ↝ C ⊣ Σ' | |
| This takes a current scope of evars Σ; a context Γ (whose evars | |
| are in Σ), a term e, and returns an expected type τ, a constraint | |
| C, and an extended set of evars Σ' ≥ Σ bounding the free evars of | |
| τ and C. | |
| Synthesis *can* extend the constraint context. | |
| The rules look very similar to the usual rules for bidirectional | |
| typechecking. The main differences are that | |
| a) we use destructors instead of pattern matching to disassemble | |
| types, and | |
| b) when synthesis meets a checking term, instead of signalling | |
| an error, we create a fresh type variable and check against | |
| that. | |
| *) | |
| let rec check dom ctx exp tp = | |
| match exp with | |
| | Blit b -> let+ ((), vars, c0) = Destruct.bool tp in | |
| return (exist vars c0) | |
| | Lam(x, e) -> let+ ((tp1, tp2), vars, c0) = Destruct.arrow tp in | |
| let+ c1 = check (dom @ vars) ((x, tp1) :: ctx) e tp2 in | |
| return (exist vars (And(c0, c1))) | |
| | If(e1, e2, e3) -> | |
| let+ c1 = check dom ctx e1 Bool in | |
| let+ c2 = check dom ctx e2 tp in | |
| let+ c3 = check dom ctx e3 tp in | |
| return (And(c1, And(c2, c3))) | |
| | Let(x, e1, e2) -> | |
| let+ (tp1, c1, xs) = synth dom ctx e1 in | |
| let+ c2 = check (xs @ dom) ((x, tp1) :: ctx) e2 tp in | |
| return (exist xs (And(c1, c2))) | |
| | e -> let+ (tp', c, xs) = synth dom ctx e in | |
| return (exist xs (And(c, Eq(tp, tp')))) | |
| and synth dom ctx exp = | |
| match exp with | |
| | App(e1, e2) -> | |
| let+ (tp1, c0, xs0) = synth dom ctx e1 in | |
| let+ ((tp_arg, tp), xs1, c1) = Destruct.arrow tp1 in | |
| let+ c2 = check (xs0 @ xs1 @ dom) ctx e2 tp_arg in | |
| return (tp, And(c0, And(c1, c2)), xs0 @ xs1) | |
| | Var x -> | |
| let+ tp = lookup ctx x in | |
| return (tp, Top, []) | |
| | Annot(e, tp) -> | |
| let+ c = check dom ctx e tp in | |
| return (tp, c, []) | |
| (* This is the case which is different from usual – instead | |
| of erroring out, we introduce a fresh evar and check e | |
| against that. *) | |
| | e -> | |
| let+ a = fresh "check" in | |
| let+ c = check (a :: dom) ctx e (TVar a) in | |
| return (TVar a, c, [a]) | |
| end | |
| (* An implementation of parallel substitutions *) | |
| module Subst = struct | |
| type subst = S of (evar * tp) list | |
| (* Our invariant is that substitutions are total, so identity | |
| substitutions map evars to the corresponding type *) | |
| (* id dom return dom ⊢ id[dom] : dom | |
| Note that since weakening is sound, if dom' ⊇ dom, then | |
| dom' ⊢ id dom : dom as well. | |
| *) | |
| let id dom = | |
| S (List.map (fun a -> (a, TVar a)) dom) | |
| (* If dom |- s : dom' and dom' |- tp then dom |- apply s tp *) | |
| let rec apply (S s) = function | |
| | Bool -> Bool | |
| | Arrow(tp1, tp2) -> Arrow(apply (S s) tp1, apply (S s) tp2) | |
| | TVar a -> List.assoc a s | |
| (* If dom'' |- s1 : dom' and dom' |- s2 : dom then dom'' |- compose s1 s2 : dom *) | |
| let rec compose s1 (S s2) = | |
| S (List.map (fun (x, tp) -> (x, apply s1 tp)) s2) | |
| (* if dom ⊢ s1 : dom1 and dom ⊢ s2 : dom2 and dom1 ∩ dom2 = ∅, then | |
| dom ⊢ s1 ** s2 : dom1, dom2 *) | |
| let ( ** ) (S s1) (S s2) = S (s1 @ s2) | |
| (* if dom |- e then dom ⊢ (singleton e x) : x *) | |
| let singleton e x = S [x, e] | |
| (* restrict a s – remove the variable a from s *) | |
| let restrict a (S s) = S(List.filter (fun (b, _) -> a != b) s) | |
| end | |
| (* A simple implementation of unification *) | |
| module Solve = struct | |
| open TC | |
| open Subst | |
| (* occurs a tp: Does the variable a occur in tp? *) | |
| let rec occurs a tp = | |
| match tp with | |
| | TVar b -> a = b | |
| | Bool -> false | |
| | Arrow(tp1, tp2) -> occurs a tp1 || occurs a tp2 | |
| (* unify dom tp1 tp2 | |
| If dom ⊢ tp1 and dom ⊢ tp2 then: | |
| - if unify dom tp1 tp2 = None, then there is no unifier | |
| - if unify dom tp1 tp2 = Some s, then s is the most general unifier | |
| *) | |
| let rec unify dom tp1 tp2 = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match tp1, tp2 with | |
| | Bool, Bool -> return (id dom, dom) | |
| | Arrow(tp1, tp2), Arrow(tp1', tp2') -> | |
| let+ (s, dom') = unify dom tp1 tp1' in | |
| let+ (s', dom'') = unify dom' (apply s tp2) (apply s tp2') in | |
| return (compose s' s, dom'') | |
| | (TVar a, TVar b) when a = b -> return (id dom, dom) | |
| | (TVar a, tp) | |
| | (tp, TVar a) -> | |
| if occurs a tp then | |
| fail | |
| else | |
| let dom' = List.filter (fun b -> b != a) dom in | |
| return (id dom' ** singleton tp a, dom') | |
| | _, _ -> fail | |
| (* Apply a substitution to a constructor. (NB: Weakening is implicitly | |
| used in the `id [x] ** s` subterm of the exists case.) *) | |
| let rec apply_constr s = function | |
| | Top -> Top | |
| | And(c1, c2) -> And(apply_constr s c1, apply_constr s c2) | |
| | Eq(tp1, tp2) -> Eq(apply s tp1, apply s tp2) | |
| | Exists(x, c) -> Exists(x, apply_constr (id [x] ** s) c) | |
| (* solve dom c solves a set of constraints. | |
| If dom ⊢ c then | |
| - If solve dom c = None, there is no solution | |
| - If solve dom c = Some(s,dom') then s is the minimal solution | |
| (i.e., every solution to c factors through s) | |
| *) | |
| let rec solve dom c = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match c with | |
| | Top -> return (id dom, dom) | |
| | And(c1, c2) -> | |
| let+ (s1, dom') = solve dom c1 in | |
| let+ (s2, dom'') = solve dom' (apply_constr s1 c2) in | |
| return (compose s2 s1, dom'') | |
| | Eq(tp1, tp2) -> unify dom tp1 tp2 | |
| | Exists(x, c) -> | |
| let+ (s, dom') = solve (x :: dom) c in | |
| return (restrict x s, dom) | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment