Skip to content

Instantly share code, notes, and snippets.

@neel-krishnaswami
Created May 5, 2026 15:27
Show Gist options
  • Select an option

  • Save neel-krishnaswami/09b525a6d6d630f358c8c8a206ff872b to your computer and use it in GitHub Desktop.

Select an option

Save neel-krishnaswami/09b525a6d6d630f358c8c8a206ff872b to your computer and use it in GitHub Desktop.
(* Type inference for the STLC, but with bidirectional constraint generation *)
(* This is basically unification-based type inference *without* let-generalization.
*)
type evar = string
type tp = TVar of evar | Arrow of tp * tp | Bool
(* The type of constraints. We assume there are n shadowed or duplicated
variables in constraints, which we enforce by generating fresh
variables for any new name. *)
type constr =
| Eq of tp * tp
| Exists of evar * constr
| And of constr * constr
| Top
(* exist [x1, .., xn] C returns ∃x1. ... ∃xn. C *)
let exist vars c =
List.fold_right (fun a c -> Exists(a, c)) vars c
(* Lambda terms *)
type var = string
type exp =
| Var of var
| Lam of var * exp
| App of exp * exp
| Annot of exp * tp
| Blit of bool
| If of exp * exp * exp
| Let of var * exp * exp
(* Typechecking monad – basically supports fresh names + error stops *)
module TC = struct
type 'a t = M of (int -> ('a * int) option)
let map f (M m) = M (fun n -> match m n with
| None -> None
| Some (a, n) -> Some(f a, n))
let return x = M(fun n -> Some(x, n))
let (let+) (M c) f =
M(fun n ->
match c n with
| None -> None
| Some(v, n') -> let M g = f v in
g n')
let fail = M(fun n -> None)
let fresh name =
M(fun n ->
let s = Printf.sprintf "_%s_%d" name n in
Some(s, n+1))
end
module Destruct = struct
open TC
(* Destructors:
Each type destructor returns the subcomponents of a type if the
argument type has the right shape. If it's an evar, then it
generates fresh evars and a constraint that it has the right
shape. Otherwise it fails.
*)
let arrow = function
| Arrow(tp1, tp2) -> return ((tp1, tp2), [], Top)
| TVar a ->
let+ dom = fresh "dom" in
let+ cod = fresh "cod" in
return ((TVar dom, TVar cod),
[dom; cod],
Eq(TVar a, Arrow(TVar dom, TVar cod)))
| _ -> fail
let bool = function
| Bool -> return ((), [], Top)
| TVar a -> return ((), [], Eq(TVar a, Bool))
| _ -> fail
end
module Bidir = struct
open TC
(* Looking up variables in a context *)
type ctx = (var * tp) list
let lookup ctx x =
match List.assoc_opt x ctx with
| None -> fail
| Some tp -> return tp
(* Bidirectional typechecking:
1. The checking judgement Σ; Γ ⊢ e ⇐ τ ↝ C
This takes a current scope of evars Σ, a context Γ (all of whose
evars are in Σ), a term e, and an expected type τ (whose evars are
in Σ), and returns a constraint C, with free variables bounded by
Σ.
Checking never extends the constraint context. We introduce
existential variables as needed to make this happen.
2. The synthesis judgement Σ; Γ ⊢ e ⇒ τ ↝ C ⊣ Σ'
This takes a current scope of evars Σ; a context Γ (whose evars
are in Σ), a term e, and returns an expected type τ, a constraint
C, and an extended set of evars Σ' ≥ Σ bounding the free evars of
τ and C.
Synthesis *can* extend the constraint context.
The rules look very similar to the usual rules for bidirectional
typechecking. The main differences are that
a) we use destructors instead of pattern matching to disassemble
types, and
b) when synthesis meets a checking term, instead of signalling
an error, we create a fresh type variable and check against
that.
*)
let rec check dom ctx exp tp =
match exp with
| Blit b -> let+ ((), vars, c0) = Destruct.bool tp in
return (exist vars c0)
| Lam(x, e) -> let+ ((tp1, tp2), vars, c0) = Destruct.arrow tp in
let+ c1 = check (dom @ vars) ((x, tp1) :: ctx) e tp2 in
return (exist vars (And(c0, c1)))
| If(e1, e2, e3) ->
let+ c1 = check dom ctx e1 Bool in
let+ c2 = check dom ctx e2 tp in
let+ c3 = check dom ctx e3 tp in
return (And(c1, And(c2, c3)))
| Let(x, e1, e2) ->
let+ (tp1, c1, xs) = synth dom ctx e1 in
let+ c2 = check (xs @ dom) ((x, tp1) :: ctx) e2 tp in
return (exist xs (And(c1, c2)))
| e -> let+ (tp', c, xs) = synth dom ctx e in
return (exist xs (And(c, Eq(tp, tp'))))
and synth dom ctx exp =
match exp with
| App(e1, e2) ->
let+ (tp1, c0, xs0) = synth dom ctx e1 in
let+ ((tp_arg, tp), xs1, c1) = Destruct.arrow tp1 in
let+ c2 = check (xs0 @ xs1 @ dom) ctx e2 tp_arg in
return (tp, And(c0, And(c1, c2)), xs0 @ xs1)
| Var x ->
let+ tp = lookup ctx x in
return (tp, Top, [])
| Annot(e, tp) ->
let+ c = check dom ctx e tp in
return (tp, c, [])
(* This is the case which is different from usual – instead
of erroring out, we introduce a fresh evar and check e
against that. *)
| e ->
let+ a = fresh "check" in
let+ c = check (a :: dom) ctx e (TVar a) in
return (TVar a, c, [a])
end
(* An implementation of parallel substitutions *)
module Subst = struct
type subst = S of (evar * tp) list
(* Our invariant is that substitutions are total, so identity
substitutions map evars to the corresponding type *)
(* id dom return dom ⊢ id[dom] : dom
Note that since weakening is sound, if dom' ⊇ dom, then
dom' ⊢ id dom : dom as well.
*)
let id dom =
S (List.map (fun a -> (a, TVar a)) dom)
(* If dom |- s : dom' and dom' |- tp then dom |- apply s tp *)
let rec apply (S s) = function
| Bool -> Bool
| Arrow(tp1, tp2) -> Arrow(apply (S s) tp1, apply (S s) tp2)
| TVar a -> List.assoc a s
(* If dom'' |- s1 : dom' and dom' |- s2 : dom then dom'' |- compose s1 s2 : dom *)
let rec compose s1 (S s2) =
S (List.map (fun (x, tp) -> (x, apply s1 tp)) s2)
(* if dom ⊢ s1 : dom1 and dom ⊢ s2 : dom2 and dom1 ∩ dom2 = ∅, then
dom ⊢ s1 ** s2 : dom1, dom2 *)
let ( ** ) (S s1) (S s2) = S (s1 @ s2)
(* if dom |- e then dom ⊢ (singleton e x) : x *)
let singleton e x = S [x, e]
(* restrict a s – remove the variable a from s *)
let restrict a (S s) = S(List.filter (fun (b, _) -> a != b) s)
end
(* A simple implementation of unification *)
module Solve = struct
open TC
open Subst
(* occurs a tp: Does the variable a occur in tp? *)
let rec occurs a tp =
match tp with
| TVar b -> a = b
| Bool -> false
| Arrow(tp1, tp2) -> occurs a tp1 || occurs a tp2
(* unify dom tp1 tp2
If dom ⊢ tp1 and dom ⊢ tp2 then:
- if unify dom tp1 tp2 = None, then there is no unifier
- if unify dom tp1 tp2 = Some s, then s is the most general unifier
*)
let rec unify dom tp1 tp2 =
let (return, fail, (let+)) = Option.(some, none, bind) in
match tp1, tp2 with
| Bool, Bool -> return (id dom, dom)
| Arrow(tp1, tp2), Arrow(tp1', tp2') ->
let+ (s, dom') = unify dom tp1 tp1' in
let+ (s', dom'') = unify dom' (apply s tp2) (apply s tp2') in
return (compose s' s, dom'')
| (TVar a, TVar b) when a = b -> return (id dom, dom)
| (TVar a, tp)
| (tp, TVar a) ->
if occurs a tp then
fail
else
let dom' = List.filter (fun b -> b != a) dom in
return (id dom' ** singleton tp a, dom')
| _, _ -> fail
(* Apply a substitution to a constructor. (NB: Weakening is implicitly
used in the `id [x] ** s` subterm of the exists case.) *)
let rec apply_constr s = function
| Top -> Top
| And(c1, c2) -> And(apply_constr s c1, apply_constr s c2)
| Eq(tp1, tp2) -> Eq(apply s tp1, apply s tp2)
| Exists(x, c) -> Exists(x, apply_constr (id [x] ** s) c)
(* solve dom c solves a set of constraints.
If dom ⊢ c then
- If solve dom c = None, there is no solution
- If solve dom c = Some(s,dom') then s is the minimal solution
(i.e., every solution to c factors through s)
*)
let rec solve dom c =
let (return, fail, (let+)) = Option.(some, none, bind) in
match c with
| Top -> return (id dom, dom)
| And(c1, c2) ->
let+ (s1, dom') = solve dom c1 in
let+ (s2, dom'') = solve dom' (apply_constr s1 c2) in
return (compose s2 s1, dom'')
| Eq(tp1, tp2) -> unify dom tp1 tp2
| Exists(x, c) ->
let+ (s, dom') = solve (x :: dom) c in
return (restrict x s, dom)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment