Skip to content

Commit

Permalink
fix: theorem instantiation in grind (#6492)
Browse files Browse the repository at this point in the history
This PR fixes a bug in the theorem instantiation procedure in the (WIP)
`grind` tactic. For example, it was missing the following instance in
one of the tests:

```lean
[grind.ematch.instance] Array.get_set_ne: ∀ (hj : i < bs.size), j ≠ i → (bs.set j w ⋯)[i] = bs[i]
```

This PR also renames the `grind` base monad to `GrindCoreM`.
  • Loading branch information
leodemoura authored Jan 1, 2025
1 parent 6d44715 commit fedaf85
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 55 deletions.
29 changes: 7 additions & 22 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) :
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.
Missing parameters are synthesized using type inference and type class synthesis."
-/
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do withNewMCtxDepth do
let thm := (← read).thm
unless (← markTheorenInstance thm.proof c.assignment) do
return ()
Expand Down Expand Up @@ -203,32 +203,17 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
if (← mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm.origin (mkAppN proof mvars) c.gen
else
-- instance has hypothesis
mkImp mvars 0 proof #[]
let proof := mkAppN proof mvars
let mvars ← mvars.filterM fun mvar => return !(← mvar.mvarId!.isAssigned)
if let some mvarBad ← mvars.findM? fun mvar => return !(← isProof mvar) then
trace[grind.issues] "failed to instantiate {← thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"
let proof ← mkLambdaFVars (binderInfoForMVars := .default) mvars (← instantiateMVars proof)
addNewInstance thm.origin proof c.gen
where
synthesizeInstance (x type : Expr) : MetaM Bool := do
let .some val ← trySynthInstance type | return false
isDefEq x val

mkImp (mvars : Array Expr) (i : Nat) (proof : Expr) (xs : Array Expr) : M Unit := do
if h : i < mvars.size then
let mvar := mvars[i]
if (← mvar.mvarId!.isAssigned) then
mkImp mvars (i+1) (mkApp proof mvar) xs
else
let mvarType ← instantiateMVars (← inferType mvar)
if mvarType.hasMVar then
let thm := (← read).thm
trace[grind.issues] "failed to create hypothesis for instance of {← thm.origin.pp} hypothesis type has metavars{indentExpr mvarType}"
return ()
withLocalDeclD (← mkFreshUserName `h) mvarType fun x => do
mkImp mvars (i+1) (mkApp proof x) (xs.push x)
else
let proof ← instantiateMVars proof
let proof ← mkLambdaFVars xs proof
let thm := (← read).thm
addNewInstance thm.origin proof c.gen

/-- Process choice stack until we don't have more choices to be processed. -/
private partial def processChoices : M Unit := do
unless (← get).choiceStack.isEmpty do
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Preprocessor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ structure State where
goals : PArray Goal := {}
deriving Inhabited

abbrev PreM := StateRefT State GrindM
abbrev PreM := StateRefT State GrindCoreM

def PreM.run (x : PreM α) : GrindM α := do
def PreM.run (x : PreM α) : GrindCoreM α := do
x.run' {}

inductive IntroResult where
Expand Down Expand Up @@ -168,7 +168,7 @@ def preprocess (mvarId : MVarId) (mainDeclName : Name) (config : Grind.Config) :
Preprocessor.preprocess mvarId |>.run |>.run mainDeclName config

def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) : MetaM (List MVarId) := do
let go : GrindM (List MVarId) := do
let go : GrindCoreM (List MVarId) := do
let s ← Preprocessor.preprocess mvarId |>.run
let goals := s.goals.toList.filter fun goal => !goal.inconsistent
return goals.map (·.mvarId)
Expand Down
8 changes: 4 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Run.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def mkMethods : CoreM Methods := do
prop e
}

def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) : MetaM α := do
def GrindCoreM.run (x : GrindCoreM α) (mainDeclName : Name) (config : Grind.Config) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
Expand All @@ -37,13 +37,13 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) : M
(congrTheorems := (← getSimpCongrTheorems))
x (← mkMethods).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }

@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindCoreM (α × Goal) :=
goal.mvarId.withContext do StateRefT'.run x goal

@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindCoreM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal

def mkGoal (mvarId : MVarId) : GrindM Goal := do
def mkGoal (mvarId : MVarId) : GrindCoreM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
let thmMap ← getEMatchTheorems
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Lean.Meta.Grind
-- TODO: implement `simp` discharger using preprocessor state

/-- Simplifies the given expression using the `grind` simprocs and normalization theorems. -/
def simp (e : Expr) : GrindM Simp.Result := do
def simp (e : Expr) : GrindCoreM Simp.Result := do
let simpStats := (← get).simpStats
let (r, simpStats) ← Meta.simp e (← readThe Context).simp (← readThe Context).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
Expand All @@ -25,7 +25,7 @@ def simp (e : Expr) : GrindM Simp.Result := do
Simplifies `e` using `grind` normalization theorems and simprocs,
and then applies several other preprocessing steps.
-/
def pre (e : Expr) : GrindM Simp.Result := do
def pre (e : Expr) : GrindCoreM Simp.Result := do
let r ← simp e
let e' := r.expr
let e' ← markNestedProofs e'
Expand Down
50 changes: 26 additions & 24 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ register_builtin_option grind.debug.proofs : Bool := {
descr := "check proofs between the elements of all equivalence classes"
}

/-- Context for `GrindM` monad. -/
/-- Context for `GrindCoreM` monad. -/
structure Context where
simp : Simp.Context
simprocs : Array Simp.Simprocs
Expand All @@ -66,8 +66,8 @@ instance : BEq CongrTheoremCacheKey where
instance : Hashable CongrTheoremCacheKey where
hash a := mixHash (unsafe ptrAddrUnsafe a.f).toUInt64 (hash a.numArgs)

/-- State for the `GrindM` monad. -/
structure State where
/-- State for the `GrindCoreM` monad. -/
structure CoreState where
canon : Canon.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
Expand All @@ -87,34 +87,34 @@ private opaque MethodsRefPointed : NonemptyType.{0}
private def MethodsRef : Type := MethodsRefPointed.type
instance : Nonempty MethodsRef := MethodsRefPointed.property

abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM
abbrev GrindCoreM := ReaderT MethodsRef $ ReaderT Context $ StateRefT CoreState MetaM

/-- Returns the user-defined configuration options -/
def getConfig : GrindM Grind.Config :=
def getConfig : GrindCoreM Grind.Config :=
return (← readThe Context).config

/-- Returns the internalized `True` constant. -/
def getTrueExpr : GrindM Expr := do
def getTrueExpr : GrindCoreM Expr := do
return (← get).trueExpr

/-- Returns the internalized `False` constant. -/
def getFalseExpr : GrindM Expr := do
def getFalseExpr : GrindCoreM Expr := do
return (← get).falseExpr

def getMainDeclName : GrindM Name :=
def getMainDeclName : GrindCoreM Name :=
return (← readThe Context).mainDeclName

@[inline] def getMethodsRef : GrindM MethodsRef :=
@[inline] def getMethodsRef : GrindCoreM MethodsRef :=
read

/-- Returns maximum term generation that is considered during ematching. -/
def getMaxGeneration : GrindM Nat := do
def getMaxGeneration : GrindCoreM Nat := do
return (← getConfig).gen

/--
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
-/
def abstractNestedProofs (e : Expr) : GrindM Expr := do
def abstractNestedProofs (e : Expr) : GrindCoreM Expr := do
let nextIdx := (← get).nextThmIdx
let (e, s') ← AbstractNestedProofs.visit e |>.run { baseName := (← getMainDeclName) } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
Expand All @@ -124,24 +124,32 @@ def abstractNestedProofs (e : Expr) : GrindM Expr := do
Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have
been hash-consing. We perform this step before we internalize expressions.
-/
def shareCommon (e : Expr) : GrindM Expr := do
def shareCommon (e : Expr) : GrindCoreM Expr := do
modifyGet fun { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, simpStats } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, simpStats })

/--
Canonicalizes nested types, type formers, and instances in `e`.
-/
def canon (e : Expr) : GrindM Expr := do
def canon (e : Expr) : GrindCoreM Expr := do
let canonS ← modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) ← Canon.canon e |>.run canonS
modify fun s => { s with canon := canonS }
return e

/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindCoreM Bool :=
return isSameExpr e (← getTrueExpr)

/-- Returns `true` if `e` is the internalized `False` expression. -/
def isFalseExpr (e : Expr) : GrindCoreM Bool :=
return isSameExpr e (← getFalseExpr)

/--
Creates a congruence theorem for a `f`-applications with `numArgs` arguments.
-/
def mkHCongrWithArity (f : Expr) (numArgs : Nat) : GrindM CongrTheorem := do
def mkHCongrWithArity (f : Expr) (numArgs : Nat) : GrindCoreM CongrTheorem := do
let key := { f, numArgs }
if let some result := (← get).congrThms.find? key then
return result
Expand Down Expand Up @@ -197,6 +205,7 @@ structure ENode where
-- TODO: see Lean 3 implementation
deriving Inhabited, Repr

/-- New equality to be processed. -/
structure NewEq where
lhs : Expr
rhs : Expr
Expand Down Expand Up @@ -252,6 +261,7 @@ where
| .app f a => go f (mixHash r (hashRoot enodes a))
| _ => mixHash r (hashRoot enodes e)

/-- Returns `true` if `a` and `b` are congruent modulo the equivalence classes in `enodes`. -/
partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool :=
if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then
hasSameRoot enodes (a.getArg! 0) (b.getArg! 0)
Expand Down Expand Up @@ -343,7 +353,7 @@ structure Goal where
def Goal.admit (goal : Goal) : MetaM Unit :=
goal.mvarId.admit

abbrev GoalM := StateRefT Goal GrindM
abbrev GoalM := StateRefT Goal GrindCoreM

abbrev Propagator := Expr → GoalM Unit

Expand All @@ -362,14 +372,6 @@ def markTheorenInstance (proof : Expr) (assignment : Array Expr) : GoalM Bool :=
def checkMaxInstancesExceeded : GoalM Bool := do
return (← get).numInstances >= (← getConfig).instances

/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindM Bool :=
return isSameExpr e (← getTrueExpr)

/-- Returns `true` if `e` is the internalized `False` expression. -/
def isFalseExpr (e : Expr) : GrindM Bool :=
return isSameExpr e (← getFalseExpr)

/--
Returns `some n` if `e` has already been "internalized" into the
Otherwise, returns `none`s.
Expand Down Expand Up @@ -616,7 +618,7 @@ def Methods.toMethodsRef (m : Methods) : MethodsRef :=
private def MethodsRef.toMethods (m : MethodsRef) : Methods :=
unsafe unsafeCast m

@[inline] def getMethods : GrindM Methods :=
@[inline] def getMethods : GrindCoreM Methods :=
return (← getMethodsRef).toMethods

def propagateUp (e : Expr) : GoalM Unit := do
Expand Down
1 change: 1 addition & 0 deletions tests/lean/run/grind_ematch1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set_option grind.debug.proofs true
/--
info: [grind.ematch.instance] Array.get_set_eq: (bs.set j w ⋯)[j] = w
[grind.ematch.instance] Array.get_set_eq: (as.set i v ⋯)[i] = v
[grind.ematch.instance] Array.get_set_ne: ∀ (hj : i < bs.size), j ≠ i → (bs.set j w ⋯)[i] = bs[i]
-/
#guard_msgs (info) in
example (as : Array α)
Expand Down

0 comments on commit fedaf85

Please sign in to comment.