diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index 5e9cb3cc5bba..f064ab8f2713 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -12,15 +12,19 @@ namespace Lean.Elab.Deriving.BEq open Lean.Parser.Term open Meta -def mkBEqHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `BEq 2 indVal +def mkBEqHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + mkHeader `BEq 2 argNames nestedOcc -def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkElseAlt : TermElabM (TSyntax ``matchAltExpr) := do + mkElseAlt (indVal : InductiveVal): TermElabM (TSyntax ``matchAltExpr) := do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -30,11 +34,15 @@ where let altRhs ← `(false) `(matchAltExpr| | $[$patterns:term],* => $altRhs:term) - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let subargs := args[:ctorInfo.numParams] + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) subargs + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs type => do let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies let mut patterns := #[] -- add `_` pattern for indices @@ -45,7 +53,7 @@ where let mut rhs ← `(true) let mut rhs_empty := true for i in [:ctorInfo.numFields] do - let pos := indVal.numParams + ctorInfo.numFields - i - 1 + let pos := indVal.numParams + ctorInfo.numFields - subargs.size - i - 1 let x := xs[pos]! if type.containsFVar x.fvarId! then -- If resulting type depends on this field, we don't need to compare @@ -59,7 +67,7 @@ where let xType ← inferType x if (← isProp xType) then continue - if xType.isAppOf indVal.name then + if let some auxFunName ← ctx.getFunName? header xType fvars then if rhs_empty then rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident) rhs_empty := false @@ -88,19 +96,21 @@ where patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2.reverse:term*)) `(matchAltExpr| | $[$patterns:term],* => $rhs:term) alts := alts.push alt - alts := alts.push (← mkElseAlt) + alts := alts.push (← mkElseAlt indVal) return alts def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkBEqHeader indVal - let mut body ← mkMatch header indVal auxFunName + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkBEqHeader argNames nestedOcc + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkMatch ctx header type xs if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames body ← mkLet letDecls body - let binders := header.binders - if ctx.usePartial then `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) @@ -115,8 +125,8 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName]) + let ctx ← mkContext "beq" declName false + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq) trace[Elab.Deriving.beq] "\n{cmds}" return cmds @@ -125,8 +135,8 @@ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do `(private def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx) private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" name - let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name]) + let ctx ← mkContext "beq" name false + let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq) trace[Elab.Deriving.beq] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean index a3b6fbf5e0b5..6cfd25d3b389 100644 --- a/src/Lean/Elab/Deriving/DecEq.lean +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -6,19 +6,25 @@ Authors: Leonardo de Moura prelude import Lean.Meta.Transform import Lean.Meta.Inductive +import Lean.Meta.InferType import Lean.Elab.Deriving.Basic import Lean.Elab.Deriving.Util +import Lean.Elab.Binders namespace Lean.Elab.Deriving.DecEq open Lean.Parser.Term open Meta -def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `DecidableEq 2 indVal +def mkDecEqHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + mkHeader `DecidableEq 2 argNames nestedOcc -def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term @@ -31,14 +37,14 @@ where `(if h : @$a = @$b then by subst h; exact $sameCtor:term else - isFalse (by intro n; injection n; apply h _; assumption)) + isFalse (by intro n; injection n; contradiction)) if let some auxFunName := recField then -- add local instance for `a = b` using the function being defined `auxFunName` `(let inst := $(mkIdent auxFunName) @$a @$b; $rhs) else return rhs - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName₁ in indVal.ctors do let ctorInfo ← getConstInfoCtor ctorName₁ @@ -48,7 +54,10 @@ where for _ in [:indVal.numIndices] do patterns := patterns.push (← `(_)) if ctorName₁ == ctorName₂ then - let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let args := e.getAppArgs + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs type => do let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies let mut patterns := patterns let mut ctorArgs1 := #[] @@ -59,7 +68,7 @@ where ctorArgs2 := ctorArgs2.push (← `(_)) let mut todo := #[] for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! if type.containsFVar x.fvarId! then -- If resulting type depends on this field, we don't need to compare ctorArgs1 := ctorArgs1.push (← `(_)) @@ -70,11 +79,8 @@ where ctorArgs1 := ctorArgs1.push a ctorArgs2 := ctorArgs2.push b let xType ← inferType x - let indValNum := - ctx.typeInfos.findIdx? - (xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal) - let recField := indValNum.map (ctx.auxFunNames[·]!) - let isProof ← isProp xType + let recField ← ctx.getFunName? header xType fvars + let isProof ← isProp xType todo := todo.push (a, b, recField, isProof) patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*)) patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*)) @@ -87,47 +93,43 @@ where alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term)) return alts -def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do - let header ← mkDecEqHeader indVal - let body ← mkMatch ctx header indVal +def mkAuxFunction (ctx : Context) (auxFunName : Name) (argNames : Array Name) (nestedOcc : NestedOccurence): TermElabM (TSyntax `command) := do + let header ← mkDecEqHeader argNames nestedOcc let binders := header.binders let target₁ := mkIdent header.targetNames[0]! let target₂ := mkIdent header.targetNames[1]! - let termSuffix ← if indVal.isRec + let termSuffix ← if ctx.auxFunNames.size > 1 || nestedOcc.getIndVal.isRec then `(Parser.Termination.suffix|termination_by structural $target₁) else `(Parser.Termination.suffix|) - let type ← `(Decidable ($target₁ = $target₂)) + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let body ← mkMatch ctx header type xs + let type ← `(Decidable ($target₁ = $target₂)) `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term $termSuffix:suffix) def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do let mut res : Array (TSyntax `command) := #[] for i in [:ctx.auxFunNames.size] do - let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - res := res.push (← mkAuxFunction ctx auxFunName indVal) + let auxFunName := ctx.auxFunNames[i]! + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + res := res.push (← mkAuxFunction ctx auxFunName argNames nestedOcc) `(command| mutual $[$res:command]* end) def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do let ctx ← mkContext "decEq" indVal.name - let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false)) + let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq (useAnonCtor := false)) trace[Elab.Deriving.decEq] "\n{cmds}" return cmds open Command -def mkDecEq (declName : Name) : CommandElabM Bool := do +def mkDecEq (declName : Name) : CommandElabM Unit := do let indVal ← getConstInfoInduct declName - if indVal.isNested then - return false -- nested inductive types are not supported yet - else - let cmds ← liftTermElabM <| mkDecEqCmds indVal - -- `cmds` can have a number of syntax nodes quadratic in the number of constructors - -- and thus create as many info tree nodes, which we never make use of but which can - -- significantly slow down e.g. the unused variables linter; avoid creating them - withEnableInfoTree false do - cmds.forM elabCommand - return true + let cmds ← liftTermElabM <| mkDecEqCmds indVal + withEnableInfoTree false do + cmds.forM elabCommand partial def mkEnumOfNat (declName : Name) : MetaM Unit := do let indVal ← getConstInfoInduct declName @@ -195,15 +197,15 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do trace[Elab.Deriving.decEq] "\n{cmd}" elabCommand cmd -def mkDecEqInstance (declName : Name) : CommandElabM Bool := do +def mkDecEqInstance (declName : Name) : CommandElabM Unit := do if (← isEnumType declName) then mkDecEqEnum declName - return true else mkDecEq declName def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true + declNames.forM mkDecEqInstance + return true builtin_initialize registerDerivingHandler `DecidableEq mkDecEqInstanceHandler diff --git a/src/Lean/Elab/Deriving/FromToJson.lean b/src/Lean/Elab/Deriving/FromToJson.lean index 2f8f2ce5f84c..6e517a4ce47c 100644 --- a/src/Lean/Elab/Deriving/FromToJson.lean +++ b/src/Lean/Elab/Deriving/FromToJson.lean @@ -15,148 +15,130 @@ open Lean.Json open Lean.Parser.Term open Lean.Meta +def mkToJsonHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + mkHeader ``ToJson 1 argNames nestedOcc + +def mkFromJsonHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + let header ← mkHeader ``FromJson 0 argNames nestedOcc + let jsonArg ← `(bracketedBinderF|(json : Json)) + return {header with + binders := header.binders.push jsonArg} + def mkJsonField (n : Name) : CoreM (Bool × Term) := do let .str .anonymous s := n | throwError "invalid json field name {n}" let s₁ := s.dropRightWhile (· == '?') return (s != s₁, Syntax.mkStrLit s₁) -def mkToJsonInstance (declName : Name) : CommandElabM Bool := do - if isStructure (← getEnv) declName then - let cmds ← liftTermElabM do - let ctx ← mkContext "toJson" declName - let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! - let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) - let fields ← fields.mapM fun field => do - let (isOptField, nm) ← mkJsonField field - let target := mkIdent header.targetNames[0]! - if isOptField then ``(opt $nm ($target).$(mkIdent field)) - else ``([($nm, toJson ($target).$(mkIdent field))]) - let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* : Json := - mkObj <| List.join [$fields,*]) - return #[cmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) - cmds.forM elabCommand - return true - else - let indVal ← getConstInfoInduct declName - let cmds ← liftTermElabM do - let ctx ← mkContext "toJson" declName - let toJsonFuncId := mkIdent ctx.auxFunNames[0]! - -- Return syntax to JSONify `id`, either via `ToJson` or recursively - -- if `id`'s type is the type we're deriving for. - let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do - if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident) - else ``(toJson $id:ident) - let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! - let discrs ← mkDiscrs header indVal - let alts ← mkAlts indVal fun ctor args userNames => do - let ctorStr := ctor.name.eraseMacroScopes.getString! - match args, userNames with - | #[], _ => ``(toJson $(quote ctorStr)) - | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))]) - | xs, none => - let xs ← xs.mapM fun (x, t) => mkToJson x t - ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) - | xs, some userNames => - let xs ← xs.mapIdxM fun idx (x, t) => do - `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t))) - ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) - let auxTerm ← `(match $[$discrs],* with $alts:matchAlt*) - let auxCmd ← - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames - let auxTerm ← mkLet letDecls auxTerm - `(private partial def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) - else - `(private def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) - return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) - cmds.forM elabCommand - return true +def mkToJson (ctx : Context) (header : Header) (id : TSyntax `term) (type : Expr) (fvars : Array Expr) : TermElabM Term := do + if let some auxFunName ← ctx.getFunName? header type fvars then + `($(mkIdent auxFunName):ident $id) + else ``(toJson $id) + +def mkToJsonBodyForStruct (header : Header) (e : Expr) : TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! + let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false) + let fields ← fields.mapM fun field => do + let (isOptField, nm) ← mkJsonField field + let target := mkIdent header.targetNames[0]! + if isOptField then ``(opt $nm $target.$(mkIdent field)) + else ``([($nm, toJson ($target).$(mkIdent field))]) + `(mkObj <| List.join [$fields,*]) + +def mkToJsonBodyForInduct (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct indName + -- Return syntax to JSONify `id`, either via `ToJson` or recursively + -- if `id`'s type is the type we're deriving for. + + let discrs ← mkDiscrs header indVal + let alts ← mkAlts indVal lvls fun ctor args userNames => do + let ctorStr := ctor.name.eraseMacroScopes.getString! + match args, userNames with + | #[], _ => ``(toJson $(quote ctorStr)) + | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson ctx header x t fvars))]) + | xs, none => + let xs ← xs.mapM fun (x, t) => mkToJson ctx header x t fvars + ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) + | xs, some userNames => + let xs ← xs.mapIdxM fun idx (x, t) => do + `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson ctx header x t fvars))) + ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) + `(match $[$discrs],* with $alts:matchAlt*) where mkAlts - (indVal : InductiveVal) + (indVal : InductiveVal) (lvl : List Level) (rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term) : TermElabM (Array (TSyntax ``matchAlt)) := do - indVal.ctors.toArray.mapM fun ctor => do - let ctorInfo ← getConstInfoCtor ctor - forallTelescopeReducing ctorInfo.type fun xs _ => do - let mut patterns := #[] - -- add `_` pattern for indices - for _ in [:indVal.numIndices] do - patterns := patterns.push (← `(_)) - let mut ctorArgs := #[] - -- add `_` for inductive parameters, they are inaccessible - for _ in [:indVal.numParams] do - ctorArgs := ctorArgs.push (← `(_)) - -- bound constructor arguments and their types - let mut binders := #[] - let mut userNames := #[] - for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! - let localDecl ← x.fvarId!.getDecl - if !localDecl.userName.hasMacroScopes then - userNames := userNames.push localDecl.userName - let a := mkIdent (← mkFreshUserName `a) - binders := binders.push (a, localDecl.type) - ctorArgs := ctorArgs.push a - patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*)) - let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none) - `(matchAltExpr| | $[$patterns:term],* => $rhs:term) - -def mkFromJsonInstance (declName : Name) : CommandElabM Bool := do - if isStructure (← getEnv) declName then - let cmds ← liftTermElabM do - let ctx ← mkContext "fromJson" declName - let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! - let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) - let getters ← fields.mapM (fun field => do - let getter ← `(getObjValAs? j _ $(Prod.snd <| ← mkJsonField field)) - let getter ← `(doElem| Except.mapError (fun s => (toString $(quote declName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter) - return getter - ) - let fields := fields.map mkIdent - let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* (j : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := do - $[let $fields:ident ← $getters]* - return { $[$fields:ident := $(id fields)],* }) - return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) - cmds.forM elabCommand - return true - else - let indVal ← getConstInfoInduct declName - let cmds ← liftTermElabM do - let ctx ← mkContext "fromJson" declName - let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! - let fromJsonFuncId := mkIdent ctx.auxFunNames[0]! - let alts ← mkAlts indVal fromJsonFuncId - let mut auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames - auxTerm ← mkLet letDecls auxTerm - -- FromJson is not structurally recursive even non-nested recursive inductives, - -- so we also use `partial` then. - let auxCmd ← - if ctx.usePartial || indVal.isRec then - `(private partial def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := - $auxTerm) - else - `(private def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := - $auxTerm) - return #[auxCmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) - cmds.forM elabCommand - return true + let mut alts := #[] + for ctorName in indVal.ctors do + let args := e.getAppArgs + let ctorInfo ← getConstInfoCtor ctorName + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs _ => do + let mut patterns := #[] + -- add `_` pattern for indices + for _ in [:indVal.numIndices] do + patterns := patterns.push (← `(_)) + let mut ctorArgs := #[] + -- add `_` for inductive parameters, they are inaccessible + for _ in [:indVal.numParams] do + ctorArgs := ctorArgs.push (← `(_)) + -- bound constructor arguments and their types + let mut binders := #[] + let mut userNames := #[] + for i in [:ctorInfo.numFields] do + let x := xs[i]! + let localDecl ← x.fvarId!.getDecl + if !localDecl.userName.hasMacroScopes then + userNames := userNames.push localDecl.userName + let a := mkIdent (← mkFreshUserName `a) + binders := binders.push (a, localDecl.type) + ctorArgs := ctorArgs.push a + patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*)) + let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none) + `(matchAltExpr| | $[$patterns:term],* => $rhs:term) + alts := alts.push alt + return alts + +def mkFromJsonBodyForStruct (e : Expr) : TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! + let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false) + let getters ← fields.mapM (fun field => do + let getter ← `(getObjValAs? json _ $(Prod.snd <| ← mkJsonField field)) + let getter ← `(doElem| Except.mapError (fun s => (toString $(quote indName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter) + return getter + ) + let fields := fields.map mkIdent + `(do + $[let $fields:ident ← $getters]* + return { $[$fields:ident := $(id fields)],* }) +def mkFromJsonBodyForInduct (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + let f := e.getAppFn + let indName := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct indName + let alts ← mkAlts indVal lvls + let auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) + `($auxTerm) where - mkAlts (indVal : InductiveVal) (fromJsonFuncId : Ident) : TermElabM (Array Term) := do - let alts ← - indVal.ctors.toArray.mapM fun ctor => do - let ctorInfo ← getConstInfoCtor ctor - forallTelescopeReducing ctorInfo.type fun xs _ => do + mkAlts (indVal : InductiveVal) (lvl : List Level): TermElabM (Array Term) := do + let mut alts := #[] + for ctorName in indVal.ctors do + let args := e.getAppArgs + let ctorInfo ← getConstInfoCtor ctorName + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← do forallTelescopeReducing ctorType fun xs _ => do let mut binders := #[] let mut userNames := #[] for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let localDecl ← x.fvarId!.getDecl if !localDecl.userName.hasMacroScopes then userNames := userNames.push localDecl.userName @@ -165,9 +147,11 @@ where -- Return syntax to parse `id`, either via `FromJson` or recursively -- if `id`'s type is the type we're deriving for. - let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) := - if type.isAppOf indVal.name then `(Lean.Parser.Term.doExpr| $fromJsonFuncId:ident jsons[$(quote idx)]!) - else `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)]!) + let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) := do + if let some auxFunName ← ctx.getFunName? header type fvars then + `(Lean.Parser.Term.doExpr| $(mkIdent auxFunName) jsons[$(quote idx)]!) + else + `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)]!) let identNames := binders.map Prod.fst let fromJsons ← binders.mapIdxM fun idx (_, type) => mkFromJson idx type let userNamesOpt ← if binders.size == userNames.size then @@ -175,23 +159,112 @@ where else ``(none) let stx ← - `((Json.parseTagged json $(quote ctor.eraseMacroScopes.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind + `((Json.parseTagged json $(quote ctorName.eraseMacroScopes.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind (fun jsons => do $[let $identNames:ident ← $fromJsons:doExpr]* - return $(mkIdent ctor):ident $identNames*)) + return $(mkIdent ctorName):ident $identNames*)) pure (stx, ctorInfo.numFields) + alts := alts.push alt -- the smaller cases, especially the ones without fields are likely faster - let alts := alts.qsort (fun (_, x) (_, y) => x < y) - return alts.map Prod.fst + let alts' := alts.qsort (fun (_, x) (_, y) => x < y) + return alts'.map Prod.fst + +def mkToJsonBody (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + if isStructure (← getEnv) e.getAppFn.constName! then + mkToJsonBodyForStruct header e + else + mkToJsonBodyForInduct ctx header e fvars + +def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do + let auxFunName := ctx.auxFunNames[i]! + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkToJsonHeader argNames nestedOcc + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkToJsonBody ctx header type xs + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) + +def mkFromJsonBody (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + if isStructure (← getEnv) e.getAppFn.constName! then + mkFromJsonBodyForStruct e + else + mkFromJsonBodyForInduct ctx header e fvars + +def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do + let auxFunName := ctx.auxFunNames[i]! + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkFromJsonHeader argNames nestedOcc --TODO fix header info + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkFromJsonBody ctx header type xs + if ctx.usePartial || nestedOcc.getIndVal.isRec then + let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) + + +def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do + let mut auxDefs := #[] + for i in [:ctx.typeInfos.size] do + auxDefs := auxDefs.push (← mkToJsonAuxFunction ctx i) + `(mutual + $auxDefs:command* + end) + +def mkFromJsonMutualBlock (ctx : Context) : TermElabM Command := do + let mut auxDefs := #[] + for i in [:ctx.typeInfos.size] do + auxDefs := auxDefs.push (← mkFromJsonAuxFunction ctx i) + `(mutual + $auxDefs:command* + end) + +private def mkToJsonInstance (declName : Name) : TermElabM (Array Command) := do + let ctx ← mkContext "toJson" declName false + let cmds := #[← mkToJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``ToJson) + trace[Elab.Deriving.toJson] "\n{cmds}" + return cmds + +private def mkFromJsonInstance (declName : Name) : TermElabM (Array Command) := do + let ctx ← mkContext "fromJson" declName false + let cmds := #[← mkFromJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``FromJson) + trace[Elab.Deriving.fromJson] "\n{cmds}" + return cmds def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - declNames.foldlM (fun b n => andM (pure b) (mkToJsonInstance n)) true + if (← declNames.allM isInductive) && declNames.size > 0 then + for declName in declNames do + let cmds ← liftTermElabM <| mkToJsonInstance declName + cmds.forM elabCommand + return true + else + return false def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - declNames.foldlM (fun b n => andM (pure b) (mkFromJsonInstance n)) true + if (← declNames.allM isInductive) && declNames.size > 0 then + for declName in declNames do + let cmds ← liftTermElabM <| mkFromJsonInstance declName + cmds.forM elabCommand + return true + else + return false builtin_initialize registerDerivingHandler ``ToJson mkToJsonInstanceHandler registerDerivingHandler ``FromJson mkFromJsonInstanceHandler + registerTraceClass `Elab.Deriving.toJson + registerTraceClass `Elab.Deriving.fromJson + end Lean.Elab.Deriving.FromToJson diff --git a/src/Lean/Elab/Deriving/Hashable.lean b/src/Lean/Elab/Deriving/Hashable.lean index be7c5ee1d85a..f2b4f25f1841 100644 --- a/src/Lean/Elab/Deriving/Hashable.lean +++ b/src/Lean/Elab/Deriving/Hashable.lean @@ -13,22 +13,28 @@ open Command open Lean.Parser.Term open Meta -def mkHashableHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `Hashable 1 indVal +def mkHashableHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + mkHeader `Hashable 1 argNames nestedOcc -def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] let mut ctorIdx := 0 - let allIndVals := indVal.all.toArray for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs _ => do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -39,16 +45,14 @@ where for _ in [:indVal.numParams] do ctorArgs := ctorArgs.push (← `(_)) for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let a := mkIdent (← mkFreshUserName `a) ctorArgs := ctorArgs.push a - let xTy ← whnf (← inferType x) - match xTy.getAppFn with - | .const declName .. => - match allIndVals.findIdx? (· == declName) with - | some x => rhs ← `(mixHash $rhs ($(mkIdent ctx.auxFunNames[x]!) $a:ident)) - | none => rhs ← `(mixHash $rhs (hash $a:ident)) - | _ => rhs ← `(mixHash $rhs (hash $a:ident)) + let xType ← inferType x + if let some auxFunName ← ctx.getFunName? header xType fvars then + rhs ← `(mixHash $rhs ($(mkIdent auxFunName) $a)) + else + rhs ← `(mixHash $rhs (hash $a:ident)) patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*)) `(matchAltExpr| | $[$patterns:term],* => $rhs:term) alts := alts.push alt @@ -57,15 +61,17 @@ where def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkHashableHeader indVal - let mut body ← mkMatch ctx header indVal - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx `Hashable header.argNames - body ← mkLet letDecls body + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkHashableHeader argNames nestedOcc let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkMatch ctx header type xs if ctx.usePartial then -- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion + let letDecls ← mkLocalInstanceLetDecls ctx `Hashable header.argNames + body ← mkLet letDecls body `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term) @@ -77,8 +83,8 @@ def mkHashFuncs (ctx : Context) : TermElabM Syntax := do `(mutual $auxDefs:command* end) private def mkHashableInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "hash" declName - let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable #[declName]) + let ctx ← mkContext "hash" declName false + let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable) trace[Elab.Deriving.hashable] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Ord.lean b/src/Lean/Elab/Deriving/Ord.lean index aed902aa4167..223ec4f90b88 100644 --- a/src/Lean/Elab/Deriving/Ord.lean +++ b/src/Lean/Elab/Deriving/Ord.lean @@ -12,19 +12,26 @@ namespace Lean.Elab.Deriving.Ord open Lean.Parser.Term open Meta -def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do - mkHeader `Ord 2 indVal +def mkOrdHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + mkHeader `Ord 2 argNames nestedOcc -def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkMatch (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs type => do let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies let mut indPatterns := #[] -- add `_` pattern for indices @@ -39,8 +46,9 @@ where ctorArgs1 := ctorArgs1.push (← `(_)) ctorArgs2 := ctorArgs2.push (← `(_)) for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! - if type.containsFVar x.fvarId! || (←isProp (←inferType x)) then + let x := xs[i]! + let xType ← inferType x + if type.containsFVar x.fvarId! || (←isProp xType) then -- If resulting type depends on this field or is a proof, we don't need to compare ctorArgs1 := ctorArgs1.push (← `(_)) ctorArgs2 := ctorArgs2.push (← `(_)) @@ -49,7 +57,12 @@ where let b := mkIdent (← mkFreshUserName `b) ctorArgs1 := ctorArgs1.push a ctorArgs2 := ctorArgs2.push b - rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont + let compare ← + if let some auxFunName ← ctx.getFunName? header xType fvars then + `($(mkIdent auxFunName) $a $b) + else + `(compare $a $b) + rhsCont := fun rhs => `(Ordering.then $compare $rhs) >>= rhsCont let lPat ← `(@$(mkIdent ctorName):ident $ctorArgs1:term*) let rPat ← `(@$(mkIdent ctorName):ident $ctorArgs2:term*) let patterns := indPatterns ++ #[lPat, rPat] @@ -64,17 +77,14 @@ where def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkOrdHeader indVal - let mut body ← mkMatch header indVal - if ctx.usePartial || indVal.isRec then - let letDecls ← mkLocalInstanceLetDecls ctx `Ord header.argNames - body ← mkLet letDecls body + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkOrdHeader argNames nestedOcc let binders := header.binders - if ctx.usePartial || indVal.isRec then - `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) - else - `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let body ← mkMatch ctx header type xs + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) def mkMutualBlock (ctx : Context) : TermElabM Syntax := do let mut auxDefs := #[] @@ -87,7 +97,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do let ctx ← mkContext "ord" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName]) + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord) trace[Elab.Deriving.ord] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Repr.lean b/src/Lean/Elab/Deriving/Repr.lean index 0a03ed1cc0e3..269603efa572 100644 --- a/src/Lean/Elab/Deriving/Repr.lean +++ b/src/Lean/Elab/Deriving/Repr.lean @@ -14,13 +14,21 @@ open Lean.Parser.Term open Meta open Std -def mkReprHeader (indVal : InductiveVal) : TermElabM Header := do - let header ← mkHeader `Repr 1 indVal +def mkReprHeader (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + let header ← mkHeader `Repr 1 argNames nestedOcc return { header with binders := header.binders.push (← `(bracketedBinderF| (prec : Nat))) } -def mkBodyForStruct (header : Header) (indVal : InductiveVal) : TermElabM Term := do +def mkRepr (ctx : Context) (header : Header) (id : TSyntax `term) (type : Expr) (fvars : Array Expr) : TermElabM Term := do + if let some auxFunName ← ctx.getFunName? header type fvars then + ``($(mkIdent auxFunName) $id) + else ``(reprPrec $id) + +def mkBodyForStruct(header : Header) (e : Expr) : TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let indVal ← getConstInfoInduct ind let ctorVal ← getConstInfoCtor indVal.ctors.head! let fieldNames := getStructureFields (← getEnv) indVal.name let numParams := indVal.numParams @@ -42,16 +50,23 @@ def mkBodyForStruct (header : Header) (indVal : InductiveVal) : TermElabM Term : fields ← `($fields ++ $fieldNameLit ++ " := " ++ (Format.group (Format.nest $indent (repr ($target.$(mkIdent fieldName):ident))))) `(Format.bracket "{ " $fields:term " }") -def mkBodyForInduct (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do +def mkBodyForInduct (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + let f := e.getAppFn + let ind := f.constName! + let lvls := f.constLevels! + let indVal ← getConstInfoInduct ind let discrs ← mkDiscrs header indVal - let alts ← mkAlts + let alts ← mkAlts indVal lvls `(match $[$discrs],* with $alts:matchAlt*) where - mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do + mkAlts (indVal : InductiveVal) (lvl : List Level) : TermElabM (Array (TSyntax ``matchAlt)) := do let mut alts := #[] for ctorName in indVal.ctors do + let args := e.getAppArgs let ctorInfo ← getConstInfoCtor ctorName - let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do + let ctorApp := mkAppN (mkConst ctorInfo.name lvl) args[:ctorInfo.numParams] + let ctorType ← inferType ctorApp + let alt ← forallTelescopeReducing ctorType fun xs _ => do let mut patterns := #[] -- add `_` pattern for indices for _ in [:indVal.numIndices] do @@ -63,38 +78,39 @@ where for _ in [:indVal.numParams] do ctorArgs := ctorArgs.push (← `(_)) for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! + let x := xs[i]! let a := mkIdent (← mkFreshUserName `a) ctorArgs := ctorArgs.push a let localDecl ← x.fvarId!.getDecl if localDecl.binderInfo.isExplicit then - if (← inferType x).isAppOf indVal.name then - rhs ← `($rhs ++ Format.line ++ $(mkIdent auxFunName):ident $a:ident max_prec) - else if (← isType x <||> isProof x) then + if (← isType x <||> isProof x) then rhs ← `($rhs ++ Format.line ++ "_") else - rhs ← `($rhs ++ Format.line ++ reprArg $a) + let repr ← mkRepr ctx header a (← inferType x) fvars + rhs ← `($rhs ++ Format.line ++ $repr max_prec) patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*)) `(matchAltExpr| | $[$patterns:term],* => Repr.addAppParen (Format.group (Format.nest (if prec >= max_prec then 1 else 2) ($rhs:term))) prec) alts := alts.push alt return alts -def mkBody (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do - if isStructure (← getEnv) indVal.name then - mkBodyForStruct header indVal +def mkBody (ctx : Context) (header : Header) (e : Expr) (fvars : Array Expr): TermElabM Term := do + if isStructure (← getEnv) e.getAppFn.constName! then + mkBodyForStruct header e else - mkBodyForInduct header indVal auxFunName + mkBodyForInduct ctx header e fvars def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let auxFunName := ctx.auxFunNames[i]! - let indVal := ctx.typeInfos[i]! - let header ← mkReprHeader indVal - let mut body ← mkBody header indVal auxFunName + let nestedOcc := ctx.typeInfos[i]! + let argNames := ctx.typeArgNames[i]! + let header ← mkReprHeader argNames nestedOcc + let binders := header.binders + Term.elabBinders binders fun xs => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkBody ctx header type xs if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames body ← mkLet letDecls body - let binders := header.binders - if ctx.usePartial then `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) @@ -108,8 +124,8 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkReprInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "repr" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr #[declName]) + let ctx ← mkContext "repr" declName false + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr) trace[Elab.Deriving.repr] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Util.lean b/src/Lean/Elab/Deriving/Util.lean index 03f7eecae200..4fad61507dac 100644 --- a/src/Lean/Elab/Deriving/Util.lean +++ b/src/Lean/Elab/Deriving/Util.lean @@ -6,6 +6,11 @@ Authors: Leonardo de Moura prelude import Lean.Parser.Term import Lean.Elab.Term +import Lean.Elab.Binders +import Lean.PrettyPrinter +import Lean.Data.Options +import Init.Data.Sum +import Lean.Meta.CollectFVars namespace Lean.Elab.Deriving open Meta @@ -36,59 +41,349 @@ open TSyntax.Compat in /-- Return implicit binder syntaxes for the given `argNames`. The output matches `implicitBinder*`. For example, ``#[`foo,`bar]`` gives `` `({foo} {bar})``. -/ -def mkImplicitBinders (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.implicitBinder)) := +def mkImplicitBinders (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.implicitBinder)) := do argNames.mapM fun argName => `(implicitBinderF| { $(mkIdent argName) }) +/-- + Represents `ind params1 ... paramsn`, where `paramsi` is either a nested occurence, a constant name, or a free variable + + Example : + ``` + inductive Foo (α β γ: Type) : Type → Type + + inductive Bar + | foo : A -> B -> Foo A (Option Bar) B Nat → Bar + + ``` + this nested occurence `Foo A (Option Bar) B Nat` is encoded as + ```lean + node Foo #[ + .inr (.bvar 2), + .inl (node Option #[.inr (.leaf Bar)]), + .inr (.bvar 1), + ] + ``` + + Remark 1 : Since an inductive can only be nested as a parameter, not an index, of another inductive type, + we assume `params.size == ind.numParams` + + Remark 2 : Free variables are abstracted away in nested occurences. + This is useful when trying to delete duplicate occurences, since the check is now purely syntactical. + -/ +inductive NestedOccurence : Type := + | node (ind : InductiveVal) (params : Array (NestedOccurence ⊕ Expr)) + | leaf (ind : InductiveVal) (fvars : Array Expr) + +namespace NestedOccurence + +instance : Inhabited NestedOccurence := ⟨leaf default #[]⟩ +partial instance : BEq NestedOccurence := ⟨go⟩ +where go + | leaf ind₁ _,.leaf ind₂ _ => ind₁.name == ind₂.name + | node ind₁ arr₁,.node ind₂ arr₂ => Id.run do + unless ind₁.name == ind₂.name && arr₁.size == arr₂.size do + return false + for i in [:arr₁.size] do + unless @Sum.instBEq _ _ ⟨go⟩ inferInstance |>.beq arr₁[i]! arr₂[i]! do + return false + return true + | _,_ => false + +partial instance : ToString NestedOccurence := ⟨go⟩ +where go + | leaf ind e=> s!"leaf {ind.name} {e}" + | node ind arr => + let s := arr.map (@instToStringSum _ _ ⟨go⟩ inferInstance |>.toString) + s!"node {ind.name} {s}" + +@[inline] +def getIndVal : NestedOccurence → InductiveVal + | leaf indVal _| node indVal _ => indVal + +@[inline] +def getArr : NestedOccurence → Array (NestedOccurence ⊕ Expr) + | leaf .. => #[] + | node _ arr => arr + +@[inline] +def isLeaf : NestedOccurence → Bool + | leaf ..=> true + | node .. => false + +@[inline] +def isNode : NestedOccurence → Bool := not ∘ isLeaf + +partial def containsFVar (fvarId : FVarId) : NestedOccurence → Bool + | leaf _ e => e.any (Expr.containsFVar · fvarId) + | node _ arr => arr.any (Sum.lift (containsFVar fvarId) (Expr.containsFVar · fvarId)) + +partial def toListofNests (e : NestedOccurence) : List NestedOccurence := + match e with + | .leaf _ _ => [] + | .node _ arr => + let l := flip arr.foldr [] fun occ l => + if let .inl occ := occ then + occ.toListofNests ++ l + else l + e::l + +/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/ +partial def mkAppTerm (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do + go nestedOcc argNames +where + go (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do + match nestedOcc with + | leaf indVal _ => do + let f := mkCIdent indVal.name + let numArgs := indVal.numParams + indVal.numIndices + unless argNames.size >= numArgs do + throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" + let mut args := Array.mkArray numArgs default + for i in [:numArgs] do + let arg := mkIdent argNames[i]! + args := args.set! i arg + `(@$f $args*) + | node indVal arr => + let f := mkCIdent indVal.name + let mut args := #[] + for nestedOcc? in arr do + match nestedOcc? with + | .inl occ => + let arg ← go occ argNames + args := args.push arg + | .inr (.bvar i) => + let some argName := argNames[argNames.size-i-1]? + | throwError s!"Cannot instantiate {nestedOcc} : not enough arguments given" + let id := mkIdent argName + args := args.push <| ←`($id) + | .inr e => + let tm ← PrettyPrinter.delab e + args := args.push <| ←`($tm) + `(@$f $args*) + +/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/ +partial def mkAppExpr (nestedOcc : NestedOccurence) (argNames : Array Expr) : TermElabM Expr := do + let res ← go nestedOcc argNames + return res +where + go (nestedOcc : NestedOccurence) (argNames : Array Expr): TermElabM Expr := do + match nestedOcc with + | leaf indVal _ => do + let numArgs := indVal.numParams + indVal.numIndices + unless argNames.size >= numArgs do + throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" + let mut args := Array.mkArray numArgs default + for i in [:numArgs] do + let arg := argNames[i]! + args := args.modify i (fun _ => arg) + let name ← Meta.mkConstWithFreshMVarLevels indVal.name + let res := args.foldl mkApp name + return res + | node indVal arr => + let mut args := #[] + for nestedOcc? in arr do + match nestedOcc? with + | .inl occ => + let arg ← go occ argNames + args := args.push arg + | .inr (.bvar i) => + let some argName := argNames[argNames.size-i-1]? + | throwError s!"Cannot instantiate {nestedOcc} : not enough arguments given" + args := args.push argName + | .inr e => + args := args.push e + let res ← Meta.mkAppOptM indVal.name (args.map some) + return res + +structure Result where + occ : NestedOccurence + args : Subarray Expr + argNames : Array Name + +instance : ToString Result where + toString res := s!"⟨{res.occ},{res.args},{res.argNames}⟩" + +instance : BEq Result := ⟨(·.occ == ·.occ)⟩ + +structure Context where + indNames : List Name + res : List Result + +end NestedOccurence + +abbrev NestedOccM := StateT NestedOccurence.Context TermElabM + +def withIndNames (indNames : List Name) (f : NestedOccM Unit) : TermElabM NestedOccurence.Context := do + let ⟨(),ctx⟩ ← StateT.run f ⟨indNames,[]⟩ + return ctx + +def add_res (x : NestedOccurence.Result) : NestedOccM Unit := do + let ⟨names,res⟩ ← get + set (⟨names,x::res⟩ : NestedOccurence.Context) + +def add_name (n : Name) : NestedOccM Unit := do + let ⟨names,res⟩ ← get + set (⟨n::names,res⟩ : NestedOccurence.Context) + +partial def getNestedOccurencesOf (inds : List Name) (e: Expr) (fvars : Array Expr): MetaM (Option NestedOccurence) := do + let .inl occs ← go e | return none + trace[Elab.Deriving] s!"getNestedOccurencesOf {inds} {e} {fvars} =\n{occs}" + return occs +where + go (e : Expr) : MetaM (NestedOccurence ⊕ Expr) := do + trace[Elab.Deriving] s!"go {inds} {e} {fvars}" + let hd := e.getAppFn + let args := e.getAppArgs + trace[Elab.Deriving] s!"args : {args}" + let fallback _ := return .inr <| e.abstract fvars + let .const name _ := hd | fallback () + if let some indName := inds.find? (· = name) then + let indVal ← getConstInfoInduct indName + let args := args.map (Expr.instantiateRev · fvars) + return .inl <| .leaf indVal args + else + try + let indVal ← getConstInfoInduct name + let args := args.map (Expr.abstract · fvars) + trace[Elab.Deriving] s!"abstracted args : {args}" + let nestedOccsArgs ← args.mapM go + if nestedOccsArgs.any Sum.isLeft then + return .inl <| .node indVal nestedOccsArgs + else fallback () + catch _ => fallback () + +partial def getNestedOccurences (indNames : List Name) : TermElabM (List NestedOccurence.Result) := do + let ⟨_,l⟩ ← withIndNames indNames do + for name in indNames do + go name #[] #[] + let l := l.eraseDups + trace[Elab.Deriving] s!"getNestedOccurences {indNames} =\n{l}" + return l.eraseDups +where + go (indName : Name) (args : Array Expr) (fvars : Array Expr): NestedOccM Unit := do + trace[Elab.Deriving] s!"go2 {indNames} {indName} {args} {fvars}" + let indVal ← getConstInfoInduct indName + if !indVal.isNested && args.size == 0 then + return + let constrs ← indVal.ctors.mapM getConstInfoCtor + for constInfo in constrs do + let instConstInfo ← forallBoundedTelescope constInfo.type args.size fun xs e => + return e.abstract xs |>.instantiateRev args + forallTelescope instConstInfo fun xs _ => do + let mut paramArgs ← fvars.mapM fun e => do + let localDecl ← e.fvarId!.getDecl + mkFreshUserName localDecl.userName.eraseMacroScopes + let mut l := [] + for i in [:constInfo.numParams-args.size] do + let some e := xs[i]? | break + let localDecl ← e.fvarId!.getDecl + let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes + paramArgs := paramArgs.push paramName + let mut localArgs := #[] + for i in [:xs.size] do + let e := xs[i]! + let ty ← e.fvarId!.getType + let localDecl ← e.fvarId!.getDecl + let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes + let occs ← getNestedOccurencesOf indNames ty xs[:i] + let l' := if let .some x := occs then x.toListofNests else [] + for occ in l' do + trace[Elab.Deriving] s!"paramArgs : {paramArgs}" + let new_args := paramArgs ++ localArgs.filter (occ.containsFVar ⟨·⟩) + trace[Elab.Deriving] s!"localArgs : {localArgs}" + trace[Elab.Deriving] s!"occ : {occ}" + -- let new_args := new_args.filter (occ.containsFVar ⟨·⟩) + trace[Elab.Deriving] s!"filtered vars : {new_args}" + if (← get).res.all (occ != ·.occ) then + add_res ⟨occ,xs[:i],new_args⟩ + let fvars := fvars ++ xs[:i] + let app ← occ.mkAppExpr fvars + let hd := app.getAppFn.constName! + let args := app.getAppArgs + add_name hd + go hd args fvars + else + add_res ⟨occ,xs[:i],new_args⟩ + l := l ++ l' + localArgs := localArgs.push paramName + +def indNameToFunName (indName : Name) : String := + match indName.eraseMacroScopes with + | .str _ t => t + | _ => "instFn" + +partial def mkInstName: NestedOccurence → String + | .leaf ind _ => indNameToFunName ind.name + | .node ind arr => Id.run do + let mut res ← indNameToFunName ind.name + for nestedOcc in arr do + if let .inl occ := nestedOcc then + let nestedInstName ← mkInstName occ + res := res ++ nestedInstName + return res + /-- Return instance binder syntaxes binding `className α` for every generic parameter `α` of the inductive `indVal` for which such a binding is type-correct. `argNames` is expected to provide names for the parameters (see `mkInductArgNames`). The output matches `instBinder*`. For example, given `inductive Foo {α : Type} (n : Nat) : (β : Type) → Type`, where `β` is an index, invoking ``mkInstImplicitBinders `BarClass foo #[`α, `n, `β]`` gives `` `([BarClass α])``. -/ -def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) := +partial def mkInstImplicitBinders (className : Name) (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM (Array Syntax) := do + go nestedOcc argNames +where + go (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM (Array Syntax) := + let indVal := nestedOcc.getIndVal + let arr := nestedOcc.getArr forallBoundedTelescope indVal.type indVal.numParams fun xs _ => do - let mut binders := #[] + let mut binders : Array Syntax := #[] for i in [:xs.size] do - try - let x := xs[i]! - let c ← mkAppM className #[x] - if (← isTypeCorrect c) then - let argName := argNames[i]! - let binder : Syntax ← `(instBinderF| [ $(mkCIdent className):ident $(mkIdent argName):ident ]) - binders := binders.push binder - catch _ => - pure () + if nestedOcc.isNode && arr[i]? matches some (.inl _) then + let occ := arr[i]!.getLeft! + let nestedBinders ← go occ argNames + binders := binders ++ nestedBinders + else try + let x := xs[i]! + let hd ← mkConstWithFreshMVarLevels className + let c := mkAppN hd #[x] + if (← isTypeCorrect c) then + let some argName := argNames[i]? | pure () + let binder ← `(instBinderF| [ $(mkCIdent className) $(mkIdent argName)]) + binders := binders.push binder + catch _ => pure () return binders -structure Context where - typeInfos : Array InductiveVal +structure Context : Type where + indNames : List Name + typeArgNames: Array (Array Name) + typeInfos : Array NestedOccurence auxFunNames : Array Name usePartial : Bool -def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do +def mkContext (fnPrefix : String) (typeName : Name) (withNested : Bool := true): TermElabM Context := do let indVal ← getConstInfoInduct typeName - let mut typeInfos := #[] - for typeName in indVal.all do - typeInfos := typeInfos.push (← getConstInfoInduct typeName) - let mut auxFunNames := #[] - for typeName in indVal.all do - match typeName.eraseMacroScopes with - | .str _ t => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ t) - | _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn) - trace[Elab.Deriving.beq] "{auxFunNames}" - let usePartial := indVal.isNested || typeInfos.size > 1 - return { - typeInfos := typeInfos - auxFunNames := auxFunNames - usePartial := usePartial - } + let indNames := indVal.all + let mut typeInfos' : List NestedOccurence.Result := [] + for indName in indNames do + let indVal ← getConstInfoInduct indName + let args ← mkInductArgNames indVal + typeInfos' := ⟨.leaf indVal #[],#[].toSubarray,args⟩::typeInfos' + if withNested then + typeInfos' := (← getNestedOccurences indVal.all) ++ typeInfos' + let typeArgNames := typeInfos'.map (·.argNames) |>.toArray + let typeInfos := typeInfos'.map (·.occ) |>.toArray + let auxFunNames ← typeInfos.mapM fun occ => do + return ← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ mkInstName occ + let usePartial := !withNested && indVal.isNested + return {indNames, typeArgNames, typeInfos, auxFunNames, usePartial} def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.letDecl)) := do let mut letDecls := #[] for i in [:ctx.typeInfos.size] do - let indVal := ctx.typeInfos[i]! + let occ := ctx.typeInfos[i]! let auxFunName := ctx.auxFunNames[i]! + unless occ.isLeaf do continue + let indVal := occ.getIndVal let currArgNames ← mkInductArgNames indVal let numParams := indVal.numParams let currIndices := currArgNames[numParams:] @@ -107,22 +402,22 @@ def mkLet (letDecls : Array (TSyntax ``Parser.Term.letDecl)) (body : Term) : Ter `(let $letDecl:letDecl; $body) open TSyntax.Compat in -def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Command) := do +def mkInstanceCmds (ctx : Context) (className : Name) (useAnonCtor := true) : TermElabM (Array Command) := do let mut instances := #[] for i in [:ctx.typeInfos.size] do - let indVal := ctx.typeInfos[i]! - if typeNames.contains indVal.name then - let auxFunName := ctx.auxFunNames[i]! - let argNames ← mkInductArgNames indVal - let binders ← mkImplicitBinders argNames - let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) - let indType ← mkInductiveApp indVal argNames - let type ← `($(mkCIdent className) $indType) - let mut val := mkIdent auxFunName - if useAnonCtor then - val ← `(⟨$val⟩) - let instCmd ← `(instance $binders:implicitBinder* : $type := $val) - instances := instances.push instCmd + let nestedOcc := ctx.typeInfos[i]! + unless nestedOcc.isLeaf do continue + let auxFunName := ctx.auxFunNames[i]! + let argNames := ctx.typeArgNames[i]! + let binders ← mkImplicitBinders argNames + let binders := binders ++ (← mkInstImplicitBinders className nestedOcc argNames) + let indType ← nestedOcc.mkAppTerm argNames + let type ← `($(mkCIdent className) $indType) + let mut val := mkIdent auxFunName + if useAnonCtor then + val ← `(⟨$val⟩) + let instCmd ← `(instance $binders:implicitBinder* : $type := $val) + instances := instances.push instCmd return instances def mkDiscr (varName : Name) : TermElabM (TSyntax ``Parser.Term.matchDiscr) := @@ -135,15 +430,14 @@ structure Header where targetType : Term open TSyntax.Compat in -def mkHeader (className : Name) (arity : Nat) (indVal : InductiveVal) : TermElabM Header := do - let argNames ← mkInductArgNames indVal - let binders ← mkImplicitBinders argNames - let targetType ← mkInductiveApp indVal argNames +def mkHeader (className : Name) (arity : Nat) (argNames : Array Name) (nestedOcc : NestedOccurence) : TermElabM Header := do + let mut binders ← mkImplicitBinders argNames + let targetType ← nestedOcc.mkAppTerm argNames let mut targetNames := #[] for _ in [:arity] do targetNames := targetNames.push (← mkFreshUserName `x) - let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) - let binders := binders ++ (← targetNames.mapM fun targetName => `(explicitBinderF| ($(mkIdent targetName) : $targetType))) + binders := binders ++ (← mkInstImplicitBinders className nestedOcc argNames) + binders := binders ++ (← targetNames.mapM fun targetName => `(explicitBinderF| ($(mkIdent targetName) : $targetType))) return { binders := binders argNames := argNames @@ -158,4 +452,10 @@ def mkDiscrs (header : Header) (indVal : InductiveVal) : TermElabM (Array (TSynt discrs := discrs.push (← mkDiscr argName) return discrs ++ (← header.targetNames.mapM mkDiscr) +def Context.getFunName? (ctx : Context) (header : Header) (ty : Expr) (xs : Array Expr): TermElabM (Option Name) := do + let indValNum ← ctx.typeInfos.findIdxM? <| + (return .some · == (← getNestedOccurencesOf ctx.indNames ty xs[:header.argNames.size])) + let recField := indValNum.map (ctx.auxFunNames[·]!) + return recField + end Lean.Elab.Deriving diff --git a/tests/lean/3057.lean.expected.out b/tests/lean/3057.lean.expected.out index 34efeb5f83b6..dcb3e5ddfc90 100644 --- a/tests/lean/3057.lean.expected.out +++ b/tests/lean/3057.lean.expected.out @@ -1,10 +1,10 @@ -instReprTree -instReprListTree -instDecidableEqTree -instDecidableEqListTree -instBEqTree -instBEqListTree -instHashableTree -instHashableListTree -instOrdTree -instOrdListTree +instReprTree_1 +instReprListTree_1 +instDecidableEqTree_1 +instDecidableEqListTree_1 +instBEqTree_1 +instBEqListTree_1 +instHashableTree_1 +instHashableListTree_1 +instOrdTree_1 +instOrdListTree_1 diff --git a/tests/lean/decEqMutualInductives.lean b/tests/lean/decEqMutualInductives.lean index fbe0af7e9113..f18d6fcf6138 100644 --- a/tests/lean/decEqMutualInductives.lean +++ b/tests/lean/decEqMutualInductives.lean @@ -1,18 +1,19 @@ /-! Verify that the derive handler for `DecidableEq` handles mutual inductive types-/ --- Print the generated derivations -set_option trace.Elab.Deriving.decEq true - mutual -inductive Tree : Type := +inductive Tree := | node : ListTree → Tree -inductive ListTree : Type := +inductive ListTree := | nil : ListTree | cons : Tree → ListTree → ListTree deriving DecidableEq end +inductive Tree' (α : Type _) : Type _:= + | node : α → Option (List (Tree' α)) → Tree' α +deriving DecidableEq + mutual inductive Foo₁ : Type := | foo₁₁ : Foo₁ @@ -25,3 +26,36 @@ inductive Foo₂ : Type := inductive Foo₃ : Type := | foo₃ : Foo₁ → Foo₃ end + +inductive Min' where + | Base + | Const (a : List Min') +deriving DecidableEq + +inductive ComplexInductive (A B C : Type) (n : Nat) : Type + | constr : A → B → C → ComplexInductive A B C n + +inductive NestedComplex (A C : Type) : Type + | constr : ComplexInductive A (NestedComplex A C) C 1 → NestedComplex A C +deriving DecidableEq + +namespace nested + +inductive Tree (α : Type) where + | node : Array (Tree α) → Tree α +deriving DecidableEq + +end nested + +namespace mess + +mutual + +inductive Mess1 where + | node : Array (Mess2) → Mess1 +deriving DecidableEq +inductive Mess2 where + | node : Array (Mess1) → Mess2 +end + +end mess diff --git a/tests/lean/decEqMutualInductives.lean.expected.out b/tests/lean/decEqMutualInductives.lean.expected.out index bc22d84274e3..e69de29bb2d1 100644 --- a/tests/lean/decEqMutualInductives.lean.expected.out +++ b/tests/lean/decEqMutualInductives.lean.expected.out @@ -1,55 +0,0 @@ -[Elab.Deriving.decEq] - [mutual - private def decEqTree✝ (x✝ : @Tree✝) (x✝¹ : @Tree✝) : Decidable✝ (x✝ = x✝¹) := - match x✝, x✝¹ with - | @Tree.node a✝, @Tree.node b✝ => - let inst✝ := decEqListTree✝ @a✝ @b✝; - if h✝ : @a✝ = @b✝ then by subst h✝; exact isTrue✝ rfl✝ - else isFalse✝ (by intro n✝; injection n✝; apply h✝ _; assumption) - termination_by structural x✝ - private def decEqListTree✝ (x✝² : @ListTree✝) (x✝³ : @ListTree✝) : Decidable✝ (x✝² = x✝³) := - match x✝², x✝³ with - | @ListTree.nil, @ListTree.nil => isTrue✝¹ rfl✝¹ - | ListTree.nil .., ListTree.cons .. => isFalse✝¹ (by intro h✝¹; injection h✝¹) - | ListTree.cons .., ListTree.nil .. => isFalse✝¹ (by intro h✝¹; injection h✝¹) - | @ListTree.cons a✝¹ a✝², @ListTree.cons b✝¹ b✝² => - let inst✝¹ := decEqTree✝ @a✝¹ @b✝¹; - if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; - exact - let inst✝² := decEqListTree✝ @a✝² @b✝²; - if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝² - else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝³ _; assumption) - else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝² _; assumption) - termination_by structural x✝² - end, - instance : DecidableEq✝ (@ListTree✝) := - decEqListTree✝] -[Elab.Deriving.decEq] - [mutual - private def decEqFoo₁✝ (x✝ : @Foo₁✝) (x✝¹ : @Foo₁✝) : Decidable✝ (x✝ = x✝¹) := - match x✝, x✝¹ with - | @Foo₁.foo₁₁, @Foo₁.foo₁₁ => isTrue✝ rfl✝ - | Foo₁.foo₁₁ .., Foo₁.foo₁₂ .. => isFalse✝ (by intro h✝; injection h✝) - | Foo₁.foo₁₂ .., Foo₁.foo₁₁ .. => isFalse✝ (by intro h✝; injection h✝) - | @Foo₁.foo₁₂ a✝, @Foo₁.foo₁₂ b✝ => - let inst✝ := decEqFoo₂✝ @a✝ @b✝; - if h✝¹ : @a✝ = @b✝ then by subst h✝¹; exact isTrue✝¹ rfl✝¹ - else isFalse✝¹ (by intro n✝; injection n✝; apply h✝¹ _; assumption) - termination_by structural x✝ - private def decEqFoo₂✝ (x✝² : @Foo₂✝) (x✝³ : @Foo₂✝) : Decidable✝ (x✝² = x✝³) := - match x✝², x✝³ with - | @Foo₂.foo₂ a✝¹, @Foo₂.foo₂ b✝¹ => - let inst✝¹ := decEqFoo₃✝ @a✝¹ @b✝¹; - if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝² rfl✝² - else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝² _; assumption) - termination_by structural x✝² - private def decEqFoo₃✝ (x✝⁴ : @Foo₃✝) (x✝⁵ : @Foo₃✝) : Decidable✝ (x✝⁴ = x✝⁵) := - match x✝⁴, x✝⁵ with - | @Foo₃.foo₃ a✝², @Foo₃.foo₃ b✝² => - let inst✝² := decEqFoo₁✝ @a✝² @b✝²; - if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝³ rfl✝³ - else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝³ _; assumption) - termination_by structural x✝⁴ - end, - instance : DecidableEq✝ (@Foo₁✝) := - decEqFoo₁✝] diff --git a/tests/lean/run/toFromJson.lean b/tests/lean/run/toFromJson.lean index 2c9165b7e70e..75f1609e1dad 100644 --- a/tests/lean/run/toFromJson.lean +++ b/tests/lean/run/toFromJson.lean @@ -42,7 +42,6 @@ def checkRoundTrip [Repr α] [BEq α] [ToJson α] [FromJson α] (obj : α) : Met else throwError "couldn't parse: {repr obj} ≟ {obj |> toJson}" --- set_option trace.Meta.debug true in structure Foo where x : Nat y : String