Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: E-matching module for grind #6488

Merged
merged 4 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Internalize

namespace Lean.Meta.Grind

/-- Returns maximum term generation that is considered during ematching -/
private def getMaxGeneration : GoalM Nat := do
return 10000 -- TODO

/-- Returns `true` if the maximum number of instances has been reached. -/
private def checkMaxInstancesExceeded : GoalM Bool := do
return false -- TODO
/--
Theorem instance found using E-matching.
Recall that we only internalize new instances after we complete a full round of E-matching. -/
structure EMatchTheoremInstance where
proof : Expr
prop : Expr
generation : Nat
deriving Inhabited

namespace EMatch
/-! This module implements a simple E-matching procedure as a backtracking search. -/
Expand Down Expand Up @@ -51,13 +51,6 @@ structure Choice where
assignment : Array Expr
deriving Inhabited

/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/
structure TheoremInstance where
proof : Expr
prop : Expr
generation : Nat
deriving Inhabited

/-- Context for the E-matching monad. -/
structure Context where
/-- `useMT` is `true` if we are using the mod-time optimization. It is always set to false for new `EMatchTheorem`s. -/
Expand All @@ -70,7 +63,7 @@ structure Context where
structure State where
/-- Choices that still have to be processed. -/
choiceStack : List Choice := []
newInstances : PArray TheoremInstance := {}
newInstances : Array EMatchTheoremInstance := #[]
deriving Inhabited

abbrev M := ReaderT Context $ StateRefT State GoalM
Expand Down Expand Up @@ -181,6 +174,8 @@ Missing parameters are synthesized using type inference and type class synthesis
-/
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
let thm := (← read).thm
unless (← markTheorenInstance thm.proof c.assignment) do
return ()
trace[grind.ematch.instance.assignment] "{← thm.origin.pp}: {assignmentToMessageData c.assignment}"
let proof ← thm.getProofWithFreshMVarLevels
let numParams := thm.numParams
Expand Down Expand Up @@ -285,22 +280,26 @@ where
def ematchTheorems (thms : PArray EMatchTheorem) : M Unit := do
thms.forM ematchTheorem

def internalizeNewInstances : M Unit := do
-- TODO
return ()

end EMatch

open EMatch

/-- Performs one round of E-matching, and internalizes new instances. -/
def ematch : GoalM Unit := do
let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do
/-- Performs one round of E-matching, and returns new instances. -/
def ematch : GoalM (Array EMatchTheoremInstance) := do
let go (thms newThms : PArray EMatchTheorem) : EMatch.M (Array EMatchTheoremInstance) := do
withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms
withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms
internalizeNewInstances
unless (← checkMaxInstancesExceeded) do
go (← get).thms (← get).newThms |>.run'
modify fun s => { s with thms := s.thms ++ s.newThms, newThms := {}, gmt := s.gmt + 1 }
return (← get).newInstances
if (← checkMaxInstancesExceeded) then
return #[]
else
let insts ← go (← get).thms (← get).newThms |>.run'
modify fun s => { s with
thms := s.thms ++ s.newThms
newThms := {}
gmt := s.gmt + 1
numInstances := s.numInstances + insts.size
}
return insts

end Lean.Meta.Grind
5 changes: 4 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) :
let thm := { thm with symbols }
match symbols with
| [] =>
let thm := { thm with patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
let proof ← shareCommon thm.proof
let thm := { thm with proof, patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }
| _ =>
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind/Preprocessor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def preprocess (mvarId : MVarId) : PreM State := do
loop (← mkGoal mvarId)
let goals := (← get).goals
-- Testing `ematch` module here. We will rewrite this part later.
let goals ← goals.mapM fun goal => GoalM.run' goal ematch
let goals ← goals.mapM fun goal => GoalM.run' goal (discard <| ematch)
if (← isTracingEnabledFor `grind.pre) then
trace[grind.debug.pre] (← ppGoals goals)
for goal in goals do
Expand Down
125 changes: 93 additions & 32 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def getMainDeclName : GrindM Name :=
@[inline] def getMethodsRef : GrindM MethodsRef :=
read

/--
Returns maximum term generation that is considered during ematching. -/
def getMaxGeneration : GrindM Nat := do
return 10000 -- TODO

/--
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
-/
Expand Down Expand Up @@ -193,31 +198,44 @@ structure NewEq where
proof : Expr
isHEq : Bool

abbrev ENodes := PHashMap USize ENode
/--
Key for the `ENodeMap` and `ParentMap` map.
We use pointer addresses and rely on the fact all internalized expressions
have been hash-consed, i.e., we have applied `shareCommon`.
-/
private structure ENodeKey where
expr : Expr

structure CongrKey (enodes : ENodes) where
e : Expr
instance : Hashable ENodeKey where
hash k := unsafe (ptrAddrUnsafe k.expr).toUInt64

private abbrev toENodeKey (e : Expr) : USize :=
unsafe ptrAddrUnsafe e
instance : BEq ENodeKey where
beq k₁ k₂ := isSameExpr k₁.expr k₂.expr

abbrev ENodeMap := PHashMap ENodeKey ENode

/--
Key for the congruence table.
We need access to the `enodes` to be able to retrieve the equivalence class roots.
-/
structure CongrKey (enodes : ENodeMap) where
e : Expr

private def hashRoot (enodes : ENodes) (e : Expr) : UInt64 :=
if let some node := enodes.find? (toENodeKey e) then
toENodeKey node.root |>.toUInt64
private def hashRoot (enodes : ENodeMap) (e : Expr) : UInt64 :=
if let some node := enodes.find? { expr := e } then
unsafe (ptrAddrUnsafe node.root).toUInt64
else
13

private def hasSameRoot (enodes : ENodes) (a b : Expr) : Bool := Id.run do
let ka := toENodeKey a
let kb := toENodeKey b
if ka == kb then
private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
if isSameExpr a b then
return true
else
let some n1 := enodes.find? ka | return false
let some n2 := enodes.find? kb | return false
toENodeKey n1.root == toENodeKey n2.root
let some n1 := enodes.find? { expr := a } | return false
let some n2 := enodes.find? { expr := b } | return false
isSameExpr n1.root n2.root

def congrHash (enodes : ENodes) (e : Expr) : UInt64 :=
def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 :=
if e.isAppOfArity ``Lean.Grind.nestedProof 2 then
-- We only hash the proposition
hashRoot enodes (e.getArg! 0)
Expand All @@ -229,7 +247,7 @@ where
| .app f a => go f (mixHash r (hashRoot enodes a))
| _ => mixHash r (hashRoot enodes e)

partial def isCongruent (enodes : ENodes) (a b : Expr) : Bool :=
partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool :=
if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then
hasSameRoot enodes (a.getArg! 0) (b.getArg! 0)
else
Expand All @@ -249,15 +267,43 @@ instance : Hashable (CongrKey enodes) where
instance : BEq (CongrKey enodes) where
beq k1 k2 := isCongruent enodes k1.e k2.e

abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes)
abbrev CongrTable (enodes : ENodeMap) := PHashSet (CongrKey enodes)

-- Remark: we cannot use pointer addresses here because we have to traverse the tree.
abbrev ParentSet := RBTree Expr Expr.quickComp
abbrev ParentMap := PHashMap USize ParentSet
abbrev ParentMap := PHashMap ENodeKey ParentSet

/--
The E-matching module instantiates theorems using the `EMatchTheorem proof` and a (partial) assignment.
We want to avoid instantiating the same theorem with the same assignment more than once.
Therefore, we store the (pre-)instance information in set.
Recall that the proofs of activated theorems have been hash-consed.
The assignment contains internalized expressions, which have also been hash-consed.
-/
structure PreInstance where
proof : Expr
assignment : Array Expr

instance : Hashable PreInstance where
hash i := Id.run do
let mut r := unsafe (ptrAddrUnsafe i.proof >>> 3).toUInt64
for v in i.assignment do
r := mixHash r (unsafe (ptrAddrUnsafe v >>> 3).toUInt64)
return r

instance : BEq PreInstance where
beq i₁ i₂ := Id.run do
unless isSameExpr i₁.proof i₂.proof do return false
unless i₁.assignment.size == i₂.assignment.size do return false
for v₁ in i₁.assignment, v₂ in i₂.assignment do
unless isSameExpr v₁ v₂ do return false
return true

abbrev PreInstanceSet := PHashSet PreInstance

structure Goal where
mvarId : MVarId
enodes : ENodes := {}
enodes : ENodeMap := {}
parents : ParentMap := {}
congrTable : CongrTable enodes := {}
/--
Expand Down Expand Up @@ -285,6 +331,8 @@ structure Goal where
thmMap : EMatchTheorems
/-- Number of theorem instances generated so far -/
numInstances : Nat := 0
/-- (pre-)instances found so far -/
instances : PreInstanceSet := {}
deriving Inhabited

def Goal.admit (goal : Goal) : MetaM Unit :=
Expand All @@ -294,6 +342,21 @@ abbrev GoalM := StateRefT Goal GrindM

abbrev Propagator := Expr → GoalM Unit

/--
A helper function used to mark a theorem instance found by the E-matching module.
It returns `true` if it is a new instance and `false` otherwise.
-/
def markTheorenInstance (proof : Expr) (assignment : Array Expr) : GoalM Bool := do
let k := { proof, assignment }
if (← get).instances.contains k then
return false
modify fun s => { s with instances := s.instances.insert k }
return true

/-- Returns `true` if the maximum number of instances has been reached. -/
def checkMaxInstancesExceeded : GoalM Bool := do
return false -- TODO

/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindM Bool :=
return isSameExpr e (← getTrueExpr)
Expand All @@ -307,11 +370,11 @@ Returns `some n` if `e` has already been "internalized" into the
Otherwise, returns `none`s.
-/
def getENode? (e : Expr) : GoalM (Option ENode) :=
return (← get).enodes.find? (unsafe ptrAddrUnsafe e)
return (← get).enodes.find? { expr := e }

/-- Returns node associated with `e`. It assumes `e` has already been internalized. -/
def getENode (e : Expr) : GoalM ENode := do
let some n := (← get).enodes.find? (unsafe ptrAddrUnsafe e)
let some n := (← get).enodes.find? { expr := e }
| throwError "internal `grind` error, term has not been internalized{indentExpr e}"
return n

Expand Down Expand Up @@ -362,7 +425,7 @@ def getNext (e : Expr) : GoalM Expr :=

/-- Returns `true` if `e` has already been internalized. -/
def alreadyInternalized (e : Expr) : GoalM Bool :=
return (← get).enodes.contains (unsafe ptrAddrUnsafe e)
return (← get).enodes.contains { expr := e }

def getTarget? (e : Expr) : GoalM (Option Expr) := do
let some n ← getENode? e | return none
Expand Down Expand Up @@ -407,41 +470,39 @@ information in the root (aka canonical representative) of `child`.
-/
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
let some childRoot ← getRoot? child | return ()
let key := toENodeKey childRoot
let parents := if let some parents := (← get).parents.find? key then parents else {}
modify fun s => { s with parents := s.parents.insert key (parents.insert parent) }
let parents := if let some parents := (← get).parents.find? { expr := childRoot } then parents else {}
modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) }

/--
Returns the set of expressions `e` is a child of, or an expression in
`e`s equivalence class is a child of.
The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class.
-/
def getParents (e : Expr) : GoalM ParentSet := do
let some parents := (← get).parents.find? (toENodeKey e) | return {}
let some parents := (← get).parents.find? { expr := e } | return {}
return parents

/--
Similar to `getParents`, but also removes the entry `e ↦ parents` from the parent map.
-/
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
let parents ← getParents e
modify fun s => { s with parents := s.parents.erase (toENodeKey e) }
modify fun s => { s with parents := s.parents.erase { expr := e } }
return parents

/--
Copy `parents` to the parents of `root`.
`root` must be the root of its equivalence class.
-/
def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
let key := toENodeKey root
let mut curr := if let some parents := (← get).parents.find? key then parents else {}
let mut curr := if let some parents := (← get).parents.find? { expr := root } then parents else {}
for parent in parents do
curr := curr.insert parent
modify fun s => { s with parents := s.parents.insert key curr }
modify fun s => { s with parents := s.parents.insert { expr := root } curr }

def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with
enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n
enodes := s.enodes.insert { expr := e } n
congrTable := unsafe unsafeCast s.congrTable
}

Expand Down
Loading