Skip to content

Commit

Permalink
feat: detect congruent terms in grind (#6437)
Browse files Browse the repository at this point in the history
This PR adds support for detecting congruent terms in the (WIP) `grind`
tactic. It also introduces the `grind.debug` option, which, when set to
`true`, checks many invariants after each equivalence class is merged.
This option is intended solely for debugging purposes.
  • Loading branch information
leodemoura authored Dec 24, 2024
1 parent 5240405 commit b18f3a3
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 13 deletions.
5 changes: 5 additions & 0 deletions src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ opaque quickLt (a : @& Expr) (b : @& Expr) : Bool
@[extern "lean_expr_lt"]
opaque lt (a : @& Expr) (b : @& Expr) : Bool

def quickComp (a b : Expr) : Ordering :=
if quickLt a b then .lt
else if quickLt b a then .gt
else .eq

/--
Return true iff `a` and `b` are alpha equivalent.
Binder annotations are ignored.
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.MarkNestedProofs
import Lean.Meta.Tactic.Grind.Inv

namespace Lean

Expand Down
52 changes: 39 additions & 13 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Util
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.LitValues

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -60,6 +61,9 @@ def ppENodeDecl (e : Expr) : GoalM Format := do
r := r ++ ", [val]"
if n.ctor then
r := r ++ ", [ctor]"
if grind.debug.get (← getOptions) then
if let some target ← getTarget? e then
r := r ++ f!" ↝ {← ppENodeRef target}"
return r

/-- Pretty print goal state for debugging purposes. -/
Expand Down Expand Up @@ -136,13 +140,17 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then
-- We only internalize the proposition. We can skip the proof because of
-- proof irrelevance
internalize args[0]! generation
let c := args[0]!
internalize c generation
registerParent e c
else
unless f.isConst do
internalize f generation
registerParent e f
for h : i in [: args.size] do
let arg := args[i]
internalize arg generation
registerParent e arg
mkENode e generation
addCongrTable e

Expand Down Expand Up @@ -172,6 +180,24 @@ private def markAsInconsistent : GoalM Unit :=
def isInconsistent : GoalM Bool :=
return (← get).inconsistent

/--
Remove `root` parents from the congruence table.
This is an auxiliary function performed while merging equivalence classes.
-/
private def removeParents (root : Expr) : GoalM ParentSet := do
let parents ← getParentsAndReset root
for parent in parents do
modify fun s => { s with congrTable := s.congrTable.erase { e := parent } }
return parents

/--
Reinsert parents into the congruence table and detect new equalities.
This is an auxiliary function performed while merging equivalence classes.
-/
private def reinsertParents (parents : ParentSet) : GoalM Unit := do
for parent in parents do
addCongrTable parent

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "" else "="} {rhs}"
let lhsNode ← getENode lhs
Expand All @@ -182,23 +208,24 @@ private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit
return ()
let lhsRoot ← getENode lhsNode.root
let rhsRoot ← getENode rhsNode.root
let mut valueInconsistency := false
if lhsRoot.interpreted && rhsRoot.interpreted then
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
else
valueInconsistency := true
if (lhsRoot.interpreted && !rhsRoot.interpreted)
|| (lhsRoot.ctor && !rhsRoot.ctor)
|| (lhsRoot.size > rhsRoot.size && !rhsRoot.interpreted && !rhsRoot.ctor) then
go rhs lhs rhsNode lhsNode rhsRoot lhsRoot true
else
go lhs rhs lhsNode rhsNode lhsRoot rhsRoot false
-- TODO: propagate value inconsistency
trace[grind.debug] "after addEqStep, {← ppState}"
checkInvariants
where
go (lhs rhs : Expr) (lhsNode rhsNode lhsRoot rhsRoot : ENode) (flipped : Bool) : GoalM Unit := do
trace[grind.debug] "adding {← ppENodeRef lhs} ↦ {← ppENodeRef rhs}"
let mut valueInconsistency := false
if lhsRoot.interpreted && rhsRoot.interpreted then
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
else
valueInconsistency := true
-- TODO: process valueInconsistency := true
/-
We have the following `target?/proof?`
`lhs -> ... -> lhsNode.root`
Expand All @@ -213,22 +240,21 @@ where
proof? := proof
flipped
}
-- TODO: Remove parents from congruence table
let parents ← removeParents lhsRoot.self
-- TODO: set propagateBool
updateRoots lhs rhsNode.root true -- TODO
trace[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}"
-- TODO: Reinsert parents into congruence table
setENode lhsNode.root { lhsRoot with
reinsertParents parents
setENode lhsNode.root { (← getENode lhsRoot.self) with -- We must retrieve `lhsRoot` since it was updated.
next := rhsRoot.next
root := rhsNode.root
}
setENode rhsNode.root { rhsRoot with
next := lhsRoot.next
size := rhsRoot.size + lhsRoot.size
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
-- TODO: copy parentst from lhsRoot parents to rhsRoot parents
copyParentsTo parents rhsNode.root

updateRoots (lhs : Expr) (rootNew : Expr) (_propagateBool : Bool) : GoalM Unit := do
let rec loop (e : Expr) : GoalM Unit := do
Expand Down
67 changes: 67 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Inv.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types

namespace Lean.Meta.Grind

/-!
Debugging support code for checking basic invariants.
-/

register_builtin_option grind.debug : Bool := {
defValue := false
group := "debug"
descr := "check invariants after updates"
}

private def checkEqc (root : ENode) : GoalM Unit := do
let mut size := 0
let mut curr := root.self
repeat
size := size + 1
-- The root of `curr` must be `root`
assert! isSameExpr (← getRoot curr) root.self
-- Starting at `curr`, following the `target?` field leads to `root`.
let mut n := curr
repeat
if let some target ← getTarget? n then
n := target
else
break
assert! isSameExpr n root.self
-- Go to next element
curr ← getNext curr
if isSameExpr root.self curr then
break
-- The size of the equivalence class is correct.
assert! root.size == size

private def checkParents (e : Expr) : GoalM Unit := do
if (← isRoot e) then
for parent in (← getParents e) do
let mut found := false
-- There is an argument `arg` s.t. root of `arg` is `e`.
for arg in parent.getAppArgs do
if isSameExpr (← getRoot arg) e then
found := true
break
assert! found
else
-- All the parents are stored in the root of the equivalence class.
assert! (← getParents e).isEmpty

/--
Check basic invariants if `grind.debug` is enabled.
-/
def checkInvariants : GoalM Unit := do
if grind.debug.get (← getOptions) then
for (_, node) in (← get).enodes do
checkParents node.self
if isSameExpr node.self node.root then
checkEqc node

end Lean.Meta.Grind
57 changes: 57 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,14 @@ instance : BEq (CongrKey enodes) where

abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes)

-- Remark: we cannot use pointer addresses here because we have to traverse the tree.
abbrev ParentSet := RBTree Expr Expr.quickComp
abbrev ParentMap := PHashMap USize ParentSet

structure Goal where
mvarId : MVarId
enodes : ENodes := {}
parents : ParentMap := {}
congrTable : CongrTable enodes := {}
/-- Equations to be processed. -/
newEqs : Array NewEq := #[]
Expand Down Expand Up @@ -255,6 +260,16 @@ def getENode (e : Expr) : GoalM ENode := do
let some n := (← get).enodes.find? (unsafe ptrAddrUnsafe e) | unreachable!
return n

/-- Returns `true` is the root of its equivalence class. -/
def isRoot (e : Expr) : GoalM Bool := do
let some n ← getENode? e | return false -- `e` has not been internalized. Panic instead?
return isSameExpr n.root e

/-- Returns the root element in the equivalence class of `e` IF `e` has been internalized. -/
def getRoot? (e : Expr) : GoalM (Option Expr) := do
let some n ← getENode? e | return none
return some n.root

/-- Returns the root element in the equivalence class of `e`. -/
def getRoot (e : Expr) : GoalM Expr :=
return (← getENode e).root
Expand All @@ -267,6 +282,48 @@ def getNext (e : Expr) : GoalM Expr :=
def alreadyInternalized (e : Expr) : GoalM Bool :=
return (← get).enodes.contains (unsafe ptrAddrUnsafe e)

def getTarget? (e : Expr) : GoalM (Option Expr) := do
let some n ← getENode? e | return none
return n.target?

/--
Records that `parent` is a parent of `child`. This function actually stores the
information in the root (aka canonical representative) of `child`.
-/
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
let some childRoot ← getRoot? child | return ()
let key := toENodeKey childRoot
let parents := if let some parents := (← get).parents.find? key then parents else {}
modify fun s => { s with parents := s.parents.insert key (parents.insert parent) }

/--
Returns the set of expressions `e` is a child of, or an expression in
`e`s equivalence class is a child of.
The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class.
-/
def getParents (e : Expr) : GoalM ParentSet := do
let some parents := (← get).parents.find? (toENodeKey e) | return {}
return parents

/--
Similar to `getParents`, but also removes the entry `e ↦ parents` from the parent map.
-/
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
let parents ← getParents e
modify fun s => { s with parents := s.parents.erase (toENodeKey e) }
return parents

/--
Copy `parents` to the parents of `root`.
`root` must be the root of its equivalence class.
-/
def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
let key := toENodeKey root
let mut curr := if let some parents := (← get).parents.find? key then parents else {}
for parent in parents do
curr := curr.insert parent
modify fun s => { s with parents := s.parents.insert key curr }

def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with
enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n
Expand Down
33 changes: 33 additions & 0 deletions tests/lean/run/grind_congr.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import Lean

def f (a : Nat) := a + a + a
def g (a : Nat) := a + a

-- Prints the equivalence class containing a `f` application
open Lean Meta Elab Tactic Grind in
Expand All @@ -11,6 +12,8 @@ elab "grind_test" : tactic => withMainContext do
let eqc ← getEqc n.self
logInfo eqc

set_option grind.debug true

/--
info: [d, f b, c, f a]
---
Expand All @@ -20,3 +23,33 @@ warning: declaration uses 'sorry'
example (a b c d : Nat) : a = b → f a = c → f b = d → False := by
grind_test
sorry

/--
info: [d, f b, c, f a]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d : Nat) : f a = c → f b = d → a = b → False := by
grind_test
sorry

/--
info: [d, f (g b), c, f (g a)]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d e : Nat) : f (g a) = c → f (g b) = d → a = e → b = e → False := by
grind_test
sorry

/--
info: [d, f (g b), c, f v]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d e v : Nat) : f v = c → f (g b) = d → a = e → b = e → v = g a → False := by
grind_test
sorry
2 changes: 2 additions & 0 deletions tests/lean/run/grind_nested_proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ elab "grind_test" : tactic => withMainContext do
let nodes ← filterENodes fun e => return e.self.isAppOf ``Lean.Grind.nestedProof
logInfo (nodes.toList.map (·.self))

set_option grind.debug true

/--
info: [Lean.Grind.nestedProof (i < a.toList.length) (_example.proof_1 i j a b h1 h2),
Lean.Grind.nestedProof (j < a.toList.length) h1,
Expand Down
2 changes: 2 additions & 0 deletions tests/lean/run/grind_pre.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ elab "grind_pre" : tactic => do

abbrev f (a : α) := a

set_option grind.debug true

/--
warning: declaration uses 'sorry'
---
Expand Down

0 comments on commit b18f3a3

Please sign in to comment.