module Index = struct type t = int * string let create i s = (i, s) let eq t1 t2 = fst t1 = fst t2 let enter_level (i, s) = (i + 1, s) let pp_index i = Format.sprintf "%s" (snd i) end type typ = Tvar | Tpoly | Tfun of typ * typ let rec pp_typ = function | Tvar -> "any" | Tpoly -> "poly" | Tfun (a, b) -> Format.sprintf "%s -• %s" (pp_typ a) (pp_typ b) type in_ast = | Var of string | App of in_ast * in_ast | Abs of string * typ * in_ast type bruijn_ast = | Var of Index.t | App of bruijn_ast * bruijn_ast | Abs of typ * bruijn_ast type ast = Var of Index.t | App of typed_ast * typed_ast | Abs of typed_ast and typed_ast = { desc : ast; typ : typ } let rec pp_ast = function | Var i -> Format.sprintf " %s " (Index.pp_index i) | App (t1, t2) -> Format.sprintf "(%s) %s" (pp_typed t1) (pp_typed t2) | Abs t -> Format.sprintf "∆. %s" (pp_typed t) and pp_typed t = pp_ast t.desc type ctx = string list let rec find_bruijn v = function | [] -> failwith "Var not defined" | x :: t -> if v = x then 0 else 1 + find_bruijn v t let rec bruijn (ctx : ctx) : in_ast -> bruijn_ast = function | Var s -> Var (Index.create (find_bruijn s ctx) s) | App (a1, a2) -> App (bruijn ctx a1, bruijn ctx a2) | Abs (s, t, a) -> Abs (t, bruijn (s :: ctx) a) let rec unify t1 t2 = match (t1, t2) with | Tvar, Tvar -> true | Tpoly, Tvar | Tvar, Tpoly -> true | Tfun (t1, t2), Tfun (t3, t4) -> unify t1 t3 && unify t2 t4 | _, _ -> failwith "invalid types" let rec subst_aux i (t : typed_ast) : typed_ast -> typed_ast = function | { desc = Var j; typ = _ } as r -> if Index.eq i j then t else r | { desc = App (t1, t2); typ } -> { desc = App (subst_aux i t t1, subst_aux i t t2); typ } | { desc = Abs tt; typ } -> { desc = Abs (subst_aux (Index.enter_level i) t tt); typ } let substitution t e : typed_ast = subst_aux (Index.create 0 "") t e let rec type_it ctx : bruijn_ast -> typed_ast = function | Var i -> { desc = Var i; typ = List.find (fun (a, _) -> Index.eq a i) ctx |> snd } | Abs (t, f) -> let ctx' = List.map (fun (a, b) -> (Index.enter_level a, b)) ctx in let f = type_it ((Index.create 0 "", t) :: ctx') f in { desc = Abs f; typ = Tfun (t, f.typ) } | App (a1, a2) -> ( let a2 = type_it ctx a2 in let a1 = type_it ctx a1 in match a1 with | { typ; _ } -> ( match typ with | Tfun (t1, t2) -> if unify t1 a2.typ then match a1 with | { desc = Var _; typ = _ } -> { desc = App (a1, a2); typ = t2 } | { desc = Abs e; _ } -> let t = substitution a2 e in { desc = t.desc; typ = t2 } | _ -> failwith "this cannot happen" else let () = Format.printf "received %s and was expecting %s@." (pp_typ a2.typ) (pp_typ t1) in let () = assert (t1 = a2.typ) in failwith "" | _ -> failwith "invalid type applyed as fun")) let rec linearity ctx : typed_ast -> Index.t list = function | { desc = Var s; _ } -> List.find (fun a -> Index.eq a s) ctx |> ignore; List.filter (fun a -> not (Index.eq a s)) ctx | { desc = App (a1, a2); _ } -> let ctx' = linearity ctx a2 in linearity ctx' a1 | { desc = Abs a; _ } -> linearity (Index.create 0 "" :: List.map Index.enter_level ctx) a let () = let id : in_ast = Abs ("x", Tpoly, Var "x") in let ff : in_ast = Abs ("y", Tvar, App (id, Var "y")) in let idid : in_ast = Abs ("x", Tvar, App (id, Var "x")) in let fapply : in_ast = App (Abs ("f", Tfun (Tvar, Tvar), Abs ("y", Tvar, App (Var "f", Var "y"))), id) in let idb = bruijn [] id in let idt = type_it [] idb in let idp = Format.sprintf "\n%s\n@.@?" (pp_typed idt) in print_endline "identity function"; print_endline idp; let ffb = bruijn [] ff in let fft = type_it [] ffb in let ffp = Format.sprintf "\n%s\n@.@?" (pp_typed fft) in print_endline "identity function with y"; print_endline ffp; let ididb = bruijn [] idid in let ididt = type_it [] ididb in let ididp = Format.sprintf "\n%s\n@.@?" (pp_typed ididt) in print_endline "ididp"; print_endline ididp; let fapplyb = bruijn [] fapply in let fapplyt = type_it [] fapplyb in let fapplyp = Format.sprintf "\n%s\n@.@?" (pp_typed fapplyt) in print_endline "fapplyp"; print_endline fapplyp; let idl = linearity [] idt in print_endline "idl"; print_endline (String.concat ", " (List.map Index.pp_index idl)); let ffl = linearity [] fft in print_endline "ffl"; print_endline (String.concat ", " (List.map Index.pp_index ffl)); let ididl = linearity [] ididt in print_endline "ididl"; print_endline (String.concat ", " (List.map Index.pp_index ididl)); let fapplyl = linearity [] fapplyt in print_endline "fapplyl"; print_endline (String.concat ", " (List.map Index.pp_index fapplyl)); let ididl = linearity [] ididt in print_endline "ididl"; print_endline (String.concat ", " (List.map Index.pp_index ididl)); print_endline "done"