Skip to content

Commit

Permalink
feat: support for builtin grind propagators (part 2)
Browse files Browse the repository at this point in the history
This PR completes the implementation of the command `builtin_grind_propagator`.
  • Loading branch information
leodemoura committed Dec 25, 2024
1 parent 65e8ba0 commit b7ba82d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 69 deletions.
15 changes: 6 additions & 9 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import Init.Grind.Util
import Lean.Meta.LitValues
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.Propagate
import Lean.Meta.Tactic.Grind.PP

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -147,8 +146,7 @@ where
}
let parents ← removeParents lhsRoot.self
-- TODO: set propagateBool
let isTrueOrFalse ← isTrueExpr rhsNode.root <||> isFalseExpr rhsNode.root
updateRoots lhs rhsNode.root (isTrueOrFalse && !(← isInconsistent))
updateRoots lhs rhsNode.root
trace[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}"
reinsertParents parents
setENode lhsNode.root { (← getENode lhsRoot.self) with -- We must retrieve `lhsRoot` since it was updated.
Expand All @@ -162,17 +160,16 @@ where
}
copyParentsTo parents rhsNode.root
unless (← isInconsistent) do
if isTrueOrFalse then
for parent in parents do
propagateConectivesUp parent
for parent in parents do
propagateUp parent

updateRoots (lhs : Expr) (rootNew : Expr) (propagateTruth : Bool) : GoalM Unit := do
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
let rec loop (e : Expr) : GoalM Unit := do
-- TODO: propagateBool
let n ← getENode e
setENode e { n with root := rootNew }
if propagateTruth then
propagateConnectivesDown e
unless (← isInconsistent) do
propagateDown e
if isSameExpr lhs n.next then return ()
loop n.next
loop lhs
Expand Down
75 changes: 20 additions & 55 deletions src/Lean/Meta/Tactic/Grind/Propagate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Lemmas
import Init.Grind
import Lean.Meta.Tactic.Grind.Proof

namespace Lean.Meta.Grind
Expand All @@ -19,9 +19,8 @@ and propagates the following equalities:
- If `a = False`, propagates `(a ∧ b) = False`.
- If `b = False`, propagates `(a ∧ b) = False`.
-/
private def propagateAndUp (e : Expr) : GoalM Unit := do
let a := e.appFn!.appArg!
let b := e.appArg!
builtin_grind_propagator propagateAndUp ↑And := fun e => do
let_expr And a b := e | return ()
if (← isEqTrue a) then
-- a = True → (a ∧ b) = b
pushEq e b <| mkApp3 (mkConst ``Lean.Grind.and_eq_of_eq_true_left) a b (← mkEqTrueProof a)
Expand All @@ -39,10 +38,9 @@ private def propagateAndUp (e : Expr) : GoalM Unit := do
Propagates truth values downwards for a conjunction `a ∧ b` when the
expression itself is known to be `True`.
-/
private def propagateAndDown (e : Expr) : GoalM Unit := do
builtin_grind_propagator propagateAndDown ↓And := fun e => do
if (← isEqTrue e) then
let a := e.appFn!.appArg!
let b := e.appArg!
let_expr And a b := e | return ()
let h ← mkEqTrueProof e
pushEqTrue a <| mkApp3 (mkConst ``Lean.Grind.eq_true_of_and_eq_true_left) a b h
pushEqTrue b <| mkApp3 (mkConst ``Lean.Grind.eq_true_of_and_eq_true_right) a b h
Expand All @@ -57,9 +55,8 @@ and propagates the following equalities:
- If `a = True`, propagates `(a ∨ b) = True`.
- If `b = True`, propagates `(a ∨ b) = True`.
-/
private def propagateOrUp (e : Expr) : GoalM Unit := do
let a := e.appFn!.appArg!
let b := e.appArg!
builtin_grind_propagator propagateOrUp ↑Or := fun e => do
let_expr Or a b := e | return ()
if (← isEqFalse a) then
-- a = False → (a ∨ b) = b
pushEq e b <| mkApp3 (mkConst ``Lean.Grind.or_eq_of_eq_false_left) a b (← mkEqFalseProof a)
Expand All @@ -77,10 +74,9 @@ private def propagateOrUp (e : Expr) : GoalM Unit := do
Propagates truth values downwards for a disjuction `a ∨ b` when the
expression itself is known to be `False`.
-/
private def propagateOrDown (e : Expr) : GoalM Unit := do
builtin_grind_propagator propagateOrDown ↓Or := fun e => do
if (← isEqFalse e) then
let a := e.appFn!.appArg!
let b := e.appArg!
let_expr Or a b := e | return ()
let h ← mkEqFalseProof e
pushEqFalse a <| mkApp3 (mkConst ``Lean.Grind.eq_false_of_or_eq_false_left) a b h
pushEqFalse b <| mkApp3 (mkConst ``Lean.Grind.eq_false_of_or_eq_false_right) a b h
Expand All @@ -92,8 +88,8 @@ This function checks the truth value of `a` and propagates the following equalit
- If `a = False`, propagates `(Not a) = True`.
- If `a = True`, propagates `(Not a) = False`.
-/
private def propagateNotUp (e : Expr) : GoalM Unit := do
let a := e.appArg!
builtin_grind_propagator propagateNotUp ↑Not := fun e => do
let_expr Not a := e | return ()
if (← isEqFalse a) then
-- a = False → (Not a) = True
pushEqTrue e <| mkApp2 (mkConst ``Lean.Grind.not_eq_of_eq_false) a (← mkEqFalseProof a)
Expand All @@ -108,18 +104,17 @@ This function performs the following:
- If `(Not a) = False`, propagates `a = True`.
- If `(Not a) = True`, propagates `a = False`.
-/
private def propagateNotDown (e : Expr) : GoalM Unit := do
builtin_grind_propagator propagateNotDown ↓Not := fun e => do
if (← isEqFalse e) then
let a := e.appArg!
let_expr Not a := e | return ()
pushEqTrue a <| mkApp2 (mkConst ``Lean.Grind.eq_true_of_not_eq_false) a (← mkEqFalseProof e)
else if (← isEqTrue e) then
let a := e.appArg!
let_expr Not a := e | return ()
pushEqFalse a <| mkApp2 (mkConst ``Lean.Grind.eq_false_of_not_eq_true) a (← mkEqTrueProof e)

/-- Propagates `Eq` upwards -/
def propagateEqUp (e : Expr) : GoalM Unit := do
let a := e.appFn!.appArg!
let b := e.appArg!
builtin_grind_propagator propagateEqUp ↑Eq := fun e => do
let_expr Eq _ a b := e | return ()
if (← isEqTrue a) then
pushEq e b <| mkApp3 (mkConst ``Lean.Grind.eq_eq_of_eq_true_left) a b (← mkEqTrueProof a)
else if (← isEqTrue b) then
Expand All @@ -128,45 +123,15 @@ def propagateEqUp (e : Expr) : GoalM Unit := do
pushEqTrue e <| mkApp2 (mkConst ``of_eq_true) e (← mkEqProof a b)

/-- Propagates `Eq` downwards -/
def propagateEqDown (e : Expr) : GoalM Unit := do
builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
if (← isEqTrue e) then
let a := e.appFn!.appArg!
let b := e.appArg!
let_expr Eq _ a b := e | return ()
pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e)

/-- Propagates `HEq` downwards -/
def propagateHEqDown (e : Expr) : GoalM Unit := do
builtin_grind_propagator propagateHEqDown ↓HEq := fun e => do
if (← isEqTrue e) then
let a := e.appFn!.appFn!.appArg!
let b := e.appArg!
let_expr HEq _ a _ b := e | return ()
pushHEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e)

/-- Propagates equalities upwards for logical connectives. -/
def propagateConectivesUp (e : Expr) : GoalM Unit := do
let .const declName _ := e.getAppFn | return ()
if declName == ``Eq && e.getAppNumArgs == 3 then
propagateEqUp e
else if declName == ``And && e.getAppNumArgs == 2 then
propagateAndUp e
else if declName == ``Or && e.getAppNumArgs == 2 then
propagateOrUp e
else if declName == ``Not && e.getAppNumArgs == 1 then
propagateNotUp e
-- TODO support for equality between Props

/-- Propagates equalities downwards for logical connectives. -/
def propagateConnectivesDown (e : Expr) : GoalM Unit := do
let .const declName _ := e.getAppFn | return ()
if declName == ``Eq && e.getAppNumArgs == 3 then
propagateEqDown e
else if declName == ``HEq && e.getAppNumArgs == 4 then
propagateHEqDown e
else if declName == ``And && e.getAppNumArgs == 2 then
propagateAndDown e
else if declName == ``Or && e.getAppNumArgs == 2 then
propagateOrDown e
else if declName == ``Not && e.getAppNumArgs == 1 then
propagateNotDown e
-- TODO support for `if-then-else`, equality between Props

end Lean.Meta.Grind
22 changes: 17 additions & 5 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def getFalseExpr : GrindM Expr := do
def getMainDeclName : GrindM Name :=
return (← readThe Context).mainDeclName

@[inline] def getMethodsRef : GrindM MethodsRef :=
read

/--
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
-/
Expand Down Expand Up @@ -449,17 +452,26 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do
if isSameExpr n.self n.root then
f n

structure Methods where
private structure Methods where
propagateUp : Propagator := fun _ => return ()
propagateDown : Propagator := fun _ => return ()
deriving Inhabited

def Methods.toMethodsRef (m : Methods) : MethodsRef :=
private def Methods.toMethodsRef (m : Methods) : MethodsRef :=
unsafe unsafeCast m

def MethodsRef.toMethods (m : MethodsRef) : Methods :=
private def MethodsRef.toMethods (m : MethodsRef) : Methods :=
unsafe unsafeCast m

@[inline] def getMethods : GrindM Methods :=
return (← getMethodsRef).toMethods

def propagateUp (e : Expr) : GoalM Unit := do
(← getMethods).propagateUp e

def propagateDown (e : Expr) : GoalM Unit := do
(← getMethods).propagateDown e

/-- Builtin propagators. -/
structure BuiltinPropagators where
up : Std.HashMap Name Propagator := {}
Expand Down Expand Up @@ -509,7 +521,7 @@ builtin_initialize
add := fun declName stx _ => addBuiltin declName stx
}

def getMethods : CoreM Methods := do
def mkMethods : CoreM Methods := do
let builtinPropagators ← builtinPropagatorsRef.get
return {
propagateUp := fun e => do
Expand All @@ -532,7 +544,7 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x (← getMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }
x (← mkMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }

@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
goal.mvarId.withContext do StateRefT'.run x goal
Expand Down

0 comments on commit b7ba82d

Please sign in to comment.