Skip to content

Commit

Permalink
feat: manage deriving DecidableEq for nested inductive types
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-adjedj committed Aug 14, 2024
1 parent 82b7cfd commit 901d3cd
Show file tree
Hide file tree
Showing 11 changed files with 775 additions and 380 deletions.
48 changes: 29 additions & 19 deletions src/Lean/Elab/Deriving/BEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ namespace Lean.Elab.Deriving.BEq
open Lean.Parser.Term
open Meta

def mkBEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `BEq 2 indVal
def mkBEqHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do
mkHeader `BEq 2 argNames nestedOcc

def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do
let f := e.getAppFn
let ind := f.constName!
let lvls := f.constLevels!
let indVal ← getConstInfoInduct ind
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
let alts ← mkAlts indVal lvls
`(match $[$discrs],* with $alts:matchAlt*)
where
mkElseAlt : TermElabM (TSyntax ``matchAltExpr) := do
mkElseAlt (indVal : InductiveVal): TermElabM (TSyntax ``matchAltExpr) := do
let mut patterns := #[]
-- add `_` pattern for indices
for _ in [:indVal.numIndices] do
Expand All @@ -30,11 +34,15 @@ where
let altRhs ← `(false)
`(matchAltExpr| | $[$patterns:term],* => $altRhs:term)

mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName in indVal.ctors do
let args := e.getAppArgs
let ctorInfo ← getConstInfoCtor ctorName
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let subargs := args[:ctorInfo.numParams]
let ctorApp := mkAppN (mkConst ctorInfo.name lvl) subargs
let ctorType ← inferType ctorApp
let alt ← forallTelescopeReducing ctorType fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := #[]
-- add `_` pattern for indices
Expand All @@ -45,7 +53,7 @@ where
let mut rhs ← `(true)
let mut rhs_empty := true
for i in [:ctorInfo.numFields] do
let pos := indVal.numParams + ctorInfo.numFields - i - 1
let pos := indVal.numParams + ctorInfo.numFields - subargs.size - i - 1
let x := xs[pos]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
Expand All @@ -59,7 +67,7 @@ where
let xType ← inferType x
if (← isProp xType) then
continue
if xType.isAppOf indVal.name then
if let some auxFunName ← ctx.getFunName? header xType fvars then
if rhs_empty then
rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident)
rhs_empty := false
Expand Down Expand Up @@ -88,19 +96,21 @@ where
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2.reverse:term*))
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
alts := alts.push alt
alts := alts.push (← mkElseAlt)
alts := alts.push (← mkElseAlt indVal)
return alts

def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
let header ← mkBEqHeader indVal
let mut body ← mkMatch header indVal auxFunName
let nestedOcc := ctx.typeInfos[i]!
let argNames := ctx.typeArgNames[i]!
let header ← mkBEqHeader argNames nestedOcc
let binders := header.binders
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let mut body ← mkMatch ctx header type xs
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
body ← mkLet letDecls body
let binders := header.binders
if ctx.usePartial then
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
Expand All @@ -115,8 +125,8 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
end)

private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
let ctx ← mkContext "beq" declName false
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq)
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds

Expand All @@ -125,8 +135,8 @@ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
`(private def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx)

private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" name
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name])
let ctx ← mkContext "beq" name false
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq)
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds

Expand Down
74 changes: 38 additions & 36 deletions src/Lean/Elab/Deriving/DecEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Meta.InferType
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Elab.Binders

namespace Lean.Elab.Deriving.DecEq
open Lean.Parser.Term
open Meta

def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `DecidableEq 2 indVal
def mkDecEqHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do
mkHeader `DecidableEq 2 argNames nestedOcc

def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do
let f := e.getAppFn
let ind := f.constName!
let lvls := f.constLevels!
let indVal ← getConstInfoInduct ind
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
let alts ← mkAlts indVal lvls
`(match $[$discrs],* with $alts:matchAlt*)
where
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
Expand All @@ -31,14 +37,14 @@ where
`(if h : @$a = @$b then
by subst h; exact $sameCtor:term
else
isFalse (by intro n; injection n; apply h _; assumption))
isFalse (by intro n; injection n; contradiction))
if let some auxFunName := recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
else
return rhs

mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName₁ in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName₁
Expand All @@ -48,7 +54,10 @@ where
for _ in [:indVal.numIndices] do
patterns := patterns.push (← `(_))
if ctorName₁ == ctorName₂ then
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let args := e.getAppArgs
let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams]
let ctorType ← inferType ctorApp
let alt ← forallTelescopeReducing ctorType fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := patterns
let mut ctorArgs1 := #[]
Expand All @@ -59,7 +68,7 @@ where
ctorArgs2 := ctorArgs2.push (← `(_))
let mut todo := #[]
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]!
let x := xs[i]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
ctorArgs1 := ctorArgs1.push (← `(_))
Expand All @@ -70,11 +79,8 @@ where
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let xType ← inferType x
let indValNum :=
ctx.typeInfos.findIdx?
(xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
let recField := indValNum.map (ctx.auxFunNames[·]!)
let isProof ← isProp xType
let recField ← ctx.getFunName? header xType fvars
let isProof ← isProp xType
todo := todo.push (a, b, recField, isProof)
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
Expand All @@ -87,47 +93,43 @@ where
alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts

def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
let header ← mkDecEqHeader indVal
let body ← mkMatch ctx header indVal
def mkAuxFunction (ctx : Context) (auxFunName : Name) (argNames : Array Name) (nestedOcc : NestedOccurence): TermElabM (TSyntax `command) := do
let header ← mkDecEqHeader argNames nestedOcc
let binders := header.binders
let target₁ := mkIdent header.targetNames[0]!
let target₂ := mkIdent header.targetNames[1]!
let termSuffix ← if indVal.isRec
let termSuffix ← if ctx.auxFunNames.size > 1 || nestedOcc.getIndVal.isRec
then `(Parser.Termination.suffix|termination_by structural $target₁)
else `(Parser.Termination.suffix|)
let type ← `(Decidable ($target₁ = $target₂))
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let body ← mkMatch ctx header type xs
let type ← `(Decidable ($target₁ = $target₂))
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term
$termSuffix:suffix)

def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
let mut res : Array (TSyntax `command) := #[]
for i in [:ctx.auxFunNames.size] do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
res := res.push (← mkAuxFunction ctx auxFunName indVal)
let auxFunName := ctx.auxFunNames[i]!
let nestedOcc := ctx.typeInfos[i]!
let argNames := ctx.typeArgNames[i]!
res := res.push (← mkAuxFunction ctx auxFunName argNames nestedOcc)
`(command| mutual $[$res:command]* end)

def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let ctx ← mkContext "decEq" indVal.name
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq (useAnonCtor := false))
trace[Elab.Deriving.decEq] "\n{cmds}"
return cmds

open Command

def mkDecEq (declName : Name) : CommandElabM Bool := do
def mkDecEq (declName : Name) : CommandElabM Unit := do
let indVal ← getConstInfoInduct declName
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM <| mkDecEqCmds indVal
-- `cmds` can have a number of syntax nodes quadratic in the number of constructors
-- and thus create as many info tree nodes, which we never make use of but which can
-- significantly slow down e.g. the unused variables linter; avoid creating them
withEnableInfoTree false do
cmds.forM elabCommand
return true
let cmds ← liftTermElabM <| mkDecEqCmds indVal
withEnableInfoTree false do
cmds.forM elabCommand

partial def mkEnumOfNat (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
Expand Down Expand Up @@ -195,15 +197,15 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
trace[Elab.Deriving.decEq] "\n{cmd}"
elabCommand cmd

def mkDecEqInstance (declName : Name) : CommandElabM Bool := do
def mkDecEqInstance (declName : Name) : CommandElabM Unit := do
if (← isEnumType declName) then
mkDecEqEnum declName
return true
else
mkDecEq declName

def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true
declNames.forM mkDecEqInstance
return true

builtin_initialize
registerDerivingHandler `DecidableEq mkDecEqInstanceHandler
Expand Down
Loading

0 comments on commit 901d3cd

Please sign in to comment.