Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: manage nested inductive types in deriving #3160

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Init/Data/Sum.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,14 @@ def getRight? : α ⊕ β → Option β
| inr b => some b
| inl _ => none

/-- Define a function on `α ⊕ β` by giving separate definitions on `α` and `β`. -/
protected def elim {α β γ} (f : α → γ) (g : β → γ) : α ⊕ β → γ :=
fun x => Sum.casesOn x f g

@[simp] theorem elim_inl (f : α → γ) (g : β → γ) (x : α) :
Sum.elim f g (inl x) = f x := rfl

@[simp] theorem elim_inr (f : α → γ) (g : β → γ) (x : β) :
Sum.elim f g (inr x) = g x := rfl

end Sum
50 changes: 30 additions & 20 deletions src/Lean/Elab/Deriving/BEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ namespace Lean.Elab.Deriving.BEq
open Lean.Parser.Term
open Meta

def mkBEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `BEq 2 indVal
def mkBEqHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do
mkHeader `BEq 2 argNames indTypeFormer

def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do
let f := e.getAppFn
let ind := f.constName!
let lvls := f.constLevels!
let indVal ← getConstInfoInduct ind
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
let alts ← mkAlts indVal lvls
`(match $[$discrs],* with $alts:matchAlt*)
where
mkElseAlt : TermElabM (TSyntax ``matchAltExpr) := do
mkElseAlt (indVal : InductiveVal): TermElabM (TSyntax ``matchAltExpr) := do
let mut patterns := #[]
-- add `_` pattern for indices
for _ in [:indVal.numIndices] do
Expand All @@ -30,11 +34,15 @@ where
let altRhs ← `(false)
`(matchAltExpr| | $[$patterns:term],* => $altRhs:term)

mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName in indVal.ctors do
let args := e.getAppArgs
let ctorInfo ← getConstInfoCtor ctorName
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let subargs := args[:ctorInfo.numParams]
let ctorApp := mkAppN (mkConst ctorInfo.name lvl) subargs
let ctorType ← inferType ctorApp
let alt ← forallTelescopeReducing ctorType fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := #[]
-- add `_` pattern for indices
Expand All @@ -45,7 +53,7 @@ where
let mut rhs ← `(true)
let mut rhs_empty := true
for i in [:ctorInfo.numFields] do
let pos := indVal.numParams + ctorInfo.numFields - i - 1
let pos := indVal.numParams + ctorInfo.numFields - subargs.size - i - 1
let x := xs[pos]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
Expand All @@ -59,7 +67,7 @@ where
let xType ← inferType x
if (← isProp xType) then
continue
if xType.isAppOf indVal.name then
if let some auxFunName ← ctx.getFunName? header xType fvars then
if rhs_empty then
rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident)
rhs_empty := false
Expand Down Expand Up @@ -88,19 +96,21 @@ where
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2.reverse:term*))
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
alts := alts.push alt
alts := alts.push (← mkElseAlt)
alts := alts.push (← mkElseAlt indVal)
return alts

def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
let header ← mkBEqHeader indVal
let mut body ← mkMatch header indVal auxFunName
let auxFunName := ctx.auxFunNames[i]!
let indTypeFormer := ctx.typeInfos[i]!
let argNames := ctx.typeArgNames[i]!
let header ← mkBEqHeader argNames indTypeFormer
let binders := header.binders
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let mut body ← mkMatch ctx header type xs
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
body ← mkLet letDecls body
let binders := header.binders
if ctx.usePartial then
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
Expand All @@ -115,8 +125,8 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
end)

private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
let ctx ← mkContext "beq" declName false
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq)
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds

Expand All @@ -125,8 +135,8 @@ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
`(private def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx)

private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" name
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name])
let ctx ← mkContext "beq" name false
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq)
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds

Expand Down
74 changes: 38 additions & 36 deletions src/Lean/Elab/Deriving/DecEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Meta.InferType
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Elab.Binders

namespace Lean.Elab.Deriving.DecEq
open Lean.Parser.Term
open Meta

def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `DecidableEq 2 indVal
def mkDecEqHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do
mkHeader `DecidableEq 2 argNames indTypeFormer

def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do
let f := e.getAppFn
let ind := f.constName!
let lvls := f.constLevels!
let indVal ← getConstInfoInduct ind
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
let alts ← mkAlts indVal lvls
`(match $[$discrs],* with $alts:matchAlt*)
where
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
Expand All @@ -31,14 +37,14 @@ where
`(if h : @$a = @$b then
by subst h; exact $sameCtor:term
else
isFalse (by intro n; injection n; apply h _; assumption))
isFalse (by intro n; injection n; contradiction))
if let some auxFunName := recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
else
return rhs

mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName₁ in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName₁
Expand All @@ -48,7 +54,10 @@ where
for _ in [:indVal.numIndices] do
patterns := patterns.push (← `(_))
if ctorName₁ == ctorName₂ then
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let args := e.getAppArgs
let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams]
let ctorType ← inferType ctorApp
let alt ← forallTelescopeReducing ctorType fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := patterns
let mut ctorArgs1 := #[]
Expand All @@ -59,7 +68,7 @@ where
ctorArgs2 := ctorArgs2.push (← `(_))
let mut todo := #[]
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]!
let x := xs[i]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
ctorArgs1 := ctorArgs1.push (← `(_))
Expand All @@ -70,11 +79,8 @@ where
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let xType ← inferType x
let indValNum :=
ctx.typeInfos.findIdx?
(xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
let recField := indValNum.map (ctx.auxFunNames[·]!)
let isProof ← isProp xType
let recField ← ctx.getFunName? header xType fvars
let isProof ← isProp xType
todo := todo.push (a, b, recField, isProof)
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
Expand All @@ -87,47 +93,43 @@ where
alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts

def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
let header ← mkDecEqHeader indVal
let body ← mkMatch ctx header indVal
def mkAuxFunction (ctx : Context) (auxFunName : Name) (argNames : Array Name) (indTypeFormer : IndTypeFormer): TermElabM (TSyntax `command) := do
let header ← mkDecEqHeader argNames indTypeFormer
let binders := header.binders
let target₁ := mkIdent header.targetNames[0]!
let target₂ := mkIdent header.targetNames[1]!
let termSuffix ← if indVal.isRec
let termSuffix ← if ctx.auxFunNames.size > 1 || indTypeFormer.getIndVal.isRec
then `(Parser.Termination.suffix|termination_by structural $target₁)
else `(Parser.Termination.suffix|)
let type ← `(Decidable ($target₁ = $target₂))
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let body ← mkMatch ctx header type xs
let type ← `(Decidable ($target₁ = $target₂))
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term
$termSuffix:suffix)

def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
let mut res : Array (TSyntax `command) := #[]
for i in [:ctx.auxFunNames.size] do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
res := res.push (← mkAuxFunction ctx auxFunName indVal)
let auxFunName := ctx.auxFunNames[i]!
let indTypeFormer := ctx.typeInfos[i]!
let argNames := ctx.typeArgNames[i]!
res := res.push (← mkAuxFunction ctx auxFunName argNames indTypeFormer)
`(command| mutual $[$res:command]* end)

def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let ctx ← mkContext "decEq" indVal.name
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq (useAnonCtor := false))
trace[Elab.Deriving.decEq] "\n{cmds}"
return cmds

open Command

def mkDecEq (declName : Name) : CommandElabM Bool := do
def mkDecEq (declName : Name) : CommandElabM Unit := do
let indVal ← getConstInfoInduct declName
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM <| mkDecEqCmds indVal
-- `cmds` can have a number of syntax nodes quadratic in the number of constructors
-- and thus create as many info tree nodes, which we never make use of but which can
-- significantly slow down e.g. the unused variables linter; avoid creating them
withEnableInfoTree false do
cmds.forM elabCommand
return true
let cmds ← liftTermElabM <| mkDecEqCmds indVal
withEnableInfoTree false do
cmds.forM elabCommand

partial def mkEnumOfNat (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
Expand Down Expand Up @@ -195,15 +197,15 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
trace[Elab.Deriving.decEq] "\n{cmd}"
elabCommand cmd

def mkDecEqInstance (declName : Name) : CommandElabM Bool := do
def mkDecEqInstance (declName : Name) : CommandElabM Unit := do
if (← isEnumType declName) then
mkDecEqEnum declName
return true
else
mkDecEq declName

def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true
declNames.forM mkDecEqInstance
return true

builtin_initialize
registerDerivingHandler `DecidableEq mkDecEqInstanceHandler
Expand Down
Loading
Loading