Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ematch theorem activation for grind #6475

Merged
merged 3 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def elabGrindPattern : CommandElab := fun stx => do
let pattern ← instantiateMVars (← elabTerm term none)
let pattern ← Grind.unfoldReducible pattern
return pattern.abstract xs
Grind.addTheoremPattern declName xs.size patterns.toList
Grind.addEMatchTheorem declName xs.size patterns.toList
| _ => throwUnsupportedSyntax

def grind (mvarId : MVarId) (mainDeclName : Name) : MetaM Unit := do
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.Ctor
import Lean.Meta.Tactic.Grind.Parser
import Lean.Meta.Tactic.Grind.TheoremPatterns
import Lean.Meta.Tactic.Grind.EMatchTheorem

namespace Lean

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ inductive Origin where
| other
deriving Inhabited, Repr

structure TheoremPattern where
/-- A unique identifier corresponding to the origin. -/
def Origin.key : Origin → Name
| .decl declName => declName
| .fvar fvarId => fvarId.name
| .stx id _ => id
| .other => `other

/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
proof : Expr
numParams : Nat
patterns : List Expr
Expand All @@ -34,16 +42,21 @@ structure TheoremPattern where
origin : Origin
deriving Inhabited

abbrev TheoremPatterns := SMap Name (List TheoremPattern)
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
abbrev EMatchTheorems := PHashMap Name (List EMatchTheorem)

builtin_initialize theoremPatternsExt : SimpleScopedEnvExtension TheoremPattern TheoremPatterns ←
def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchTheorems := Id.run do
let .const declName :: syms := thm.symbols
| unreachable!
let thm := { thm with symbols := syms }
if let some thms := s.find? declName then
return PersistentHashMap.insert s declName (thm::thms)
else
return PersistentHashMap.insert s declName [thm]

private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem EMatchTheorems ←
registerSimpleScopedEnvExtension {
addEntry := fun s t => Id.run do
let .const declName :: _ := t.symbols | unreachable!
if let some ts := s.find? declName then
s.insert declName (t::ts)
else
s.insert declName [t]
addEntry := EMatchTheorems.insert
initial := .empty
}

Expand Down Expand Up @@ -282,19 +295,23 @@ private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : M
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg

def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
assert! symbols.all fun s => s matches .const _
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
theoremPatternsExt.add {
ematchTheoremsExt.add {
proof, patterns, numParams, symbols
origin := .decl declName
}

def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)

end Lean.Meta.Grind
19 changes: 18 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ private def updateAppMap (e : Expr) : GoalM Unit := do
s.appMap.insert key [e]
}

private def activateTheoremPatterns (fName : Name) : GoalM Unit := do
if let some thms := (← get).thmMap.find? fName then
modify fun s => { s with thmMap := s.thmMap.erase fName }
let appMap := (← get).appMap
for thm in thms do
let symbols := thm.symbols.filter fun sym => !appMap.contains sym
let thm := { thm with symbols }
match symbols with
| [] =>
trace[grind.pattern] "activated `{thm.origin.key}`"
modify fun s => { s with newThms := s.newThms.push thm }
| _ =>
trace[grind.pattern] "reinsert `{thm.origin.key}`"
modify fun s => { s with thmMap := s.thmMap.insert thm }

partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
match e with
Expand All @@ -63,7 +78,9 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
internalize c generation
registerParent e c
else
unless f.isConst do
if let .const fName _ := f then
activateTheoremPatterns fName
else
internalize f generation
registerParent e f
for h : i in [: args.size] do
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Run.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
GoalM.run' { mvarId } do
let thmMap ← getEMatchTheorems
GoalM.run' { mvarId, thmMap } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)

Expand Down
10 changes: 10 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.Attr
import Lean.Meta.Tactic.Grind.EMatchTheorem

namespace Lean.Meta.Grind

Expand Down Expand Up @@ -273,6 +274,15 @@ structure Goal where
gmt : Nat := 0
/-- Next unique index for creating ENodes -/
nextIdx : Nat := 0
/-- Active theorems that we have performed ematching at least once. -/
thms : PArray EMatchTheorem := {}
/-- Active theorems that we have not performed any round of ematching yet. -/
newThms : PArray EMatchTheorem := {}
/--
Inactive global theorems. As we internalize terms, we activate theorems as we find their symbols.
Local theorem provided by users are added directly into `newThms`.
-/
thmMap : EMatchTheorems
deriving Inhabited

def Goal.admit (goal : Goal) : MetaM Unit :=
Expand Down
39 changes: 39 additions & 0 deletions tests/lean/run/grind_pattern2.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
def Set (α : Type) := α → Bool

def insertElem [DecidableEq α] (s : Set α) (a : α) : Set α :=
fun x => a = x || s x

def contains (s : Set α) (a : α) : Bool :=
s a

theorem contains_insert [DecidableEq α] (s : Set α) (a : α) : contains (insertElem s a) a := by
simp [contains, insertElem]

grind_pattern contains_insert => contains (insertElem s a) a

-- TheoremPattern activation test

set_option trace.grind.pattern true

/--
warning: declaration uses 'sorry'
---
info: [grind.pattern] activated `contains_insert`
-/
#guard_msgs in
example [DecidableEq α] (s₁ s₂ : Set α) (a₁ a₂ : α) :
s₂ = insertElem s₁ a₁ → a₁ = a₂ → contains s₂ a₂ := by
fail_if_success grind
sorry

/--
warning: declaration uses 'sorry'
---
info: [grind.pattern] reinsert `contains_insert`
[grind.pattern] activated `contains_insert`
-/
#guard_msgs in
example [DecidableEq α] (s₁ s₂ : Set α) (a₁ a₂ : α) :
¬ contains s₂ a₂ → s₂ = insertElem s₁ a₁ → a₁ = a₂ → False := by
fail_if_success grind
sorry
Loading