diff --git a/src/Init/Data/Sum.lean b/src/Init/Data/Sum.lean index ac5ff280b511..ae8692ea0f4f 100644 --- a/src/Init/Data/Sum.lean +++ b/src/Init/Data/Sum.lean @@ -21,4 +21,14 @@ def getRight? : α ⊕ β → Option β | inr b => some b | inl _ => none +/-- Define a function on `α ⊕ β` by giving separate definitions on `α` and `β`. -/ +protected def elim {α β γ} (f : α → γ) (g : β → γ) : α ⊕ β → γ := + fun x => Sum.casesOn x f g + +@[simp] theorem elim_inl (f : α → γ) (g : β → γ) (x : α) : + Sum.elim f g (inl x) = f x := rfl + +@[simp] theorem elim_inr (f : α → γ) (g : β → γ) (x : β) : + Sum.elim f g (inr x) = g x := rfl + end Sum diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index 5e9cb3cc5bba..9e076ebd19e8 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -12,15 +12,19 @@ namespace Lean.Elab.Deriving.BEq open Lean.Parser.Term open Meta -def mkBEqHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `BEq 2 indVal +def mkBEqHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + mkHeader `BEq 2 argNames indTypeFormer -def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkElseAlt : TermElabM (TSyntax ``matchAltExpr) := do + mkElseAlt (indVal : InductiveVal): TermElabM (TSyntax ``matchAltExpr) := do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -30,11 +34,15 @@ where let altRhs ← `(false) `(matchAltExpr| | $[$patterns:term],* => $altRhs:term) - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let subargs := args[:ctorInfo.numParams] + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) subargs + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs type => do let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies let mut patterns := #[] -- add `_` pattern for indices @@ -45,7 +53,7 @@ where let mut rhs ← `(true) let mut rhs_empty := true for i in [:ctorInfo.numFields] do - let pos := indVal.numParams + ctorInfo.numFields - i - 1 + let pos := indVal.numParams + ctorInfo.numFields - subargs.size - i - 1 let x := xs[pos]! if type.containsFVar x.fvarId! then -- If resulting type depends on this field, we don't need to compare @@ -59,7 +67,7 @@ where let xType ← inferType x if (← isProp xType) then continue - if xType.isAppOf indVal.name then + if let some auxFunName ← ctx.getFunName? header xType fvars then if rhs_empty then rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident) rhs_empty := false @@ -88,19 +96,21 @@ where patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2.reverse:term*)) `(matchAltExpr| | $[$patterns:term],* => $rhs:term) alts := alts.push alt - alts := alts.push (← mkElseAlt) + alts := alts.push (← mkElseAlt indVal) return alts def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do - let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkBEqHeader indVal - let mut body ← mkMatch header indVal auxFunName + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkBEqHeader argNames indTypeFormer + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkMatch ctx header type xs if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames body ← mkLet letDecls body - let binders := header.binders - if ctx.usePartial then `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) @@ -115,8 +125,8 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName]) + let ctx ← mkContext "beq" declName false + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq) trace[Elab.Deriving.beq] "\n{cmds}" return cmds @@ -125,8 +135,8 @@ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do `(private def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx) private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" name - let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name]) + let ctx ← mkContext "beq" name false + let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq) trace[Elab.Deriving.beq] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean index a3b6fbf5e0b5..fcd9717992da 100644 --- a/src/Lean/Elab/Deriving/DecEq.lean +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -6,19 +6,25 @@ Authors: Leonardo de Moura prelude import Lean.Meta.Transform import Lean.Meta.Inductive +import Lean.Meta.InferType import Lean.Elab.Deriving.Basic import Lean.Elab.Deriving.Util +import Lean.Elab.Binders namespace Lean.Elab.Deriving.DecEq open Lean.Parser.Term open Meta -def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `DecidableEq 2 indVal +def mkDecEqHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + mkHeader `DecidableEq 2 argNames indTypeFormer -def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term @@ -31,14 +37,14 @@ where `(if h : @$a = @$b then by subst h; exact $sameCtor:term else - isFalse (by intro n; injection n; apply h _; assumption)) + isFalse (by intro n; injection n; contradiction)) if let some auxFunName := recField then -- add local instance for `a = b` using the function being defined `auxFunName` `(let inst := $(mkIdent auxFunName) @$a @$b; $rhs) else return rhs - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName₁ in indVal.ctors do let ctorInfo ← getConstInfoCtor ctorName₁ @@ -48,7 +54,10 @@ where for _ in [:indVal.numIndices] do patterns := patterns.push (← `(_)) if ctorName₁ == ctorName₂ then - let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let args := e.getAppArgs + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs type => do let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies let mut patterns := patterns let mut ctorArgs1 := #[] @@ -59,7 +68,7 @@ where ctorArgs2 := ctorArgs2.push (← `(_)) let mut todo := #[] for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! if type.containsFVar x.fvarId! then -- If resulting type depends on this field, we don't need to compare ctorArgs1 := ctorArgs1.push (← `(_)) @@ -70,11 +79,8 @@ where ctorArgs1 := ctorArgs1.push a ctorArgs2 := ctorArgs2.push b let xType ← inferType x - let indValNum := - ctx.typeInfos.findIdx? - (xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal) - let recField := indValNum.map (ctx.auxFunNames[·]!) - let isProof ← isProp xType + let recField ← ctx.getFunName? header xType fvars + let isProof ← isProp xType todo := todo.push (a, b, recField, isProof) patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*)) patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*)) @@ -87,47 +93,43 @@ where alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term)) return alts -def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do - let header ← mkDecEqHeader indVal - let body ← mkMatch ctx header indVal +def mkAuxFunction (ctx : Context) (auxFunName : Name) (argNames : Array Name) (indTypeFormer : IndTypeFormer): TermElabM (TSyntax `command) := do + let header ← mkDecEqHeader argNames indTypeFormer let binders := header.binders let target₁ := mkIdent header.targetNames[0]! let target₂ := mkIdent header.targetNames[1]! - let termSuffix ← if indVal.isRec + let termSuffix ← if ctx.auxFunNames.size > 1 || indTypeFormer.getIndVal.isRec then `(Parser.Termination.suffix|termination_by structural $target₁) else `(Parser.Termination.suffix|) - let type ← `(Decidable ($target₁ = $target₂)) + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let body ← mkMatch ctx header type xs + let type ← `(Decidable ($target₁ = $target₂)) `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term $termSuffix:suffix) def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do let mut res : Array (TSyntax `command) := #[] for i in [:ctx.auxFunNames.size] do - let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - res := res.push (← mkAuxFunction ctx auxFunName indVal) + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + res := res.push (← mkAuxFunction ctx auxFunName argNames indTypeFormer) `(command| mutual $[$res:command]* end) def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do let ctx ← mkContext "decEq" indVal.name - let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false)) + let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq (useAnonCtor := false)) trace[Elab.Deriving.decEq] "\n{cmds}" return cmds open Command -def mkDecEq (declName : Name) : CommandElabM Bool := do +def mkDecEq (declName : Name) : CommandElabM Unit := do let indVal ← getConstInfoInduct declName - if indVal.isNested then - return false -- nested inductive types are not supported yet - else - let cmds ← liftTermElabM <| mkDecEqCmds indVal - -- `cmds` can have a number of syntax nodes quadratic in the number of constructors - -- and thus create as many info tree nodes, which we never make use of but which can - -- significantly slow down e.g. the unused variables linter; avoid creating them - withEnableInfoTree false do - cmds.forM elabCommand - return true + let cmds ← liftTermElabM <| mkDecEqCmds indVal + withEnableInfoTree false do + cmds.forM elabCommand partial def mkEnumOfNat (declName : Name) : MetaM Unit := do let indVal ← getConstInfoInduct declName @@ -195,15 +197,15 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do trace[Elab.Deriving.decEq] "\n{cmd}" elabCommand cmd -def mkDecEqInstance (declName : Name) : CommandElabM Bool := do +def mkDecEqInstance (declName : Name) : CommandElabM Unit := do if (← isEnumType declName) then mkDecEqEnum declName - return true else mkDecEq declName def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true + declNames.forM mkDecEqInstance + return true builtin_initialize registerDerivingHandler `DecidableEq mkDecEqInstanceHandler diff --git a/src/Lean/Elab/Deriving/FromToJson.lean b/src/Lean/Elab/Deriving/FromToJson.lean index c68e477d04bb..6bff3206ff25 100644 --- a/src/Lean/Elab/Deriving/FromToJson.lean +++ b/src/Lean/Elab/Deriving/FromToJson.lean @@ -15,11 +15,11 @@ open Lean.Json open Lean.Parser.Term open Lean.Meta -def mkToJsonHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader ``ToJson 1 indVal +def mkToJsonHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + mkHeader ``ToJson 1 argNames indTypeFormer -def mkFromJsonHeader (indVal : InductiveVal) : TermElabM Header := do - let header ← mkHeader ``FromJson 0 indVal +def mkFromJsonHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + let header ← mkHeader ``FromJson 0 argNames indTypeFormer let jsonArg ← `(bracketedBinderF|(json : Json)) return {header with binders := header.binders.push jsonArg} @@ -29,7 +29,14 @@ def mkJsonField (n : Name) : CoreM (Bool × Term) := do let s₁ := s.dropRightWhile (· == '?') return (s != s₁, Syntax.mkStrLit s₁) -def mkToJsonBodyForStruct (header : Header) (indName : Name) : TermElabM Term := do +def mkToJson (ctx : Context) (header : Header) (id : TSyntax `term) (type : Expr) (fvars : Array Expr) : TermElabM Term := do + if let some auxFunName ← ctx.getFunName? header type fvars then + `($(mkIdent auxFunName):ident $id) + else ``(toJson $id) + +def mkToJsonBodyForStruct (header : Header) (e : Expr) : TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false) let fields ← fields.mapM fun field => do let (isOptField, nm) ← mkJsonField field @@ -38,37 +45,40 @@ def mkToJsonBodyForStruct (header : Header) (indName : Name) : TermElabM Term := else ``([($nm, toJson ($target).$(mkIdent field))]) `(mkObj <| List.join [$fields,*]) -def mkToJsonBodyForInduct (ctx : Context) (header : Header) (indName : Name) : TermElabM Term := do +def mkToJsonBodyForInduct (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! + let lvls := f.constLevels! let indVal ← getConstInfoInduct indName - let toJsonFuncId := mkIdent ctx.auxFunNames[0]! -- Return syntax to JSONify `id`, either via `ToJson` or recursively -- if `id`'s type is the type we're deriving for. - let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do - if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident) - else ``(toJson $id:ident) + let discrs ← mkDiscrs header indVal - let alts ← mkAlts indVal fun ctor args userNames => do + let alts ← mkAlts indVal lvls fun ctor args userNames => do let ctorStr := ctor.name.eraseMacroScopes.getString! match args, userNames with | #[], _ => ``(toJson $(quote ctorStr)) - | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))]) + | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson ctx header x t fvars))]) | xs, none => - let xs ← xs.mapM fun (x, t) => mkToJson x t + let xs ← xs.mapM fun (x, t) => mkToJson ctx header x t fvars ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) | xs, some userNames => let xs ← xs.mapIdxM fun idx (x, t) => do - `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t))) + `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson ctx header x t fvars))) ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) `(match $[$discrs],* with $alts:matchAlt*) where mkAlts - (indVal : InductiveVal) - (rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term): TermElabM (Array (TSyntax ``matchAlt)) := do + (indVal : InductiveVal) (lvl : List Level) + (rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs _ => do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -81,7 +91,7 @@ where let mut binders := #[] let mut userNames := #[] for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let localDecl ← x.fvarId!.getDecl if !localDecl.userName.hasMacroScopes then userNames := userNames.push localDecl.userName @@ -94,7 +104,9 @@ where alts := alts.push alt return alts -def mkFromJsonBodyForStruct (indName : Name) : TermElabM Term := do +def mkFromJsonBodyForStruct (e : Expr) : TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false) let getters ← fields.mapM (fun field => do let getter ← `(getObjValAs? json _ $(Prod.snd <| ← mkJsonField field)) @@ -106,32 +118,39 @@ def mkFromJsonBodyForStruct (indName : Name) : TermElabM Term := do $[let $fields:ident ← $getters]* return { $[$fields:ident := $(id fields)],* }) -def mkFromJsonBodyForInduct (ctx : Context) (indName : Name) : TermElabM Term := do +def mkFromJsonBodyForInduct (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! + let lvls := f.constLevels! let indVal ← getConstInfoInduct indName - let alts ← mkAlts indVal + let alts ← mkAlts indVal lvls let auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) `($auxTerm) where - mkAlts (indVal : InductiveVal) : TermElabM (Array Term) := do + mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array Term) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← do forallTelescopeReducing ctorInfo.type fun xs _ => do - let mut binders := #[] + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← do forallTelescopeReducing ctorType fun xs _ => do + let mut binders := #[] let mut userNames := #[] for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let localDecl ← x.fvarId!.getDecl if !localDecl.userName.hasMacroScopes then userNames := userNames.push localDecl.userName let a := mkIdent (← mkFreshUserName `a) binders := binders.push (a, localDecl.type) - let fromJsonFuncId := mkIdent ctx.auxFunNames[0]! -- Return syntax to parse `id`, either via `FromJson` or recursively -- if `id`'s type is the type we're deriving for. - let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) := - if type.isAppOf indVal.name then `(Lean.Parser.Term.doExpr| $fromJsonFuncId:ident jsons[$(quote idx)]!) - else `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)]!) + let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) := do + if let some auxFunName ← ctx.getFunName? header type fvars then + `(Lean.Parser.Term.doExpr| $(mkIdent auxFunName) jsons[$(quote idx)]!) + else + `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)]!) let identNames := binders.map Prod.fst let fromJsons ← binders.mapIdxM fun idx (_, type) => mkFromJson idx type let userNamesOpt ← if binders.size == userNames.size then @@ -149,48 +168,49 @@ where let alts' := alts.qsort (fun (_, x) (_, y) => x < y) return alts'.map Prod.fst -def mkToJsonBody (ctx : Context) (header : Header) (e : Expr): TermElabM Term := do - let indName := e.getAppFn.constName! - if isStructure (← getEnv) indName then - mkToJsonBodyForStruct header indName +def mkToJsonBody (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + if isStructure (← getEnv) e.getAppFn.constName! then + mkToJsonBodyForStruct header e else - mkToJsonBodyForInduct ctx header indName + mkToJsonBodyForInduct ctx header e fvars def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do - let auxFunName := ctx.auxFunNames[i]! - let header ← mkToJsonHeader ctx.typeInfos[i]! - let binders := header.binders - Term.elabBinders binders fun _ => do - let type ← Term.elabTerm header.targetType none - let mut body ← mkToJsonBody ctx header type - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames - body ← mkLet letDecls body - `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) - else - `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkToJsonHeader argNames indTypeFormer + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkToJsonBody ctx header type xs + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) -def mkFromJsonBody (ctx : Context) (e : Expr) : TermElabM Term := do - let indName := e.getAppFn.constName! - if isStructure (← getEnv) indName then - mkFromJsonBodyForStruct indName +def mkFromJsonBody (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + if isStructure (← getEnv) e.getAppFn.constName! then + mkFromJsonBodyForStruct e else - mkFromJsonBodyForInduct ctx indName + mkFromJsonBodyForInduct ctx header e fvars def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do - let auxFunName := ctx.auxFunNames[i]! - let indval := ctx.typeInfos[i]! - let header ← mkFromJsonHeader indval --TODO fix header info - let binders := header.binders - Term.elabBinders binders fun _ => do - let type ← Term.elabTerm header.targetType none - let mut body ← mkFromJsonBody ctx type - if ctx.usePartial || indval.isRec then - let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames - body ← mkLet letDecls body - `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term) - else - `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term) + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkFromJsonHeader argNames indTypeFormer --TODO fix header info + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkFromJsonBody ctx header type xs + if ctx.usePartial || indTypeFormer.getIndVal.isRec then + let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do @@ -210,14 +230,14 @@ def mkFromJsonMutualBlock (ctx : Context) : TermElabM Command := do end) private def mkToJsonInstance (declName : Name) : TermElabM (Array Command) := do - let ctx ← mkContext "toJson" declName - let cmds := #[← mkToJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) + let ctx ← mkContext "toJson" declName false + let cmds := #[← mkToJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``ToJson) trace[Elab.Deriving.toJson] "\n{cmds}" return cmds private def mkFromJsonInstance (declName : Name) : TermElabM (Array Command) := do - let ctx ← mkContext "fromJson" declName - let cmds := #[← mkFromJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) + let ctx ← mkContext "fromJson" declName false + let cmds := #[← mkFromJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``FromJson) trace[Elab.Deriving.fromJson] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Hashable.lean b/src/Lean/Elab/Deriving/Hashable.lean index be7c5ee1d85a..b5e15ddf84dc 100644 --- a/src/Lean/Elab/Deriving/Hashable.lean +++ b/src/Lean/Elab/Deriving/Hashable.lean @@ -13,22 +13,28 @@ open Command open Lean.Parser.Term open Meta -def mkHashableHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `Hashable 1 indVal +def mkHashableHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + mkHeader `Hashable 1 argNames indTypeFormer -def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] let mut ctorIdx := 0 - let allIndVals := indVal.all.toArray for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs _ => do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -39,16 +45,14 @@ where for _ in [:indVal.numParams] do ctorArgs := ctorArgs.push (← `(_)) for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let a := mkIdent (← mkFreshUserName `a) ctorArgs := ctorArgs.push a - let xTy ← whnf (← inferType x) - match xTy.getAppFn with - | .const declName .. => - match allIndVals.findIdx? (· == declName) with - | some x => rhs ← `(mixHash $rhs ($(mkIdent ctx.auxFunNames[x]!) $a:ident)) - | none => rhs ← `(mixHash $rhs (hash $a:ident)) - | _ => rhs ← `(mixHash $rhs (hash $a:ident)) + let xType ← inferType x + if let some auxFunName ← ctx.getFunName? header xType fvars then + rhs ← `(mixHash $rhs ($(mkIdent auxFunName) $a)) + else + rhs ← `(mixHash $rhs (hash $a:ident)) patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*)) `(matchAltExpr| | $[$patterns:term],* => $rhs:term) alts := alts.push alt @@ -56,16 +60,18 @@ where return alts def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do - let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkHashableHeader indVal - let mut body ← mkMatch ctx header indVal + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkHashableHeader argNames indTypeFormer + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkMatch ctx header type xs if ctx.usePartial then + -- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion let letDecls ← mkLocalInstanceLetDecls ctx `Hashable header.argNames body ← mkLet letDecls body - let binders := header.binders - if ctx.usePartial then - -- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term) @@ -77,8 +83,8 @@ def mkHashFuncs (ctx : Context) : TermElabM Syntax := do `(mutual $auxDefs:command* end) private def mkHashableInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "hash" declName - let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable #[declName]) + let ctx ← mkContext "hash" declName false + let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable) trace[Elab.Deriving.hashable] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Ord.lean b/src/Lean/Elab/Deriving/Ord.lean index aed902aa4167..a979ca22b3b6 100644 --- a/src/Lean/Elab/Deriving/Ord.lean +++ b/src/Lean/Elab/Deriving/Ord.lean @@ -12,19 +12,26 @@ namespace Lean.Elab.Deriving.Ord open Lean.Parser.Term open Meta -def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `Ord 2 indVal +def mkOrdHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + mkHeader `Ord 2 argNames indTypeFormer -def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs type => do let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies let mut indPatterns := #[] -- add `_` pattern for indices @@ -39,8 +46,9 @@ where ctorArgs1 := ctorArgs1.push (← `(_)) ctorArgs2 := ctorArgs2.push (← `(_)) for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! - if type.containsFVar x.fvarId! || (←isProp (←inferType x)) then + let x := xs[i]! + let xType ← inferType x + if type.containsFVar x.fvarId! || (←isProp xType) then -- If resulting type depends on this field or is a proof, we don't need to compare ctorArgs1 := ctorArgs1.push (← `(_)) ctorArgs2 := ctorArgs2.push (← `(_)) @@ -49,7 +57,12 @@ where let b := mkIdent (← mkFreshUserName `b) ctorArgs1 := ctorArgs1.push a ctorArgs2 := ctorArgs2.push b - rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont + let compare ← + if let some auxFunName ← ctx.getFunName? header xType fvars then + `($(mkIdent auxFunName) $a $b) + else + `(compare $a $b) + rhsCont := fun rhs => `(Ordering.then $compare $rhs) >>= rhsCont let lPat ← `(@$(mkIdent ctorName):ident $ctorArgs1:term*) let rPat ← `(@$(mkIdent ctorName):ident $ctorArgs2:term*) let patterns := indPatterns ++ #[lPat, rPat] @@ -63,18 +76,15 @@ where return alts.pop.pop def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do - let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkOrdHeader indVal - let mut body ← mkMatch header indVal - if ctx.usePartial || indVal.isRec then - let letDecls ← mkLocalInstanceLetDecls ctx `Ord header.argNames - body ← mkLet letDecls body - let binders := header.binders - if ctx.usePartial || indVal.isRec then - `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) - else - `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkOrdHeader argNames indTypeFormer + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let body ← mkMatch ctx header type xs + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) def mkMutualBlock (ctx : Context) : TermElabM Syntax := do let mut auxDefs := #[] @@ -87,7 +97,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do let ctx ← mkContext "ord" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName]) + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord) trace[Elab.Deriving.ord] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Repr.lean b/src/Lean/Elab/Deriving/Repr.lean index 0a03ed1cc0e3..d6993f556ab4 100644 --- a/src/Lean/Elab/Deriving/Repr.lean +++ b/src/Lean/Elab/Deriving/Repr.lean @@ -14,13 +14,21 @@ open Lean.Parser.Term open Meta open Std -def mkReprHeader (indVal : InductiveVal) : TermElabM Header := do - let header ← mkHeader `Repr 1 indVal +def mkReprHeader (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + let header ← mkHeader `Repr 1 argNames indTypeFormer return { header with binders := header.binders.push (← `(bracketedBinderF| (prec : Nat))) } -def mkBodyForStruct (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkRepr (ctx : Context) (header : Header) (id : TSyntax `term) (type : Expr) (fvars : Array Expr) : TermElabM Term := do + if let some auxFunName ← ctx.getFunName? header type fvars then + ``($(mkIdent auxFunName) $id) + else ``(reprPrec $id) + +def mkBodyForStruct(header : Header) (e : Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let indVal ← getConstInfoInduct ind let ctorVal ← getConstInfoCtor indVal.ctors.head! let fieldNames := getStructureFields (← getEnv) indVal.name let numParams := indVal.numParams @@ -42,16 +50,23 @@ def mkBodyForStruct (header : Header) (indVal : InductiveVal) : TermElabM Term : fields ← `($fields ++ $fieldNameLit ++ " := " ++ (Format.group (Format.nest $indent (repr ($target.$(mkIdent fieldName):ident))))) `(Format.bracket "{ " $fields:term " }") -def mkBodyForInduct (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do +def mkBodyForInduct (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs _ => do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -63,38 +78,39 @@ where for _ in [:indVal.numParams] do ctorArgs := ctorArgs.push (← `(_)) for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let a := mkIdent (← mkFreshUserName `a) ctorArgs := ctorArgs.push a let localDecl ← x.fvarId!.getDecl if localDecl.binderInfo.isExplicit then - if (← inferType x).isAppOf indVal.name then - rhs ← `($rhs ++ Format.line ++ $(mkIdent auxFunName):ident $a:ident max_prec) - else if (← isType x <||> isProof x) then + if (← isType x <||> isProof x) then rhs ← `($rhs ++ Format.line ++ "_") else - rhs ← `($rhs ++ Format.line ++ reprArg $a) + let repr ← mkRepr ctx header a (← inferType x) fvars + rhs ← `($rhs ++ Format.line ++ $repr max_prec) patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*)) `(matchAltExpr| | $[$patterns:term],* => Repr.addAppParen (Format.group (Format.nest (if prec >= max_prec then 1 else 2) ($rhs:term))) prec) alts := alts.push alt return alts -def mkBody (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do - if isStructure (← getEnv) indVal.name then - mkBodyForStruct header indVal +def mkBody (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + if isStructure (← getEnv) e.getAppFn.constName! then + mkBodyForStruct header e else - mkBodyForInduct header indVal auxFunName + mkBodyForInduct ctx header e fvars def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do - let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkReprHeader indVal - let mut body ← mkBody header indVal auxFunName + let auxFunName := ctx.auxFunNames[i]! + let indTypeFormer := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkReprHeader argNames indTypeFormer + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkBody ctx header type xs if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames body ← mkLet letDecls body - let binders := header.binders - if ctx.usePartial then `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) @@ -108,8 +124,8 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkReprInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "repr" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr #[declName]) + let ctx ← mkContext "repr" declName false + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr) trace[Elab.Deriving.repr] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Util.lean b/src/Lean/Elab/Deriving/Util.lean index 03f7eecae200..2b55f251a3e6 100644 --- a/src/Lean/Elab/Deriving/Util.lean +++ b/src/Lean/Elab/Deriving/Util.lean @@ -6,6 +6,11 @@ Authors: Leonardo de Moura prelude import Lean.Parser.Term import Lean.Elab.Term +import Lean.Elab.Binders +import Lean.PrettyPrinter +import Lean.Data.Options +import Init.Data.Sum +import Lean.Meta.CollectFVars namespace Lean.Elab.Deriving open Meta @@ -36,59 +41,340 @@ open TSyntax.Compat in /-- Return implicit binder syntaxes for the given `argNames`. The output matches `implicitBinder*`. For example, ``#[`foo,`bar]`` gives `` `({foo} {bar})``. -/ -def mkImplicitBinders (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.implicitBinder)) := +def mkImplicitBinders (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.implicitBinder)) := do argNames.mapM fun argName => `(implicitBinderF| { $(mkIdent argName) }) +/-- + Represents `ind params1 ... paramsn`, where `paramsi` is either a nested type formers, a constant name, or a free variable + + Example : + ``` + inductive Foo (α β γ: Type) : Type → Type + + inductive Bar + | foo : A -> B -> Foo A (Option Bar) B Nat → Bar + + ``` + this nested type formers `Foo A (Option Bar) B Nat` is encoded as + ```lean + node Foo #[ + .inr (.bvar 2), + .inl (node Option #[.inr (.leaf Bar #[])]), + .inr (.bvar 1), + ] + ``` + + Remark 1 : Since an inductive can only be nested as a parameter, not an index, of another inductive type, + we assume `params.size == ind.numParams` + + Remark 2 : Free variables are abstracted away as bound variables in nested type formerss. + This is useful when trying to delete duplicate occurences, since the check is now purely syntactical. + -/ +inductive IndTypeFormer : Type := + | node (ind : InductiveVal) (params : Array (IndTypeFormer ⊕ Expr)) + | leaf (ind : InductiveVal) (fvars : Array Expr) + +namespace IndTypeFormer + +instance : Inhabited IndTypeFormer := ⟨leaf default #[]⟩ +instance : Inhabited (IndTypeFormer ⊕ α) := ⟨.inl default⟩ + +partial instance : BEq IndTypeFormer := ⟨go⟩ +where go + | leaf ind₁ _,.leaf ind₂ _ => ind₁.name == ind₂.name + | node ind₁ arr₁,.node ind₂ arr₂ => Id.run do + unless ind₁.name == ind₂.name && arr₁.size == arr₂.size do + return false + for i in [:arr₁.size] do + unless @Sum.instBEq _ _ ⟨go⟩ inferInstance |>.beq arr₁[i]! arr₂[i]! do + return false + return true + | _,_ => false + +partial instance : ToString IndTypeFormer := ⟨go⟩ +where go + | leaf ind e => s!"leaf {ind.name} {e}" + | node ind arr => + let s := arr.map (@instToStringSum _ _ ⟨go⟩ inferInstance |>.toString) + s!"node {ind.name} {s}" + +@[inline] +def getIndVal : IndTypeFormer → InductiveVal + | leaf indVal _ | node indVal _ => indVal + +@[inline] +def getArr : IndTypeFormer → Array (IndTypeFormer ⊕ Expr) + | leaf .. => #[] + | node _ arr => arr + +@[inline] +def isLeaf : IndTypeFormer → Bool + | leaf .. => true + | node .. => false + +@[inline] +def isNode : IndTypeFormer → Bool := not ∘ isLeaf + +partial def containsFVar (fvarId : FVarId) : IndTypeFormer → Bool + | leaf _ e => e.any (Expr.containsFVar · fvarId) + | node _ arr => arr.any (Sum.elim (containsFVar fvarId) (Expr.containsFVar · fvarId)) + +partial def toListofTypeFormers (e : IndTypeFormer) : List IndTypeFormer := + match e with + | .leaf _ _ => [] + | .node _ arr => + let l := flip arr.foldr [] fun indFormer l => + if let .inl indFormer := indFormer then + indFormer.toListofTypeFormers ++ l + else l + e::l + +/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/ +partial def mkAppTerm (indTypeFormer : IndTypeFormer) (argNames : Array Name) : TermElabM Term := do + go indTypeFormer argNames +where + go (indTypeFormer : IndTypeFormer) (argNames : Array Name) : TermElabM Term := do + match indTypeFormer with + | leaf indVal _ => + let f := mkCIdent indVal.name + let numArgs := indVal.numParams + indVal.numIndices + unless argNames.size >= numArgs do + throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" + let mut args := Array.mkArray numArgs default + for i in [:numArgs] do + let arg := mkIdent argNames[i]! + args := args.set! i arg + `(@$f $args*) + | node indVal arr => + let f := mkCIdent indVal.name + let mut args := #[] + for indTypeFormer? in arr do + match indTypeFormer? with + | .inl indFormer => + let arg ← go indFormer argNames + args := args.push arg + | .inr (.bvar i) => + let some argName := argNames[argNames.size-i-1]? + | throwError s!"Cannot instantiate {indTypeFormer} : not enough arguments given" + let id := mkIdent argName + args := args.push <| ←`($id) + | .inr e => + let tm ← PrettyPrinter.delab e + args := args.push <| ←`($tm) + `(@$f $args*) + +/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/ +partial def mkAppExpr (indTypeFormer : IndTypeFormer) (argNames : Array Expr) : TermElabM Expr := do + let res ← go indTypeFormer argNames + return res +where + go (indTypeFormer : IndTypeFormer) (argNames : Array Expr): TermElabM Expr := do + match indTypeFormer with + | leaf indVal _ => + let numArgs := indVal.numParams + indVal.numIndices + unless argNames.size >= numArgs do + throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" + let mut args := Array.mkArray numArgs default + for i in [:numArgs] do + let arg := argNames[i]! + args := args.modify i (fun _ => arg) + let name ← Meta.mkConstWithFreshMVarLevels indVal.name + let res := args.foldl mkApp name + return res + | node indVal arr => + let mut args := #[] + for indTypeFormer? in arr do + match indTypeFormer? with + | .inl indFormer => + let arg ← go indFormer argNames + args := args.push arg + | .inr (.bvar i) => + let some argName := argNames[argNames.size-i-1]? + | throwError s!"Cannot instantiate {indTypeFormer} : not enough arguments given" + args := args.push argName + | .inr e => + args := args.push e + let res ← Meta.mkAppOptM indVal.name (args.map some) + return res + +structure Result where + indFormer : IndTypeFormer + args : Subarray Expr + argNames : Array Name + +instance : ToString Result where + toString res := s!"⟨{res.indFormer},{res.args},{res.argNames}⟩" + +instance : BEq Result := ⟨(·.indFormer == ·.indFormer)⟩ + +structure Context where + indNames : List Name + results : List Result + +end IndTypeFormer + +abbrev IndTypeFormerM := StateT IndTypeFormer.Context TermElabM + +def withIndNames (indNames : List Name) (f : IndTypeFormerM Unit) : TermElabM IndTypeFormer.Context := do + let ⟨(),ctx⟩ ← StateT.run f ⟨indNames,[]⟩ + return ctx + +def addResult (x : IndTypeFormer.Result) : IndTypeFormerM Unit := do + let ⟨names,res⟩ ← get + set (⟨names,x::res⟩ : IndTypeFormer.Context) + +def addName (n : Name) : IndTypeFormerM Unit := do + let ⟨names,res⟩ ← get + set (⟨n::names,res⟩ : IndTypeFormer.Context) + +partial def getIndTypeFormersOf (inds : List Name) (e: Expr) (fvars : Array Expr): MetaM (Option IndTypeFormer) := do + let .inl occs ← go e | return none + return occs +where + go (e : Expr) : MetaM (IndTypeFormer ⊕ Expr) := do + let hd := e.getAppFn + let args := e.getAppArgs + let fallback _ := return .inr <| e.abstract fvars + let .const name _ := hd | fallback () + if let some indName := inds.find? (· = name) then + let indVal ← getConstInfoInduct indName + let args := args.map (Expr.instantiateRev · fvars) + return .inl <| .leaf indVal args + else + try + let indVal ← getConstInfoInduct name + let args := args.map (Expr.abstract · fvars) + let indTypeFormersArgs ← args.mapM go + if indTypeFormersArgs.any (· matches .inl _) then + return .inl <| .node indVal indTypeFormersArgs + else fallback () + catch _ => fallback () + +partial def getIndTypeFormers (indNames : List Name) : TermElabM (List IndTypeFormer.Result) := do + let ⟨_,l⟩ ← withIndNames indNames do + for name in indNames do + go name #[] #[] + return l.eraseDups +where + go (indName : Name) (args : Array Expr) (fvars : Array Expr): IndTypeFormerM Unit := do + let indVal ← getConstInfoInduct indName + if !indVal.isNested && args.size == 0 then + return + let consts ← indVal.ctors.mapM getConstInfoCtor + for constInfo in consts do + let instConstInfo ← forallBoundedTelescope constInfo.type args.size fun xs e => + return e.abstract xs |>.instantiateRev args + forallTelescope instConstInfo fun xs _ => do + let mut paramArgs ← fvars.mapM fun e => do + let localDecl ← e.fvarId!.getDecl + mkFreshUserName localDecl.userName.eraseMacroScopes + let mut l := [] + for i in [:constInfo.numParams - args.size] do + let some e := xs[i]? | break + let localDecl ← e.fvarId!.getDecl + let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes + paramArgs := paramArgs.push paramName + let mut localArgs := #[] + for i in [:xs.size] do + let e := xs[i]! + let ty ← e.fvarId!.getType + let localDecl ← e.fvarId!.getDecl + let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes + let indFormers ← getIndTypeFormersOf indNames ty xs[:i] + let l' := if let .some x := indFormers then x.toListofTypeFormers else [] + for indFormer in l' do + let new_args := paramArgs ++ localArgs.filter (indFormer.containsFVar ⟨·⟩) + if (← get).results.all (indFormer != ·.indFormer) then + addResult ⟨indFormer, xs[:i], new_args⟩ + let fvars := fvars ++ xs[:i] + let app ← indFormer.mkAppExpr fvars + let hd := app.getAppFn.constName! + let args := app.getAppArgs + addName hd + go hd args fvars + else + addResult ⟨indFormer,xs[:i],new_args⟩ + l := l ++ l' + localArgs := localArgs.push paramName + +def indNameToFunName (indName : Name) : String := + match indName.eraseMacroScopes with + | .str _ t => t + | _ => "instFn" + +partial def mkInstName: IndTypeFormer → String + | .leaf ind _ => indNameToFunName ind.name + | .node ind arr => Id.run do + let mut res ← indNameToFunName ind.name + for indTypeFormer in arr do + if let .inl indFormer := indTypeFormer then + let nestedInstName ← mkInstName indFormer + res := res ++ nestedInstName + return res + /-- Return instance binder syntaxes binding `className α` for every generic parameter `α` of the inductive `indVal` for which such a binding is type-correct. `argNames` is expected to provide names for the parameters (see `mkInductArgNames`). The output matches `instBinder*`. For example, given `inductive Foo {α : Type} (n : Nat) : (β : Type) → Type`, where `β` is an index, invoking ``mkInstImplicitBinders `BarClass foo #[`α, `n, `β]`` gives `` `([BarClass α])``. -/ -def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) := +partial def mkInstImplicitBinders (className : Name) (indTypeFormer : IndTypeFormer) (argNames : Array Name) : TermElabM (Array Syntax) := do + go indTypeFormer argNames +where + go (indTypeFormer : IndTypeFormer) (argNames : Array Name) : TermElabM (Array Syntax) := + let indVal := indTypeFormer.getIndVal + let arr := indTypeFormer.getArr forallBoundedTelescope indVal.type indVal.numParams fun xs _ => do - let mut binders := #[] + let mut binders : Array Syntax := #[] for i in [:xs.size] do - try - let x := xs[i]! - let c ← mkAppM className #[x] - if (← isTypeCorrect c) then - let argName := argNames[i]! - let binder : Syntax ← `(instBinderF| [ $(mkCIdent className):ident $(mkIdent argName):ident ]) - binders := binders.push binder - catch _ => - pure () + if indTypeFormer.isNode && arr[i]? matches some (.inl _) then + let some indFormer := arr[i]!.getLeft? | unreachable! + let nestedBinders ← go indFormer argNames + binders := binders ++ nestedBinders + else try + let x := xs[i]! + let hd ← mkConstWithFreshMVarLevels className + let c := mkAppN hd #[x] + if (← isTypeCorrect c) then + let some argName := argNames[i]? | pure () + let binder ← `(instBinderF| [$(mkCIdent className) $(mkIdent argName)]) + binders := binders.push binder + catch _ => pure () return binders -structure Context where - typeInfos : Array InductiveVal +structure Context : Type where + indNames : List Name + typeArgNames: Array (Array Name) + typeInfos : Array IndTypeFormer auxFunNames : Array Name usePartial : Bool -def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do +def mkContext (fnPrefix : String) (typeName : Name) (withNested : Bool := true): TermElabM Context := do let indVal ← getConstInfoInduct typeName - let mut typeInfos := #[] - for typeName in indVal.all do - typeInfos := typeInfos.push (← getConstInfoInduct typeName) - let mut auxFunNames := #[] - for typeName in indVal.all do - match typeName.eraseMacroScopes with - | .str _ t => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ t) - | _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn) - trace[Elab.Deriving.beq] "{auxFunNames}" - let usePartial := indVal.isNested || typeInfos.size > 1 - return { - typeInfos := typeInfos - auxFunNames := auxFunNames - usePartial := usePartial - } + let indNames := indVal.all + let mut typeInfos' : List IndTypeFormer.Result := [] + for indName in indNames do + let indVal ← getConstInfoInduct indName + let args ← mkInductArgNames indVal + typeInfos' := ⟨.leaf indVal #[], #[].toSubarray, args⟩::typeInfos' + if withNested then + typeInfos' := (← getIndTypeFormers indVal.all) ++ typeInfos' + let typeArgNames := typeInfos'.map (·.argNames) |>.toArray + let typeInfos := typeInfos'.map (·.indFormer) |>.toArray + let auxFunNames ← typeInfos.mapM fun indFormer => do + return ← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ mkInstName indFormer + let usePartial := !withNested && indVal.isNested + return {indNames, typeArgNames, typeInfos, auxFunNames, usePartial} def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.letDecl)) := do let mut letDecls := #[] for i in [:ctx.typeInfos.size] do - let indVal := ctx.typeInfos[i]! + let indFormer := ctx.typeInfos[i]! let auxFunName := ctx.auxFunNames[i]! + unless indFormer.isLeaf do + continue + let indVal := indFormer.getIndVal let currArgNames ← mkInductArgNames indVal let numParams := indVal.numParams let currIndices := currArgNames[numParams:] @@ -107,22 +393,22 @@ def mkLet (letDecls : Array (TSyntax ``Parser.Term.letDecl)) (body : Term) : Ter `(let $letDecl:letDecl; $body) open TSyntax.Compat in -def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Command) := do +def mkInstanceCmds (ctx : Context) (className : Name) (useAnonCtor := true) : TermElabM (Array Command) := do let mut instances := #[] for i in [:ctx.typeInfos.size] do - let indVal := ctx.typeInfos[i]! - if typeNames.contains indVal.name then - let auxFunName := ctx.auxFunNames[i]! - let argNames ← mkInductArgNames indVal - let binders ← mkImplicitBinders argNames - let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) - let indType ← mkInductiveApp indVal argNames - let type ← `($(mkCIdent className) $indType) - let mut val := mkIdent auxFunName - if useAnonCtor then - val ← `(⟨$val⟩) - let instCmd ← `(instance $binders:implicitBinder* : $type := $val) - instances := instances.push instCmd + let indTypeFormer := ctx.typeInfos[i]! + unless indTypeFormer.isLeaf do continue + let auxFunName := ctx.auxFunNames[i]! + let argNames := ctx.typeArgNames[i]! + let binders ← mkImplicitBinders argNames + let binders := binders ++ (← mkInstImplicitBinders className indTypeFormer argNames) + let indType ← indTypeFormer.mkAppTerm argNames + let type ← `($(mkCIdent className) $indType) + let mut val := mkIdent auxFunName + if useAnonCtor then + val ← `(⟨$val⟩) + let instCmd ← `(instance $binders:implicitBinder* : $type := $val) + instances := instances.push instCmd return instances def mkDiscr (varName : Name) : TermElabM (TSyntax ``Parser.Term.matchDiscr) := @@ -135,15 +421,14 @@ structure Header where targetType : Term open TSyntax.Compat in -def mkHeader (className : Name) (arity : Nat) (indVal : InductiveVal) : TermElabM Header := do - let argNames ← mkInductArgNames indVal - let binders ← mkImplicitBinders argNames - let targetType ← mkInductiveApp indVal argNames +def mkHeader (className : Name) (arity : Nat) (argNames : Array Name) (indTypeFormer : IndTypeFormer) : TermElabM Header := do + let mut binders ← mkImplicitBinders argNames + let targetType ← indTypeFormer.mkAppTerm argNames let mut targetNames := #[] for _ in [:arity] do targetNames := targetNames.push (← mkFreshUserName `x) - let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) - let binders := binders ++ (← targetNames.mapM fun targetName => `(explicitBinderF| ($(mkIdent targetName) : $targetType))) + binders := binders ++ (← mkInstImplicitBinders className indTypeFormer argNames) + binders := binders ++ (← targetNames.mapM fun targetName => `(explicitBinderF| ($(mkIdent targetName) : $targetType))) return { binders := binders argNames := argNames @@ -158,4 +443,10 @@ def mkDiscrs (header : Header) (indVal : InductiveVal) : TermElabM (Array (TSynt discrs := discrs.push (← mkDiscr argName) return discrs ++ (← header.targetNames.mapM mkDiscr) +def Context.getFunName? (ctx : Context) (header : Header) (ty : Expr) (xs : Array Expr): TermElabM (Option Name) := do + let indValNum ← ctx.typeInfos.findIdxM? <| + (return .some · == (← getIndTypeFormersOf ctx.indNames ty xs[:header.argNames.size])) + let recField := indValNum.map (ctx.auxFunNames[·]!) + return recField + end Lean.Elab.Deriving diff --git a/tests/lean/3057.lean.expected.out b/tests/lean/3057.lean.expected.out index 34efeb5f83b6..dcb3e5ddfc90 100644 --- a/tests/lean/3057.lean.expected.out +++ b/tests/lean/3057.lean.expected.out @@ -1,10 +1,10 @@ -instReprTree -instReprListTree -instDecidableEqTree -instDecidableEqListTree -instBEqTree -instBEqListTree -instHashableTree -instHashableListTree -instOrdTree -instOrdListTree +instReprTree_1 +instReprListTree_1 +instDecidableEqTree_1 +instDecidableEqListTree_1 +instBEqTree_1 +instBEqListTree_1 +instHashableTree_1 +instHashableListTree_1 +instOrdTree_1 +instOrdListTree_1 diff --git a/tests/lean/decEqMutualInductives.lean b/tests/lean/decEqMutualInductives.lean index fbe0af7e9113..7ee1d50c9829 100644 --- a/tests/lean/decEqMutualInductives.lean +++ b/tests/lean/decEqMutualInductives.lean @@ -4,15 +4,19 @@ set_option trace.Elab.Deriving.decEq true mutual -inductive Tree : Type := +inductive Tree := | node : ListTree → Tree -inductive ListTree : Type := +inductive ListTree := | nil : ListTree | cons : Tree → ListTree → ListTree deriving DecidableEq end +inductive Tree' (α : Type _) : Type _:= + | node : α → Option (List (Tree' α)) → Tree' α +deriving DecidableEq + mutual inductive Foo₁ : Type := | foo₁₁ : Foo₁ @@ -25,3 +29,36 @@ inductive Foo₂ : Type := inductive Foo₃ : Type := | foo₃ : Foo₁ → Foo₃ end + +inductive Min' where + | Base + | Const (a : List Min') +deriving DecidableEq + +inductive ComplexInductive (A B C : Type) (n : Nat) : Type + | constr : A → B → C → ComplexInductive A B C n + +inductive NestedComplex (A C : Type) : Type + | constr : ComplexInductive A (NestedComplex A C) C 1 → NestedComplex A C +deriving DecidableEq + +namespace nested + +inductive Tree (α : Type) where + | node : Array (Tree α) → Tree α +deriving DecidableEq + +end nested + +namespace mess + +mutual + +inductive Mess1 where + | node : Array (Mess2) → Mess1 +deriving DecidableEq +inductive Mess2 where + | node : Array (Mess1) → Mess2 +end + +end mess diff --git a/tests/lean/decEqMutualInductives.lean.expected.out b/tests/lean/decEqMutualInductives.lean.expected.out index bc22d84274e3..34a9234621b7 100644 --- a/tests/lean/decEqMutualInductives.lean.expected.out +++ b/tests/lean/decEqMutualInductives.lean.expected.out @@ -1,55 +1,262 @@ [Elab.Deriving.decEq] [mutual - private def decEqTree✝ (x✝ : @Tree✝) (x✝¹ : @Tree✝) : Decidable✝ (x✝ = x✝¹) := + private def decEqListTree✝ (x✝ : @ListTree✝) (x✝¹ : @ListTree✝) : Decidable✝ (x✝ = x✝¹) := match x✝, x✝¹ with - | @Tree.node a✝, @Tree.node b✝ => - let inst✝ := decEqListTree✝ @a✝ @b✝; - if h✝ : @a✝ = @b✝ then by subst h✝; exact isTrue✝ rfl✝ - else isFalse✝ (by intro n✝; injection n✝; apply h✝ _; assumption) + | @ListTree.nil, @ListTree.nil => isTrue✝ rfl✝ + | ListTree.nil .., ListTree.cons .. => isFalse✝ (by intro h✝; injection h✝) + | ListTree.cons .., ListTree.nil .. => isFalse✝ (by intro h✝; injection h✝) + | @ListTree.cons a✝ a✝¹, @ListTree.cons b✝ b✝¹ => + let inst✝ := decEqTree✝ @a✝ @b✝; + if h✝¹ : @a✝ = @b✝ then by subst h✝¹; + exact + let inst✝¹ := decEqListTree✝ @a✝¹ @b✝¹; + if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝¹ rfl✝¹ + else isFalse✝¹ (by intro n✝; injection n✝; contradiction) + else isFalse✝² (by intro n✝¹; injection n✝¹; contradiction) termination_by structural x✝ - private def decEqListTree✝ (x✝² : @ListTree✝) (x✝³ : @ListTree✝) : Decidable✝ (x✝² = x✝³) := + private def decEqTree✝ (x✝² : @Tree✝) (x✝³ : @Tree✝) : Decidable✝ (x✝² = x✝³) := match x✝², x✝³ with - | @ListTree.nil, @ListTree.nil => isTrue✝¹ rfl✝¹ - | ListTree.nil .., ListTree.cons .. => isFalse✝¹ (by intro h✝¹; injection h✝¹) - | ListTree.cons .., ListTree.nil .. => isFalse✝¹ (by intro h✝¹; injection h✝¹) - | @ListTree.cons a✝¹ a✝², @ListTree.cons b✝¹ b✝² => - let inst✝¹ := decEqTree✝ @a✝¹ @b✝¹; - if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; - exact - let inst✝² := decEqListTree✝ @a✝² @b✝²; - if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² - else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝³ _; assumption) - else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝² _; assumption) + | @Tree.node a✝², @Tree.node b✝² => + let inst✝² := decEqListTree✝ @a✝² @b✝²; + if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² + else isFalse✝³ (by intro n✝²; injection n✝²; contradiction) termination_by structural x✝² end, instance : DecidableEq✝ (@ListTree✝) := - decEqListTree✝] + decEqListTree✝, + instance : DecidableEq✝ (@Tree✝) := + decEqTree✝] +[Elab.Deriving.decEq] + [mutual + private def decEqListTree'✝ {α✝} [DecidableEq✝ α✝] (x✝ : @List✝ (@Tree'✝ α✝)) (x✝¹ : @List✝ (@Tree'✝ α✝)) : + Decidable✝ (x✝ = x✝¹) := + match x✝, x✝¹ with + | @List.nil _, @List.nil _ => isTrue✝ rfl✝ + | List.nil .., List.cons .. => isFalse✝ (by intro h✝; injection h✝) + | List.cons .., List.nil .. => isFalse✝ (by intro h✝; injection h✝) + | @List.cons _ a✝ a✝¹, @List.cons _ b✝ b✝¹ => + let inst✝ := decEqTree'✝ @a✝ @b✝; + if h✝¹ : @a✝ = @b✝ then by subst h✝¹; + exact + let inst✝¹ := decEqListTree'✝ @a✝¹ @b✝¹; + if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝¹ rfl✝¹ + else isFalse✝¹ (by intro n✝; injection n✝; contradiction) + else isFalse✝² (by intro n✝¹; injection n✝¹; contradiction) + termination_by structural x✝ + private def decEqOptionListTree'✝ {α✝} [DecidableEq✝ α✝] (x✝² : @Option✝ (@List✝ (@Tree'✝ α✝))) + (x✝³ : @Option✝ (@List✝ (@Tree'✝ α✝))) : Decidable✝ (x✝² = x✝³) := + match x✝², x✝³ with + | @Option.none _, @Option.none _ => isTrue✝ rfl✝ + | Option.none .., Option.some .. => isFalse✝ (by intro h✝; injection h✝) + | Option.some .., Option.none .. => isFalse✝ (by intro h✝; injection h✝) + | @Option.some _ a✝², @Option.some _ b✝² => + let inst✝² := decEqListTree'✝ @a✝² @b✝²; + if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² + else isFalse✝³ (by intro n✝²; injection n✝²; contradiction) + termination_by structural x✝² + private def decEqTree'✝ {α✝¹} [DecidableEq✝ α✝¹] (x✝⁴ : @Tree'✝ α✝¹) (x✝⁵ : @Tree'✝ α✝¹) : + Decidable✝ (x✝⁴ = x✝⁵) := + match x✝⁴, x✝⁵ with + | @Tree'.node _ a✝³ a✝⁴, @Tree'.node _ b✝³ b✝⁴ => + if h✝⁴ : @a✝³ = @b✝³ then by subst h✝⁴; + exact + let inst✝³ := decEqOptionListTree'✝ @a✝⁴ @b✝⁴; + if h✝⁵ : @a✝⁴ = @b✝⁴ then by subst h✝⁵; exact isTrue✝³ rfl✝³ + else isFalse✝⁴ (by intro n✝³; injection n✝³; contradiction) + else isFalse✝⁵ (by intro n✝⁴; injection n✝⁴; contradiction) + termination_by structural x✝⁴ + end, + instance {α✝} [DecidableEq✝ α✝] : DecidableEq✝ (@Tree'✝ α✝) := + decEqTree'✝] [Elab.Deriving.decEq] [mutual - private def decEqFoo₁✝ (x✝ : @Foo₁✝) (x✝¹ : @Foo₁✝) : Decidable✝ (x✝ = x✝¹) := + private def decEqFoo₃✝ (x✝ : @Foo₃✝) (x✝¹ : @Foo₃✝) : Decidable✝ (x✝ = x✝¹) := match x✝, x✝¹ with - | @Foo₁.foo₁₁, @Foo₁.foo₁₁ => isTrue✝ rfl✝ - | Foo₁.foo₁₁ .., Foo₁.foo₁₂ .. => isFalse✝ (by intro h✝; injection h✝) - | Foo₁.foo₁₂ .., Foo₁.foo₁₁ .. => isFalse✝ (by intro h✝; injection h✝) - | @Foo₁.foo₁₂ a✝, @Foo₁.foo₁₂ b✝ => - let inst✝ := decEqFoo₂✝ @a✝ @b✝; - if h✝¹ : @a✝ = @b✝ then by subst h✝¹; exact isTrue✝¹ rfl✝¹ - else isFalse✝¹ (by intro n✝; injection n✝; apply h✝¹ _; assumption) + | @Foo₃.foo₃ a✝, @Foo₃.foo₃ b✝ => + let inst✝ := decEqFoo₁✝ @a✝ @b✝; + if h✝ : @a✝ = @b✝ then by subst h✝; exact isTrue✝ rfl✝ + else isFalse✝ (by intro n✝; injection n✝; contradiction) termination_by structural x✝ private def decEqFoo₂✝ (x✝² : @Foo₂✝) (x✝³ : @Foo₂✝) : Decidable✝ (x✝² = x✝³) := match x✝², x✝³ with | @Foo₂.foo₂ a✝¹, @Foo₂.foo₂ b✝¹ => let inst✝¹ := decEqFoo₃✝ @a✝¹ @b✝¹; - if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝² rfl✝² - else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝² _; assumption) + if h✝¹ : @a✝¹ = @b✝¹ then by subst h✝¹; exact isTrue✝¹ rfl✝¹ + else isFalse✝¹ (by intro n✝¹; injection n✝¹; contradiction) termination_by structural x✝² - private def decEqFoo₃✝ (x✝⁴ : @Foo₃✝) (x✝⁵ : @Foo₃✝) : Decidable✝ (x✝⁴ = x✝⁵) := + private def decEqFoo₁✝ (x✝⁴ : @Foo₁✝) (x✝⁵ : @Foo₁✝) : Decidable✝ (x✝⁴ = x✝⁵) := match x✝⁴, x✝⁵ with - | @Foo₃.foo₃ a✝², @Foo₃.foo₃ b✝² => - let inst✝² := decEqFoo₁✝ @a✝² @b✝²; + | @Foo₁.foo₁₁, @Foo₁.foo₁₁ => isTrue✝² rfl✝² + | Foo₁.foo₁₁ .., Foo₁.foo₁₂ .. => isFalse✝² (by intro h✝²; injection h✝²) + | Foo₁.foo₁₂ .., Foo₁.foo₁₁ .. => isFalse✝² (by intro h✝²; injection h✝²) + | @Foo₁.foo₁₂ a✝², @Foo₁.foo₁₂ b✝² => + let inst✝² := decEqFoo₂✝ @a✝² @b✝²; if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝³ rfl✝³ - else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝³ _; assumption) + else isFalse✝³ (by intro n✝²; injection n✝²; contradiction) termination_by structural x✝⁴ end, + instance : DecidableEq✝ (@Foo₃✝) := + decEqFoo₃✝, + instance : DecidableEq✝ (@Foo₂✝) := + decEqFoo₂✝, instance : DecidableEq✝ (@Foo₁✝) := decEqFoo₁✝] +[Elab.Deriving.decEq] + [mutual + private def decEqListMin'✝ (x✝ : @List✝ (@Min'✝)) (x✝¹ : @List✝ (@Min'✝)) : Decidable✝ (x✝ = x✝¹) := + match x✝, x✝¹ with + | @List.nil _, @List.nil _ => isTrue✝ rfl✝ + | List.nil .., List.cons .. => isFalse✝ (by intro h✝; injection h✝) + | List.cons .., List.nil .. => isFalse✝ (by intro h✝; injection h✝) + | @List.cons _ a✝ a✝¹, @List.cons _ b✝ b✝¹ => + let inst✝ := decEqMin'✝ @a✝ @b✝; + if h✝¹ : @a✝ = @b✝ then by subst h✝¹; + exact + let inst✝¹ := decEqListMin'✝ @a✝¹ @b✝¹; + if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝¹ rfl✝¹ + else isFalse✝¹ (by intro n✝; injection n✝; contradiction) + else isFalse✝² (by intro n✝¹; injection n✝¹; contradiction) + termination_by structural x✝ + private def decEqMin'✝ (x✝² : @Min'✝) (x✝³ : @Min'✝) : Decidable✝ (x✝² = x✝³) := + match x✝², x✝³ with + | @Min'.Base, @Min'.Base => isTrue✝ rfl✝ + | Min'.Base .., Min'.Const .. => isFalse✝ (by intro h✝; injection h✝) + | Min'.Const .., Min'.Base .. => isFalse✝ (by intro h✝; injection h✝) + | @Min'.Const a✝², @Min'.Const b✝² => + let inst✝² := decEqListMin'✝ @a✝² @b✝²; + if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² + else isFalse✝³ (by intro n✝²; injection n✝²; contradiction) + termination_by structural x✝² + end, + instance : DecidableEq✝ (@Min'✝) := + decEqMin'✝] +[Elab.Deriving.decEq] + [mutual + private def decEqComplexInductiveNestedComplex✝ {A✝} {C✝} [DecidableEq✝ A✝] [DecidableEq✝ A✝] [DecidableEq✝ C✝] + (x✝ : @ComplexInductive✝ A✝ (@NestedComplex✝ A✝ C✝) C✝ 1) + (x✝¹ : @ComplexInductive✝ A✝ (@NestedComplex✝ A✝ C✝) C✝ 1) : Decidable✝ (x✝ = x✝¹) := + match x✝, x✝¹ with + | @ComplexInductive.constr _ _ _ _ a✝ a✝¹ a✝², @ComplexInductive.constr _ _ _ _ b✝ b✝¹ b✝² => + if h✝ : @a✝ = @b✝ then by subst h✝; + exact + let inst✝ := decEqNestedComplex✝ @a✝¹ @b✝¹; + if h✝¹ : @a✝¹ = @b✝¹ then by subst h✝¹; + exact + if h✝² : @a✝² = @b✝² then by subst h✝²; exact isTrue✝ rfl✝ + else isFalse✝ (by intro n✝; injection n✝; contradiction) + else isFalse✝¹ (by intro n✝¹; injection n✝¹; contradiction) + else isFalse✝² (by intro n✝²; injection n✝²; contradiction) + termination_by structural x✝ + private def decEqNestedComplex✝ {A✝¹} {C✝¹} [DecidableEq✝ A✝¹] [DecidableEq✝ C✝¹] (x✝² : @NestedComplex✝ A✝¹ C✝¹) + (x✝³ : @NestedComplex✝ A✝¹ C✝¹) : Decidable✝ (x✝² = x✝³) := + match x✝², x✝³ with + | @NestedComplex.constr _ _ a✝³, @NestedComplex.constr _ _ b✝³ => + let inst✝¹ := decEqComplexInductiveNestedComplex✝ @a✝³ @b✝³; + if h✝³ : @a✝³ = @b✝³ then by subst h✝³; exact isTrue✝¹ rfl✝¹ + else isFalse✝³ (by intro n✝³; injection n✝³; contradiction) + termination_by structural x✝² + end, + instance {A✝} {C✝} [DecidableEq✝ A✝] [DecidableEq✝ C✝] : DecidableEq✝ (@NestedComplex✝ A✝ C✝) := + decEqNestedComplex✝] +[Elab.Deriving.decEq] + [mutual + private def decEqListTree✝ {α✝} [DecidableEq✝ α✝] (x✝ : @List✝ (@nested.Tree✝ α✝)) + (x✝¹ : @List✝ (@nested.Tree✝ α✝)) : Decidable✝ (x✝ = x✝¹) := + match x✝, x✝¹ with + | @List.nil _, @List.nil _ => isTrue✝ rfl✝ + | List.nil .., List.cons .. => isFalse✝ (by intro h✝; injection h✝) + | List.cons .., List.nil .. => isFalse✝ (by intro h✝; injection h✝) + | @List.cons _ a✝ a✝¹, @List.cons _ b✝ b✝¹ => + let inst✝ := decEqTree✝ @a✝ @b✝; + if h✝¹ : @a✝ = @b✝ then by subst h✝¹; + exact + let inst✝¹ := decEqListTree✝ @a✝¹ @b✝¹; + if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝¹ rfl✝¹ + else isFalse✝¹ (by intro n✝; injection n✝; contradiction) + else isFalse✝² (by intro n✝¹; injection n✝¹; contradiction) + termination_by structural x✝ + private def decEqArrayTree✝ {α✝¹} [DecidableEq✝ α✝¹] (x✝² : @Array✝ (@nested.Tree✝ α✝¹)) + (x✝³ : @Array✝ (@nested.Tree✝ α✝¹)) : Decidable✝ (x✝² = x✝³) := + match x✝², x✝³ with + | @Array.mk _ a✝², @Array.mk _ b✝² => + let inst✝² := decEqListTree✝ @a✝² @b✝²; + if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² + else isFalse✝³ (by intro n✝²; injection n✝²; contradiction) + termination_by structural x✝² + private def decEqTree✝ {α✝²} [DecidableEq✝ α✝²] (x✝⁴ : @nested.Tree✝ α✝²) (x✝⁵ : @nested.Tree✝ α✝²) : + Decidable✝ (x✝⁴ = x✝⁵) := + match x✝⁴, x✝⁵ with + | @nested.Tree.node _ a✝³, @nested.Tree.node _ b✝³ => + let inst✝³ := decEqArrayTree✝ @a✝³ @b✝³; + if h✝⁴ : @a✝³ = @b✝³ then by subst h✝⁴; exact isTrue✝³ rfl✝³ + else isFalse✝⁴ (by intro n✝³; injection n✝³; contradiction) + termination_by structural x✝⁴ + end, + instance {α✝} [DecidableEq✝ α✝] : DecidableEq✝ (@nested.Tree✝ α✝) := + decEqTree✝] +[Elab.Deriving.decEq] + [mutual + private def decEqListMess1✝ (x✝ : @List✝ (@mess.Mess1✝)) (x✝¹ : @List✝ (@mess.Mess1✝)) : Decidable✝ (x✝ = x✝¹) := + match x✝, x✝¹ with + | @List.nil _, @List.nil _ => isTrue✝ rfl✝ + | List.nil .., List.cons .. => isFalse✝ (by intro h✝; injection h✝) + | List.cons .., List.nil .. => isFalse✝ (by intro h✝; injection h✝) + | @List.cons _ a✝ a✝¹, @List.cons _ b✝ b✝¹ => + let inst✝ := decEqMess1✝ @a✝ @b✝; + if h✝¹ : @a✝ = @b✝ then by subst h✝¹; + exact + let inst✝¹ := decEqListMess1✝ @a✝¹ @b✝¹; + if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝¹ rfl✝¹ + else isFalse✝¹ (by intro n✝; injection n✝; contradiction) + else isFalse✝² (by intro n✝¹; injection n✝¹; contradiction) + termination_by structural x✝ + private def decEqArrayMess1✝ (x✝² : @Array✝ (@mess.Mess1✝)) (x✝³ : @Array✝ (@mess.Mess1✝)) : + Decidable✝ (x✝² = x✝³) := + match x✝², x✝³ with + | @Array.mk _ a✝², @Array.mk _ b✝² => + let inst✝² := decEqListMess1✝ @a✝² @b✝²; + if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² + else isFalse✝³ (by intro n✝²; injection n✝²; contradiction) + termination_by structural x✝² + private def decEqListMess2✝ (x✝⁴ : @List✝ (@mess.Mess2✝)) (x✝⁵ : @List✝ (@mess.Mess2✝)) : + Decidable✝ (x✝⁴ = x✝⁵) := + match x✝⁴, x✝⁵ with + | @List.nil _, @List.nil _ => isTrue✝ rfl✝ + | List.nil .., List.cons .. => isFalse✝ (by intro h✝; injection h✝) + | List.cons .., List.nil .. => isFalse✝ (by intro h✝; injection h✝) + | @List.cons _ a✝³ a✝⁴, @List.cons _ b✝³ b✝⁴ => + let inst✝³ := decEqMess2✝ @a✝³ @b✝³; + if h✝⁴ : @a✝³ = @b✝³ then by subst h✝⁴; + exact + let inst✝⁴ := decEqListMess2✝ @a✝⁴ @b✝⁴; + if h✝⁵ : @a✝⁴ = @b✝⁴ then by subst h✝⁵; exact isTrue✝³ rfl✝³ + else isFalse✝⁴ (by intro n✝³; injection n✝³; contradiction) + else isFalse✝⁵ (by intro n✝⁴; injection n✝⁴; contradiction) + termination_by structural x✝⁴ + private def decEqArrayMess2✝ (x✝⁶ : @Array✝ (@mess.Mess2✝)) (x✝⁷ : @Array✝ (@mess.Mess2✝)) : + Decidable✝ (x✝⁶ = x✝⁷) := + match x✝⁶, x✝⁷ with + | @Array.mk _ a✝⁵, @Array.mk _ b✝⁵ => + let inst✝⁵ := decEqListMess2✝ @a✝⁵ @b✝⁵; + if h✝⁶ : @a✝⁵ = @b✝⁵ then by subst h✝⁶; exact isTrue✝⁴ rfl✝⁴ + else isFalse✝⁶ (by intro n✝⁵; injection n✝⁵; contradiction) + termination_by structural x✝⁶ + private def decEqMess2✝ (x✝⁸ : @mess.Mess2✝) (x✝⁹ : @mess.Mess2✝) : Decidable✝ (x✝⁸ = x✝⁹) := + match x✝⁸, x✝⁹ with + | @mess.Mess2.node a✝⁶, @mess.Mess2.node b✝⁶ => + let inst✝⁶ := decEqArrayMess1✝ @a✝⁶ @b✝⁶; + if h✝⁷ : @a✝⁶ = @b✝⁶ then by subst h✝⁷; exact isTrue✝⁵ rfl✝⁵ + else isFalse✝⁷ (by intro n✝⁶; injection n✝⁶; contradiction) + termination_by structural x✝⁸ + private def decEqMess1✝ (x✝¹⁰ : @mess.Mess1✝) (x✝¹¹ : @mess.Mess1✝) : Decidable✝ (x✝¹⁰ = x✝¹¹) := + match x✝¹⁰, x✝¹¹ with + | @mess.Mess1.node a✝⁷, @mess.Mess1.node b✝⁷ => + let inst✝⁷ := decEqArrayMess2✝ @a✝⁷ @b✝⁷; + if h✝⁸ : @a✝⁷ = @b✝⁷ then by subst h✝⁸; exact isTrue✝⁶ rfl✝⁶ + else isFalse✝⁸ (by intro n✝⁷; injection n✝⁷; contradiction) + termination_by structural x✝¹⁰ + end, + instance : DecidableEq✝ (@mess.Mess2✝) := + decEqMess2✝, + instance : DecidableEq✝ (@mess.Mess1✝) := + decEqMess1✝] diff --git a/tests/lean/run/toFromJson.lean b/tests/lean/run/toFromJson.lean index 2c9165b7e70e..75f1609e1dad 100644 --- a/tests/lean/run/toFromJson.lean +++ b/tests/lean/run/toFromJson.lean @@ -42,7 +42,6 @@ def checkRoundTrip [Repr α] [BEq α] [ToJson α] [FromJson α] (obj : α) : Met else throwError "couldn't parse: {repr obj} ≟ {obj |> toJson}" --- set_option trace.Meta.debug true in structure Foo where x : Nat y : String