From fedaf850bb7c64d9149a3d204b969304e23fbecd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 1 Jan 2025 18:56:27 +0100 Subject: [PATCH] fix: theorem instantiation in `grind` (#6492) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes a bug in the theorem instantiation procedure in the (WIP) `grind` tactic. For example, it was missing the following instance in one of the tests: ```lean [grind.ematch.instance] Array.get_set_ne: ∀ (hj : i < bs.size), j ≠ i → (bs.set j w ⋯)[i] = bs[i] ``` This PR also renames the `grind` base monad to `GrindCoreM`. --- src/Lean/Meta/Tactic/Grind/EMatch.lean | 29 +++--------- src/Lean/Meta/Tactic/Grind/Preprocessor.lean | 6 +-- src/Lean/Meta/Tactic/Grind/Run.lean | 8 ++-- src/Lean/Meta/Tactic/Grind/Simp.lean | 4 +- src/Lean/Meta/Tactic/Grind/Types.lean | 50 ++++++++++---------- tests/lean/run/grind_ematch1.lean | 1 + 6 files changed, 43 insertions(+), 55 deletions(-) diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index e4fc281b08cb..43d2f5299372 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -173,7 +173,7 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) : After processing a (multi-)pattern, use the choice assignment to instantiate the proof. Missing parameters are synthesized using type inference and type class synthesis." -/ -private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do +private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do withNewMCtxDepth do let thm := (← read).thm unless (← markTheorenInstance thm.proof c.assignment) do return () @@ -203,32 +203,17 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do if (← mvars.allM (·.mvarId!.isAssigned)) then addNewInstance thm.origin (mkAppN proof mvars) c.gen else - -- instance has hypothesis - mkImp mvars 0 proof #[] + let proof := mkAppN proof mvars + let mvars ← mvars.filterM fun mvar => return !(← mvar.mvarId!.isAssigned) + if let some mvarBad ← mvars.findM? fun mvar => return !(← isProof mvar) then + trace[grind.issues] "failed to instantiate {← thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}" + let proof ← mkLambdaFVars (binderInfoForMVars := .default) mvars (← instantiateMVars proof) + addNewInstance thm.origin proof c.gen where synthesizeInstance (x type : Expr) : MetaM Bool := do let .some val ← trySynthInstance type | return false isDefEq x val - mkImp (mvars : Array Expr) (i : Nat) (proof : Expr) (xs : Array Expr) : M Unit := do - if h : i < mvars.size then - let mvar := mvars[i] - if (← mvar.mvarId!.isAssigned) then - mkImp mvars (i+1) (mkApp proof mvar) xs - else - let mvarType ← instantiateMVars (← inferType mvar) - if mvarType.hasMVar then - let thm := (← read).thm - trace[grind.issues] "failed to create hypothesis for instance of {← thm.origin.pp} hypothesis type has metavars{indentExpr mvarType}" - return () - withLocalDeclD (← mkFreshUserName `h) mvarType fun x => do - mkImp mvars (i+1) (mkApp proof x) (xs.push x) - else - let proof ← instantiateMVars proof - let proof ← mkLambdaFVars xs proof - let thm := (← read).thm - addNewInstance thm.origin proof c.gen - /-- Process choice stack until we don't have more choices to be processed. -/ private partial def processChoices : M Unit := do unless (← get).choiceStack.isEmpty do diff --git a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean index 24eafd0f314d..044beea18b31 100644 --- a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean +++ b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean @@ -27,9 +27,9 @@ structure State where goals : PArray Goal := {} deriving Inhabited -abbrev PreM := StateRefT State GrindM +abbrev PreM := StateRefT State GrindCoreM -def PreM.run (x : PreM α) : GrindM α := do +def PreM.run (x : PreM α) : GrindCoreM α := do x.run' {} inductive IntroResult where @@ -168,7 +168,7 @@ def preprocess (mvarId : MVarId) (mainDeclName : Name) (config : Grind.Config) : Preprocessor.preprocess mvarId |>.run |>.run mainDeclName config def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) : MetaM (List MVarId) := do - let go : GrindM (List MVarId) := do + let go : GrindCoreM (List MVarId) := do let s ← Preprocessor.preprocess mvarId |>.run let goals := s.goals.toList.filter fun goal => !goal.inconsistent return goals.map (·.mvarId) diff --git a/src/Lean/Meta/Tactic/Grind/Run.lean b/src/Lean/Meta/Tactic/Grind/Run.lean index 7ad71412378d..2f723dd52c17 100644 --- a/src/Lean/Meta/Tactic/Grind/Run.lean +++ b/src/Lean/Meta/Tactic/Grind/Run.lean @@ -25,7 +25,7 @@ def mkMethods : CoreM Methods := do prop e } -def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) : MetaM α := do +def GrindCoreM.run (x : GrindCoreM α) (mainDeclName : Name) (config : Grind.Config) : MetaM α := do let scState := ShareCommon.State.mk _ let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) @@ -37,13 +37,13 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) : M (congrTheorems := (← getSimpCongrTheorems)) x (← mkMethods).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } -@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) := +@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindCoreM (α × Goal) := goal.mvarId.withContext do StateRefT'.run x goal -@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal := +@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindCoreM Goal := goal.mvarId.withContext do StateRefT'.run' (x *> get) goal -def mkGoal (mvarId : MVarId) : GrindM Goal := do +def mkGoal (mvarId : MVarId) : GrindCoreM Goal := do let trueExpr ← getTrueExpr let falseExpr ← getFalseExpr let thmMap ← getEMatchTheorems diff --git a/src/Lean/Meta/Tactic/Grind/Simp.lean b/src/Lean/Meta/Tactic/Grind/Simp.lean index 50435d7eb788..10f520aae43c 100644 --- a/src/Lean/Meta/Tactic/Grind/Simp.lean +++ b/src/Lean/Meta/Tactic/Grind/Simp.lean @@ -15,7 +15,7 @@ namespace Lean.Meta.Grind -- TODO: implement `simp` discharger using preprocessor state /-- Simplifies the given expression using the `grind` simprocs and normalization theorems. -/ -def simp (e : Expr) : GrindM Simp.Result := do +def simp (e : Expr) : GrindCoreM Simp.Result := do let simpStats := (← get).simpStats let (r, simpStats) ← Meta.simp e (← readThe Context).simp (← readThe Context).simprocs (stats := simpStats) modify fun s => { s with simpStats } @@ -25,7 +25,7 @@ def simp (e : Expr) : GrindM Simp.Result := do Simplifies `e` using `grind` normalization theorems and simprocs, and then applies several other preprocessing steps. -/ -def pre (e : Expr) : GrindM Simp.Result := do +def pre (e : Expr) : GrindCoreM Simp.Result := do let r ← simp e let e' := r.expr let e' ← markNestedProofs e' diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index fb88cb021fd1..e9985f01ada9 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -46,7 +46,7 @@ register_builtin_option grind.debug.proofs : Bool := { descr := "check proofs between the elements of all equivalence classes" } -/-- Context for `GrindM` monad. -/ +/-- Context for `GrindCoreM` monad. -/ structure Context where simp : Simp.Context simprocs : Array Simp.Simprocs @@ -66,8 +66,8 @@ instance : BEq CongrTheoremCacheKey where instance : Hashable CongrTheoremCacheKey where hash a := mixHash (unsafe ptrAddrUnsafe a.f).toUInt64 (hash a.numArgs) -/-- State for the `GrindM` monad. -/ -structure State where +/-- State for the `GrindCoreM` monad. -/ +structure CoreState where canon : Canon.State := {} /-- `ShareCommon` (aka `Hashconsing`) state. -/ scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _ @@ -87,34 +87,34 @@ private opaque MethodsRefPointed : NonemptyType.{0} private def MethodsRef : Type := MethodsRefPointed.type instance : Nonempty MethodsRef := MethodsRefPointed.property -abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM +abbrev GrindCoreM := ReaderT MethodsRef $ ReaderT Context $ StateRefT CoreState MetaM /-- Returns the user-defined configuration options -/ -def getConfig : GrindM Grind.Config := +def getConfig : GrindCoreM Grind.Config := return (← readThe Context).config /-- Returns the internalized `True` constant. -/ -def getTrueExpr : GrindM Expr := do +def getTrueExpr : GrindCoreM Expr := do return (← get).trueExpr /-- Returns the internalized `False` constant. -/ -def getFalseExpr : GrindM Expr := do +def getFalseExpr : GrindCoreM Expr := do return (← get).falseExpr -def getMainDeclName : GrindM Name := +def getMainDeclName : GrindCoreM Name := return (← readThe Context).mainDeclName -@[inline] def getMethodsRef : GrindM MethodsRef := +@[inline] def getMethodsRef : GrindCoreM MethodsRef := read /-- Returns maximum term generation that is considered during ematching. -/ -def getMaxGeneration : GrindM Nat := do +def getMaxGeneration : GrindCoreM Nat := do return (← getConfig).gen /-- Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization. -/ -def abstractNestedProofs (e : Expr) : GrindM Expr := do +def abstractNestedProofs (e : Expr) : GrindCoreM Expr := do let nextIdx := (← get).nextThmIdx let (e, s') ← AbstractNestedProofs.visit e |>.run { baseName := (← getMainDeclName) } |>.run |>.run { nextIdx } modify fun s => { s with nextThmIdx := s'.nextIdx } @@ -124,7 +124,7 @@ def abstractNestedProofs (e : Expr) : GrindM Expr := do Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have been hash-consing. We perform this step before we internalize expressions. -/ -def shareCommon (e : Expr) : GrindM Expr := do +def shareCommon (e : Expr) : GrindCoreM Expr := do modifyGet fun { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, simpStats } => let (e, scState) := ShareCommon.State.shareCommon scState e (e, { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, simpStats }) @@ -132,16 +132,24 @@ def shareCommon (e : Expr) : GrindM Expr := do /-- Canonicalizes nested types, type formers, and instances in `e`. -/ -def canon (e : Expr) : GrindM Expr := do +def canon (e : Expr) : GrindCoreM Expr := do let canonS ← modifyGet fun s => (s.canon, { s with canon := {} }) let (e, canonS) ← Canon.canon e |>.run canonS modify fun s => { s with canon := canonS } return e +/-- Returns `true` if `e` is the internalized `True` expression. -/ +def isTrueExpr (e : Expr) : GrindCoreM Bool := + return isSameExpr e (← getTrueExpr) + +/-- Returns `true` if `e` is the internalized `False` expression. -/ +def isFalseExpr (e : Expr) : GrindCoreM Bool := + return isSameExpr e (← getFalseExpr) + /-- Creates a congruence theorem for a `f`-applications with `numArgs` arguments. -/ -def mkHCongrWithArity (f : Expr) (numArgs : Nat) : GrindM CongrTheorem := do +def mkHCongrWithArity (f : Expr) (numArgs : Nat) : GrindCoreM CongrTheorem := do let key := { f, numArgs } if let some result := (← get).congrThms.find? key then return result @@ -197,6 +205,7 @@ structure ENode where -- TODO: see Lean 3 implementation deriving Inhabited, Repr +/-- New equality to be processed. -/ structure NewEq where lhs : Expr rhs : Expr @@ -252,6 +261,7 @@ where | .app f a => go f (mixHash r (hashRoot enodes a)) | _ => mixHash r (hashRoot enodes e) +/-- Returns `true` if `a` and `b` are congruent modulo the equivalence classes in `enodes`. -/ partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool := if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then hasSameRoot enodes (a.getArg! 0) (b.getArg! 0) @@ -343,7 +353,7 @@ structure Goal where def Goal.admit (goal : Goal) : MetaM Unit := goal.mvarId.admit -abbrev GoalM := StateRefT Goal GrindM +abbrev GoalM := StateRefT Goal GrindCoreM abbrev Propagator := Expr → GoalM Unit @@ -362,14 +372,6 @@ def markTheorenInstance (proof : Expr) (assignment : Array Expr) : GoalM Bool := def checkMaxInstancesExceeded : GoalM Bool := do return (← get).numInstances >= (← getConfig).instances -/-- Returns `true` if `e` is the internalized `True` expression. -/ -def isTrueExpr (e : Expr) : GrindM Bool := - return isSameExpr e (← getTrueExpr) - -/-- Returns `true` if `e` is the internalized `False` expression. -/ -def isFalseExpr (e : Expr) : GrindM Bool := - return isSameExpr e (← getFalseExpr) - /-- Returns `some n` if `e` has already been "internalized" into the Otherwise, returns `none`s. @@ -616,7 +618,7 @@ def Methods.toMethodsRef (m : Methods) : MethodsRef := private def MethodsRef.toMethods (m : MethodsRef) : Methods := unsafe unsafeCast m -@[inline] def getMethods : GrindM Methods := +@[inline] def getMethods : GrindCoreM Methods := return (← getMethodsRef).toMethods def propagateUp (e : Expr) : GoalM Unit := do diff --git a/tests/lean/run/grind_ematch1.lean b/tests/lean/run/grind_ematch1.lean index 620e590247c2..09a574756433 100644 --- a/tests/lean/run/grind_ematch1.lean +++ b/tests/lean/run/grind_ematch1.lean @@ -11,6 +11,7 @@ set_option grind.debug.proofs true /-- info: [grind.ematch.instance] Array.get_set_eq: (bs.set j w ⋯)[j] = w [grind.ematch.instance] Array.get_set_eq: (as.set i v ⋯)[i] = v +[grind.ematch.instance] Array.get_set_ne: ∀ (hj : i < bs.size), j ≠ i → (bs.set j w ⋯)[i] = bs[i] -/ #guard_msgs (info) in example (as : Array α)