Skip to content

Instantly share code, notes, and snippets.

@p-pavel
Created April 29, 2026 08:24
Show Gist options
  • Select an option

  • Save p-pavel/577f4eca801f6cc4e87e0bf7fff40893 to your computer and use it in GitHub Desktop.

Select an option

Save p-pavel/577f4eca801f6cc4e87e0bf7fff40893 to your computer and use it in GitHub Desktop.
Tagless final in Lean4
module
import Lean.Elab.Deriving.Basic
public meta import Lean.Elab.Command
open Lean Elab Command Meta
public section
/-- Generic final/tagless term shape (`repr` / `Sym` in Oleg's notes).
A term is represented by all of its interpretations:
given any interpreter instance `C a`, it produces an `a`.
This is essentially Church/final encoding.
Because we quantify over carriers `a : Type u`, `Final C` lives in a higher universe (`Type (u+1)`). -/
abbrev Final (C : Type u → Type v) :=
{a : Type u} → C a → a
/-- Type of a "final instance generator":
given a base signature/class family `C`, and its lifted copy `Cup`,
produce the final/tagless instance at the lifted level. -/
def FinalInstGenTy (C : Type u → Type v) (Cup : Type (max (u + 1) v) → Type w) : Type w :=
Cup (Final C)
private meta def resolveFieldProjName (clsName fieldName : Name) : TermElabM Name := do
let env ← getEnv
pure <| if env.contains fieldName then fieldName else clsName ++ fieldName
private meta def checkClassShape (clsName : Name) : CommandElabM Unit := do
let env ← getEnv
unless isClass env clsName do
throwError "#deriveFinal expects a class declaration; `{clsName}` is not a class"
let info ← getConstInfo clsName
let .inductInfo i := info
| throwError "#deriveFinal expects `{clsName}` to elaborate to an inductive class"
unless i.numParams == 1 && i.numIndices == 0 do
throwError "#deriveFinal supports classes with exactly one parameter and no indices; `{clsName}` has {i.numParams} params and {i.numIndices} indices"
private meta def mkFinalFieldValue (clsName : Name) (fieldName : Name) : TermElabM (TSyntax `term) := do
let fieldProjName ← resolveFieldProjName clsName fieldName
let clsConst ← mkConstWithLevelParams clsName
let fieldTy ← inferType (← mkConstWithLevelParams fieldProjName)
forallTelescopeReducing fieldTy fun xs _ => do
if xs.size < 2 then
throwError "#deriveFinal: field `{fieldName}` has unexpected type (need at least carrier and instance binders)"
let aVar := xs[0]!
let instVar := xs[1]!
unless (← isDefEq (← inferType instVar) (mkApp clsConst aVar)) do
throwError "#deriveFinal: second binder of `{fieldName}` is not an instance binder for `{clsName}`"
let mut argIsCarrier : Array Bool := #[]
let mut numExplicit := 0
for x in xs[2:] do
let decl ← x.fvarId!.getDecl
if decl.binderInfo.isExplicit then
numExplicit := numExplicit + 1
let xTy ← inferType x
argIsCarrier := argIsCarrier.push (← isDefEq xTy aVar)
else
throwError "#deriveFinal: field `{fieldName}` has implicit binders beyond class header, unsupported"
let fieldResultTy ← inferType (mkAppN (← mkConstWithLevelParams fieldProjName) xs)
unless (← isDefEq fieldResultTy aVar) do
throwError
"#deriveFinal only supports fields returning the class carrier; field `{fieldName}` returns{indentExpr fieldResultTy}"
let argIds : Array Ident := (List.range numExplicit).toArray.map (fun i =>
mkIdent (Name.mkSimple s!"x{i}"))
let clsId := mkIdent clsName
let fieldConst := mkIdent fieldProjName
if argIds.isEmpty then
`(fun {a : Type} (t : $clsId a) =>
letI : $clsId a := t
$fieldConst (a := a))
else
let mut appArgs : Array (TSyntax `term) := #[]
for h : i in [:argIds.size] do
let argId := argIds[i]
if argIsCarrier[i]! then
appArgs := appArgs.push (← `(($argId t)))
else
appArgs := appArgs.push (← `($argId))
`(fun $argIds* => fun {a : Type} (t : $clsId a) =>
letI : $clsId a := t
$fieldConst (a := a) $appArgs*)
private meta def deriveFinalFor (clsName : Name) : CommandElabM Unit := do
checkClassShape clsName
let fields := getStructureFields (← getEnv) clsName
let structFields ← runTermElabM fun _ => do
fields.mapM fun fieldName => do
let fieldId := mkIdent (Name.mkSimple fieldName.getString!)
let valStx ← mkFinalFieldValue clsName fieldName
`(Parser.Term.structInstField| $fieldId:ident := $valStx)
let clsIdent := mkIdent clsName
let finalIdent := mkIdent ((← getCurrNamespace) ++ `Final)
let instDecl ←
`(instance : $clsIdent ($finalIdent $clsIdent) where
$[$structFields]*)
Command.elabCommandTopLevel instDecl
/-- Derive the canonical final/tagless instance for a class signature.
Supported fragment in this file:
- one class parameter (the carrier),
- fields return the carrier,
- explicit arguments may be carrier terms (recursive) or ordinary parameters (`Nat`, `Bool`, `String`, ...).
Intentionally out of scope here:
HOAS, indexed/multi-sorted families, and effectful/higher-kinded encodings. -/
syntax (name := deriveFinalCmd) "#deriveFinal " ident : command
@[command_elab deriveFinalCmd] meta def elabDeriveFinal : CommandElab := fun stx => do
match stx with
| `(#deriveFinal $clsId:ident) =>
let clsName ← liftCoreM <| realizeGlobalConstNoOverloadWithInfo clsId
deriveFinalFor clsName
| _ => throwUnsupportedSyntax
module
import SignalProto.DeriveFinal
/-!
Small user-facing examples for `Final` / `#deriveFinal`.
The macro implementation lives in `SignalProto.DeriveFinal`.
-/
class Abc a where
zero : a
suc : a → a
class Arith a where
lit : Nat → a
add : a → a → a
mul : a → a → a
#deriveFinal Abc
#deriveFinal Arith
#synth Abc (Final Abc)
#synth Arith (Final Arith)
universe u v
/- Universe visibility: `Final` is polymorphic and rises one universe. -/
#check (Final : (Type u → Type v) → Type (max (u + 1) v))
abbrev AbcTerm := Final Abc
abbrev ArithTerm := Final Arith
instance instNat : Abc Nat where
zero := 0
suc n := n + 1
instance instArithNat : Arith Nat where
lit n := n
add x y := x + y
mul x y := x * y
instance instArithPretty : Arith String where
lit n := toString n
add x y := s!"({x} + {y})"
mul x y := s!"({x} * {y})"
structure NodeCount where
n : Nat
instance instArithNodeCount : Arith NodeCount where
lit _ := ⟨1⟩
add x y := ⟨x.n + y.n + 1⟩
mul x y := ⟨x.n + y.n + 1⟩
/-- One explicit hand-written canonical final instance (same shape `#deriveFinal` generates). -/
def manualAbcFinal : Abc (Final Abc) where
zero := fun {_} t => t.zero
suc e := fun {_} t => t.suc (e t)
def one : AbcTerm := Abc.suc (a := AbcTerm) (Abc.zero (a := AbcTerm))
def two : AbcTerm := Abc.suc (a := AbcTerm) one
def three : ArithTerm := Arith.add (a := ArithTerm) (Arith.lit (a := ArithTerm) 1) (Arith.lit (a := ArithTerm) 2)
def expr : ArithTerm :=
Arith.add (a := ArithTerm)
(Arith.lit (a := ArithTerm) 2)
(Arith.mul (a := ArithTerm) (Arith.lit (a := ArithTerm) 3) (Arith.lit (a := ArithTerm) 4))
#eval (one instNat)
#eval (two instNat)
#eval (three instArithNat)
#eval (expr instArithNat)
#eval (expr instArithPretty)
#eval (expr instArithNodeCount).n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment