Skip to content

Commit

Permalink
feat: detect congruences
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Dec 23, 2024
1 parent c3fdce3 commit 169d7a7
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 14 deletions.
30 changes: 24 additions & 6 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,15 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
-- proof irrelevance
let c := args[0]!
internalize c generation
addOccurrence e c
registerParent e c
else
unless f.isConst do
internalize f generation
addOccurrence e f
registerParent e f
for h : i in [: args.size] do
let arg := args[i]
internalize arg generation
addOccurrence e arg
registerParent e arg
mkENode e generation
addCongrTable e

Expand Down Expand Up @@ -180,6 +180,24 @@ private def markAsInconsistent : GoalM Unit :=
def isInconsistent : GoalM Bool :=
return (← get).inconsistent

/--
Remove `root` parents from the congruence table.
This is an auxiliary function performed while merging equivalence classes.
-/
private def removeParents (root : Expr) : GoalM ParentSet := do
let occs ← getParents root
for parent in occs do
modify fun s => { s with congrTable := s.congrTable.erase { e := parent } }
return occs

/--
Reinsert parents into the congruence table and detect new equalities.
This is an auxiliary function performed while merging equivalence classes.
-/
private def reinsertParents (parents : ParentSet) : GoalM Unit := do
for parent in parents do
addCongrTable parent

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "" else "="} {rhs}"
let lhsNode ← getENode lhs
Expand Down Expand Up @@ -222,11 +240,11 @@ where
proof? := proof
flipped
}
-- TODO: Remove parents from congruence table
let parents ← removeParents lhsRoot.self
-- TODO: set propagateBool
updateRoots lhs rhsNode.root true -- TODO
trace[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}"
-- TODO: Reinsert parents into congruence table
reinsertParents parents
setENode lhsNode.root { (← getENode lhsRoot.self) with -- We must retrieve `lhsRoot` since it was updated.
next := rhsRoot.next
}
Expand All @@ -236,7 +254,7 @@ where
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
-- TODO: copy parentst from lhsRoot parents to rhsRoot parents
copyParentsTo parents rhsNode.root

updateRoots (lhs : Expr) (rootNew : Expr) (_propagateBool : Bool) : GoalM Unit := do
let rec loop (e : Expr) : GoalM Unit := do
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind/Inv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ register_builtin_option grind.debug : Bool := {
descr := "check invariants after updates"
}

def checkEqc (root : ENode) : GoalM Unit := do
private def checkEqc (root : ENode) : GoalM Unit := do
let mut size := 0
let mut curr := root.self
repeat
Expand Down
30 changes: 23 additions & 7 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes)

-- Remark: we cannot use pointer addresses here because we have to traverse the tree.
abbrev ParentSet := RBTree Expr Expr.quickComp
abbrev Occurrences := PHashMap USize ParentSet
abbrev ParentMap := PHashMap USize ParentSet

structure Goal where
mvarId : MVarId
enodes : ENodes := {}
occs : Occurrences := {}
parents : ParentMap := {}
congrTable : CongrTable enodes := {}
/-- Equations to be processed. -/
newEqs : Array NewEq := #[]
Expand Down Expand Up @@ -260,6 +260,11 @@ def getENode (e : Expr) : GoalM ENode := do
let some n := (← get).enodes.find? (unsafe ptrAddrUnsafe e) | unreachable!
return n

/-- Returns `true` is the root of its equivalence class. -/
def isRoot (e : Expr) : GoalM Bool := do
let some n ← getENode? e | return false -- `e` has not been internalized. Panic instead?
return isSameExpr n.root e

/-- Returns the root element in the equivalence class of `e` IF `e` has been internalized. -/
def getRoot? (e : Expr) : GoalM (Option Expr) := do
let some n ← getENode? e | return none
Expand All @@ -285,23 +290,34 @@ def getTarget? (e : Expr) : GoalM (Option Expr) := do
Records that `parent` is a parent of `child`. This function actually stores the
information in the root (aka canonical representative) of `child`.
-/
def addOccurrence (parent : Expr) (child : Expr) : GoalM Unit := do
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
let some childRoot ← getRoot? child | return ()
let key := toENodeKey childRoot
let parents := if let some parents := (← get).occs.find? key then parents else {}
modify fun s => { s with occs := s.occs.insert key (parents.insert parent) }
let parents := if let some parents := (← get).parents.find? key then parents else {}
modify fun s => { s with parents := s.parents.insert key (parents.insert parent) }

/--
Returns the set of expressions `e` is a child of, or an expression in
`e`s equivalence class is a child of.
The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class.
-/
def getOccurrences (e : Expr) : GoalM ParentSet := do
if let some occs := (← get).occs.find? (toENodeKey e) then
def getParents (e : Expr) : GoalM ParentSet := do
if let some occs := (← get).parents.find? (toENodeKey e) then
return occs
else
return {}

/--
Copy `parents` to the parents of `root`.
`root` must be the root of its equivalence class.
-/
def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
let key := toENodeKey root
let mut curr := if let some parents := (← get).parents.find? key then parents else {}
for parent in parents do
curr := curr.insert parent
modify fun s => { s with parents := s.parents.insert key curr }

def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with
enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n
Expand Down
31 changes: 31 additions & 0 deletions tests/lean/run/grind_congr.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import Lean

def f (a : Nat) := a + a + a
def g (a : Nat) := a + a

-- Prints the equivalence class containing a `f` application
open Lean Meta Elab Tactic Grind in
Expand All @@ -22,3 +23,33 @@ warning: declaration uses 'sorry'
example (a b c d : Nat) : a = b → f a = c → f b = d → False := by
grind_test
sorry

/--
info: [d, f b, c, f a]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d : Nat) : f a = c → f b = d → a = b → False := by
grind_test
sorry

/--
info: [d, f (g b), c, f (g a)]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d e : Nat) : f (g a) = c → f (g b) = d → a = e → b = e → False := by
grind_test
sorry

/--
info: [d, f (g b), c, f v]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d e v : Nat) : f v = c → f (g b) = d → a = e → b = e → v = g a → False := by
grind_test
sorry

0 comments on commit 169d7a7

Please sign in to comment.