Created
April 29, 2026 08:24
-
-
Save p-pavel/577f4eca801f6cc4e87e0bf7fff40893 to your computer and use it in GitHub Desktop.
Tagless final in Lean4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module | |
| import Lean.Elab.Deriving.Basic | |
| public meta import Lean.Elab.Command | |
| open Lean Elab Command Meta | |
| public section | |
| /-- Generic final/tagless term shape (`repr` / `Sym` in Oleg's notes). | |
| A term is represented by all of its interpretations: | |
| given any interpreter instance `C a`, it produces an `a`. | |
| This is essentially Church/final encoding. | |
| Because we quantify over carriers `a : Type u`, `Final C` lives in a higher universe (`Type (u+1)`). -/ | |
| abbrev Final (C : Type u → Type v) := | |
| {a : Type u} → C a → a | |
| /-- Type of a "final instance generator": | |
| given a base signature/class family `C`, and its lifted copy `Cup`, | |
| produce the final/tagless instance at the lifted level. -/ | |
| def FinalInstGenTy (C : Type u → Type v) (Cup : Type (max (u + 1) v) → Type w) : Type w := | |
| Cup (Final C) | |
| private meta def resolveFieldProjName (clsName fieldName : Name) : TermElabM Name := do | |
| let env ← getEnv | |
| pure <| if env.contains fieldName then fieldName else clsName ++ fieldName | |
| private meta def checkClassShape (clsName : Name) : CommandElabM Unit := do | |
| let env ← getEnv | |
| unless isClass env clsName do | |
| throwError "#deriveFinal expects a class declaration; `{clsName}` is not a class" | |
| let info ← getConstInfo clsName | |
| let .inductInfo i := info | |
| | throwError "#deriveFinal expects `{clsName}` to elaborate to an inductive class" | |
| unless i.numParams == 1 && i.numIndices == 0 do | |
| throwError "#deriveFinal supports classes with exactly one parameter and no indices; `{clsName}` has {i.numParams} params and {i.numIndices} indices" | |
| private meta def mkFinalFieldValue (clsName : Name) (fieldName : Name) : TermElabM (TSyntax `term) := do | |
| let fieldProjName ← resolveFieldProjName clsName fieldName | |
| let clsConst ← mkConstWithLevelParams clsName | |
| let fieldTy ← inferType (← mkConstWithLevelParams fieldProjName) | |
| forallTelescopeReducing fieldTy fun xs _ => do | |
| if xs.size < 2 then | |
| throwError "#deriveFinal: field `{fieldName}` has unexpected type (need at least carrier and instance binders)" | |
| let aVar := xs[0]! | |
| let instVar := xs[1]! | |
| unless (← isDefEq (← inferType instVar) (mkApp clsConst aVar)) do | |
| throwError "#deriveFinal: second binder of `{fieldName}` is not an instance binder for `{clsName}`" | |
| let mut argIsCarrier : Array Bool := #[] | |
| let mut numExplicit := 0 | |
| for x in xs[2:] do | |
| let decl ← x.fvarId!.getDecl | |
| if decl.binderInfo.isExplicit then | |
| numExplicit := numExplicit + 1 | |
| let xTy ← inferType x | |
| argIsCarrier := argIsCarrier.push (← isDefEq xTy aVar) | |
| else | |
| throwError "#deriveFinal: field `{fieldName}` has implicit binders beyond class header, unsupported" | |
| let fieldResultTy ← inferType (mkAppN (← mkConstWithLevelParams fieldProjName) xs) | |
| unless (← isDefEq fieldResultTy aVar) do | |
| throwError | |
| "#deriveFinal only supports fields returning the class carrier; field `{fieldName}` returns{indentExpr fieldResultTy}" | |
| let argIds : Array Ident := (List.range numExplicit).toArray.map (fun i => | |
| mkIdent (Name.mkSimple s!"x{i}")) | |
| let clsId := mkIdent clsName | |
| let fieldConst := mkIdent fieldProjName | |
| if argIds.isEmpty then | |
| `(fun {a : Type} (t : $clsId a) => | |
| letI : $clsId a := t | |
| $fieldConst (a := a)) | |
| else | |
| let mut appArgs : Array (TSyntax `term) := #[] | |
| for h : i in [:argIds.size] do | |
| let argId := argIds[i] | |
| if argIsCarrier[i]! then | |
| appArgs := appArgs.push (← `(($argId t))) | |
| else | |
| appArgs := appArgs.push (← `($argId)) | |
| `(fun $argIds* => fun {a : Type} (t : $clsId a) => | |
| letI : $clsId a := t | |
| $fieldConst (a := a) $appArgs*) | |
| private meta def deriveFinalFor (clsName : Name) : CommandElabM Unit := do | |
| checkClassShape clsName | |
| let fields := getStructureFields (← getEnv) clsName | |
| let structFields ← runTermElabM fun _ => do | |
| fields.mapM fun fieldName => do | |
| let fieldId := mkIdent (Name.mkSimple fieldName.getString!) | |
| let valStx ← mkFinalFieldValue clsName fieldName | |
| `(Parser.Term.structInstField| $fieldId:ident := $valStx) | |
| let clsIdent := mkIdent clsName | |
| let finalIdent := mkIdent ((← getCurrNamespace) ++ `Final) | |
| let instDecl ← | |
| `(instance : $clsIdent ($finalIdent $clsIdent) where | |
| $[$structFields]*) | |
| Command.elabCommandTopLevel instDecl | |
| /-- Derive the canonical final/tagless instance for a class signature. | |
| Supported fragment in this file: | |
| - one class parameter (the carrier), | |
| - fields return the carrier, | |
| - explicit arguments may be carrier terms (recursive) or ordinary parameters (`Nat`, `Bool`, `String`, ...). | |
| Intentionally out of scope here: | |
| HOAS, indexed/multi-sorted families, and effectful/higher-kinded encodings. -/ | |
| syntax (name := deriveFinalCmd) "#deriveFinal " ident : command | |
| @[command_elab deriveFinalCmd] meta def elabDeriveFinal : CommandElab := fun stx => do | |
| match stx with | |
| | `(#deriveFinal $clsId:ident) => | |
| let clsName ← liftCoreM <| realizeGlobalConstNoOverloadWithInfo clsId | |
| deriveFinalFor clsName | |
| | _ => throwUnsupportedSyntax |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module | |
| import SignalProto.DeriveFinal | |
| /-! | |
| Small user-facing examples for `Final` / `#deriveFinal`. | |
| The macro implementation lives in `SignalProto.DeriveFinal`. | |
| -/ | |
| class Abc a where | |
| zero : a | |
| suc : a → a | |
| class Arith a where | |
| lit : Nat → a | |
| add : a → a → a | |
| mul : a → a → a | |
| #deriveFinal Abc | |
| #deriveFinal Arith | |
| #synth Abc (Final Abc) | |
| #synth Arith (Final Arith) | |
| universe u v | |
| /- Universe visibility: `Final` is polymorphic and rises one universe. -/ | |
| #check (Final : (Type u → Type v) → Type (max (u + 1) v)) | |
| abbrev AbcTerm := Final Abc | |
| abbrev ArithTerm := Final Arith | |
| instance instNat : Abc Nat where | |
| zero := 0 | |
| suc n := n + 1 | |
| instance instArithNat : Arith Nat where | |
| lit n := n | |
| add x y := x + y | |
| mul x y := x * y | |
| instance instArithPretty : Arith String where | |
| lit n := toString n | |
| add x y := s!"({x} + {y})" | |
| mul x y := s!"({x} * {y})" | |
| structure NodeCount where | |
| n : Nat | |
| instance instArithNodeCount : Arith NodeCount where | |
| lit _ := ⟨1⟩ | |
| add x y := ⟨x.n + y.n + 1⟩ | |
| mul x y := ⟨x.n + y.n + 1⟩ | |
| /-- One explicit hand-written canonical final instance (same shape `#deriveFinal` generates). -/ | |
| def manualAbcFinal : Abc (Final Abc) where | |
| zero := fun {_} t => t.zero | |
| suc e := fun {_} t => t.suc (e t) | |
| def one : AbcTerm := Abc.suc (a := AbcTerm) (Abc.zero (a := AbcTerm)) | |
| def two : AbcTerm := Abc.suc (a := AbcTerm) one | |
| def three : ArithTerm := Arith.add (a := ArithTerm) (Arith.lit (a := ArithTerm) 1) (Arith.lit (a := ArithTerm) 2) | |
| def expr : ArithTerm := | |
| Arith.add (a := ArithTerm) | |
| (Arith.lit (a := ArithTerm) 2) | |
| (Arith.mul (a := ArithTerm) (Arith.lit (a := ArithTerm) 3) (Arith.lit (a := ArithTerm) 4)) | |
| #eval (one instNat) | |
| #eval (two instNat) | |
| #eval (three instArithNat) | |
| #eval (expr instArithNat) | |
| #eval (expr instArithPretty) | |
| #eval (expr instArithNodeCount).n |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment