Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: detect congruent terms in grind #6437

Merged
merged 5 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ opaque quickLt (a : @& Expr) (b : @& Expr) : Bool
@[extern "lean_expr_lt"]
opaque lt (a : @& Expr) (b : @& Expr) : Bool

def quickComp (a b : Expr) : Ordering :=
if quickLt a b then .lt
else if quickLt b a then .gt
else .eq

/--
Return true iff `a` and `b` are alpha equivalent.
Binder annotations are ignored.
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.MarkNestedProofs
import Lean.Meta.Tactic.Grind.Inv

namespace Lean

Expand Down
52 changes: 39 additions & 13 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Util
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.LitValues

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -60,6 +61,9 @@ def ppENodeDecl (e : Expr) : GoalM Format := do
r := r ++ ", [val]"
if n.ctor then
r := r ++ ", [ctor]"
if grind.debug.get (← getOptions) then
if let some target ← getTarget? e then
r := r ++ f!" ↝ {← ppENodeRef target}"
return r

/-- Pretty print goal state for debugging purposes. -/
Expand Down Expand Up @@ -136,13 +140,17 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then
-- We only internalize the proposition. We can skip the proof because of
-- proof irrelevance
internalize args[0]! generation
let c := args[0]!
internalize c generation
registerParent e c
else
unless f.isConst do
internalize f generation
registerParent e f
for h : i in [: args.size] do
let arg := args[i]
internalize arg generation
registerParent e arg
mkENode e generation
addCongrTable e

Expand Down Expand Up @@ -172,6 +180,24 @@ private def markAsInconsistent : GoalM Unit :=
def isInconsistent : GoalM Bool :=
return (← get).inconsistent

/--
Remove `root` parents from the congruence table.
This is an auxiliary function performed while merging equivalence classes.
-/
private def removeParents (root : Expr) : GoalM ParentSet := do
let parents ← getParentsAndReset root
for parent in parents do
modify fun s => { s with congrTable := s.congrTable.erase { e := parent } }
return parents

/--
Reinsert parents into the congruence table and detect new equalities.
This is an auxiliary function performed while merging equivalence classes.
-/
private def reinsertParents (parents : ParentSet) : GoalM Unit := do
for parent in parents do
addCongrTable parent

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "≡" else "="} {rhs}"
let lhsNode ← getENode lhs
Expand All @@ -182,23 +208,24 @@ private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit
return ()
let lhsRoot ← getENode lhsNode.root
let rhsRoot ← getENode rhsNode.root
let mut valueInconsistency := false
if lhsRoot.interpreted && rhsRoot.interpreted then
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
else
valueInconsistency := true
if (lhsRoot.interpreted && !rhsRoot.interpreted)
|| (lhsRoot.ctor && !rhsRoot.ctor)
|| (lhsRoot.size > rhsRoot.size && !rhsRoot.interpreted && !rhsRoot.ctor) then
go rhs lhs rhsNode lhsNode rhsRoot lhsRoot true
else
go lhs rhs lhsNode rhsNode lhsRoot rhsRoot false
-- TODO: propagate value inconsistency
trace[grind.debug] "after addEqStep, {← ppState}"
checkInvariants
where
go (lhs rhs : Expr) (lhsNode rhsNode lhsRoot rhsRoot : ENode) (flipped : Bool) : GoalM Unit := do
trace[grind.debug] "adding {← ppENodeRef lhs} ↦ {← ppENodeRef rhs}"
let mut valueInconsistency := false
if lhsRoot.interpreted && rhsRoot.interpreted then
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
else
valueInconsistency := true
-- TODO: process valueInconsistency := true
/-
We have the following `target?/proof?`
`lhs -> ... -> lhsNode.root`
Expand All @@ -213,22 +240,21 @@ where
proof? := proof
flipped
}
-- TODO: Remove parents from congruence table
let parents ← removeParents lhsRoot.self
-- TODO: set propagateBool
updateRoots lhs rhsNode.root true -- TODO
trace[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}"
-- TODO: Reinsert parents into congruence table
setENode lhsNode.root { lhsRoot with
reinsertParents parents
setENode lhsNode.root { (← getENode lhsRoot.self) with -- We must retrieve `lhsRoot` since it was updated.
next := rhsRoot.next
root := rhsNode.root
}
setENode rhsNode.root { rhsRoot with
next := lhsRoot.next
size := rhsRoot.size + lhsRoot.size
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
-- TODO: copy parentst from lhsRoot parents to rhsRoot parents
copyParentsTo parents rhsNode.root

updateRoots (lhs : Expr) (rootNew : Expr) (_propagateBool : Bool) : GoalM Unit := do
let rec loop (e : Expr) : GoalM Unit := do
Expand Down
67 changes: 67 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Inv.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types

namespace Lean.Meta.Grind

/-!
Debugging support code for checking basic invariants.
-/

register_builtin_option grind.debug : Bool := {
defValue := false
group := "debug"
descr := "check invariants after updates"
}

private def checkEqc (root : ENode) : GoalM Unit := do
let mut size := 0
let mut curr := root.self
repeat
size := size + 1
-- The root of `curr` must be `root`
assert! isSameExpr (← getRoot curr) root.self
-- Starting at `curr`, following the `target?` field leads to `root`.
let mut n := curr
repeat
if let some target ← getTarget? n then
n := target
else
break
assert! isSameExpr n root.self
-- Go to next element
curr ← getNext curr
if isSameExpr root.self curr then
break
-- The size of the equivalence class is correct.
assert! root.size == size

private def checkParents (e : Expr) : GoalM Unit := do
if (← isRoot e) then
for parent in (← getParents e) do
let mut found := false
-- There is an argument `arg` s.t. root of `arg` is `e`.
for arg in parent.getAppArgs do
if isSameExpr (← getRoot arg) e then
found := true
break
assert! found
else
-- All the parents are stored in the root of the equivalence class.
assert! (← getParents e).isEmpty

/--
Check basic invariants if `grind.debug` is enabled.
-/
def checkInvariants : GoalM Unit := do
if grind.debug.get (← getOptions) then
for (_, node) in (← get).enodes do
checkParents node.self
if isSameExpr node.self node.root then
checkEqc node

end Lean.Meta.Grind
57 changes: 57 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,14 @@ instance : BEq (CongrKey enodes) where

abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes)

-- Remark: we cannot use pointer addresses here because we have to traverse the tree.
abbrev ParentSet := RBTree Expr Expr.quickComp
abbrev ParentMap := PHashMap USize ParentSet

structure Goal where
mvarId : MVarId
enodes : ENodes := {}
parents : ParentMap := {}
congrTable : CongrTable enodes := {}
/-- Equations to be processed. -/
newEqs : Array NewEq := #[]
Expand Down Expand Up @@ -255,6 +260,16 @@ def getENode (e : Expr) : GoalM ENode := do
let some n := (← get).enodes.find? (unsafe ptrAddrUnsafe e) | unreachable!
return n

/-- Returns `true` is the root of its equivalence class. -/
def isRoot (e : Expr) : GoalM Bool := do
let some n ← getENode? e | return false -- `e` has not been internalized. Panic instead?
return isSameExpr n.root e

/-- Returns the root element in the equivalence class of `e` IF `e` has been internalized. -/
def getRoot? (e : Expr) : GoalM (Option Expr) := do
let some n ← getENode? e | return none
return some n.root

/-- Returns the root element in the equivalence class of `e`. -/
def getRoot (e : Expr) : GoalM Expr :=
return (← getENode e).root
Expand All @@ -267,6 +282,48 @@ def getNext (e : Expr) : GoalM Expr :=
def alreadyInternalized (e : Expr) : GoalM Bool :=
return (← get).enodes.contains (unsafe ptrAddrUnsafe e)

def getTarget? (e : Expr) : GoalM (Option Expr) := do
let some n ← getENode? e | return none
return n.target?

/--
Records that `parent` is a parent of `child`. This function actually stores the
information in the root (aka canonical representative) of `child`.
-/
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
let some childRoot ← getRoot? child | return ()
let key := toENodeKey childRoot
let parents := if let some parents := (← get).parents.find? key then parents else {}
modify fun s => { s with parents := s.parents.insert key (parents.insert parent) }

/--
Returns the set of expressions `e` is a child of, or an expression in
`e`s equivalence class is a child of.
The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class.
-/
def getParents (e : Expr) : GoalM ParentSet := do
let some parents := (← get).parents.find? (toENodeKey e) | return {}
return parents

/--
Similar to `getParents`, but also removes the entry `e ↦ parents` from the parent map.
-/
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
let parents ← getParents e
modify fun s => { s with parents := s.parents.erase (toENodeKey e) }
return parents

/--
Copy `parents` to the parents of `root`.
`root` must be the root of its equivalence class.
-/
def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
let key := toENodeKey root
let mut curr := if let some parents := (← get).parents.find? key then parents else {}
for parent in parents do
curr := curr.insert parent
modify fun s => { s with parents := s.parents.insert key curr }

def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with
enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n
Expand Down
33 changes: 33 additions & 0 deletions tests/lean/run/grind_congr.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import Lean

def f (a : Nat) := a + a + a
def g (a : Nat) := a + a

-- Prints the equivalence class containing a `f` application
open Lean Meta Elab Tactic Grind in
Expand All @@ -11,6 +12,8 @@ elab "grind_test" : tactic => withMainContext do
let eqc ← getEqc n.self
logInfo eqc

set_option grind.debug true

/--
info: [d, f b, c, f a]
---
Expand All @@ -20,3 +23,33 @@ warning: declaration uses 'sorry'
example (a b c d : Nat) : a = b → f a = c → f b = d → False := by
grind_test
sorry

/--
info: [d, f b, c, f a]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d : Nat) : f a = c → f b = d → a = b → False := by
grind_test
sorry

/--
info: [d, f (g b), c, f (g a)]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d e : Nat) : f (g a) = c → f (g b) = d → a = e → b = e → False := by
grind_test
sorry

/--
info: [d, f (g b), c, f v]
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
example (a b c d e v : Nat) : f v = c → f (g b) = d → a = e → b = e → v = g a → False := by
grind_test
sorry
2 changes: 2 additions & 0 deletions tests/lean/run/grind_nested_proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ elab "grind_test" : tactic => withMainContext do
let nodes ← filterENodes fun e => return e.self.isAppOf ``Lean.Grind.nestedProof
logInfo (nodes.toList.map (·.self))

set_option grind.debug true

/--
info: [Lean.Grind.nestedProof (i < a.toList.length) (_example.proof_1 i j a b h1 h2),
Lean.Grind.nestedProof (j < a.toList.length) h1,
Expand Down
2 changes: 2 additions & 0 deletions tests/lean/run/grind_pre.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ elab "grind_pre" : tactic => do

abbrev f (a : α) := a

set_option grind.debug true

/--
warning: declaration uses 'sorry'
---
Expand Down
Loading