diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index f97f27482295..f655ca85c447 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -8,7 +8,6 @@ import Init.Grind.Util import Lean.Meta.LitValues import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Inv -import Lean.Meta.Tactic.Grind.Propagate import Lean.Meta.Tactic.Grind.PP namespace Lean.Meta.Grind @@ -147,8 +146,7 @@ where } let parents ← removeParents lhsRoot.self -- TODO: set propagateBool - let isTrueOrFalse ← isTrueExpr rhsNode.root <||> isFalseExpr rhsNode.root - updateRoots lhs rhsNode.root (isTrueOrFalse && !(← isInconsistent)) + updateRoots lhs rhsNode.root trace[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}" reinsertParents parents setENode lhsNode.root { (← getENode lhsRoot.self) with -- We must retrieve `lhsRoot` since it was updated. @@ -162,17 +160,16 @@ where } copyParentsTo parents rhsNode.root unless (← isInconsistent) do - if isTrueOrFalse then - for parent in parents do - propagateConectivesUp parent + for parent in parents do + propagateUp parent - updateRoots (lhs : Expr) (rootNew : Expr) (propagateTruth : Bool) : GoalM Unit := do + updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do let rec loop (e : Expr) : GoalM Unit := do -- TODO: propagateBool let n ← getENode e setENode e { n with root := rootNew } - if propagateTruth then - propagateConnectivesDown e + unless (← isInconsistent) do + propagateDown e if isSameExpr lhs n.next then return () loop n.next loop lhs diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index 897d4f851e56..c7d7812122f0 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude -import Init.Grind.Lemmas +import Init.Grind import Lean.Meta.Tactic.Grind.Proof namespace Lean.Meta.Grind @@ -19,9 +19,8 @@ and propagates the following equalities: - If `a = False`, propagates `(a ∧ b) = False`. - If `b = False`, propagates `(a ∧ b) = False`. -/ -private def propagateAndUp (e : Expr) : GoalM Unit := do - let a := e.appFn!.appArg! - let b := e.appArg! +builtin_grind_propagator propagateAndUp ↑And := fun e => do + let_expr And a b := e | return () if (← isEqTrue a) then -- a = True → (a ∧ b) = b pushEq e b <| mkApp3 (mkConst ``Lean.Grind.and_eq_of_eq_true_left) a b (← mkEqTrueProof a) @@ -39,10 +38,9 @@ private def propagateAndUp (e : Expr) : GoalM Unit := do Propagates truth values downwards for a conjunction `a ∧ b` when the expression itself is known to be `True`. -/ -private def propagateAndDown (e : Expr) : GoalM Unit := do +builtin_grind_propagator propagateAndDown ↓And := fun e => do if (← isEqTrue e) then - let a := e.appFn!.appArg! - let b := e.appArg! + let_expr And a b := e | return () let h ← mkEqTrueProof e pushEqTrue a <| mkApp3 (mkConst ``Lean.Grind.eq_true_of_and_eq_true_left) a b h pushEqTrue b <| mkApp3 (mkConst ``Lean.Grind.eq_true_of_and_eq_true_right) a b h @@ -57,9 +55,8 @@ and propagates the following equalities: - If `a = True`, propagates `(a ∨ b) = True`. - If `b = True`, propagates `(a ∨ b) = True`. -/ -private def propagateOrUp (e : Expr) : GoalM Unit := do - let a := e.appFn!.appArg! - let b := e.appArg! +builtin_grind_propagator propagateOrUp ↑Or := fun e => do + let_expr Or a b := e | return () if (← isEqFalse a) then -- a = False → (a ∨ b) = b pushEq e b <| mkApp3 (mkConst ``Lean.Grind.or_eq_of_eq_false_left) a b (← mkEqFalseProof a) @@ -77,10 +74,9 @@ private def propagateOrUp (e : Expr) : GoalM Unit := do Propagates truth values downwards for a disjuction `a ∨ b` when the expression itself is known to be `False`. -/ -private def propagateOrDown (e : Expr) : GoalM Unit := do +builtin_grind_propagator propagateOrDown ↓Or := fun e => do if (← isEqFalse e) then - let a := e.appFn!.appArg! - let b := e.appArg! + let_expr Or a b := e | return () let h ← mkEqFalseProof e pushEqFalse a <| mkApp3 (mkConst ``Lean.Grind.eq_false_of_or_eq_false_left) a b h pushEqFalse b <| mkApp3 (mkConst ``Lean.Grind.eq_false_of_or_eq_false_right) a b h @@ -92,8 +88,8 @@ This function checks the truth value of `a` and propagates the following equalit - If `a = False`, propagates `(Not a) = True`. - If `a = True`, propagates `(Not a) = False`. -/ -private def propagateNotUp (e : Expr) : GoalM Unit := do - let a := e.appArg! +builtin_grind_propagator propagateNotUp ↑Not := fun e => do + let_expr Not a := e | return () if (← isEqFalse a) then -- a = False → (Not a) = True pushEqTrue e <| mkApp2 (mkConst ``Lean.Grind.not_eq_of_eq_false) a (← mkEqFalseProof a) @@ -108,18 +104,17 @@ This function performs the following: - If `(Not a) = False`, propagates `a = True`. - If `(Not a) = True`, propagates `a = False`. -/ -private def propagateNotDown (e : Expr) : GoalM Unit := do +builtin_grind_propagator propagateNotDown ↓Not := fun e => do if (← isEqFalse e) then - let a := e.appArg! + let_expr Not a := e | return () pushEqTrue a <| mkApp2 (mkConst ``Lean.Grind.eq_true_of_not_eq_false) a (← mkEqFalseProof e) else if (← isEqTrue e) then - let a := e.appArg! + let_expr Not a := e | return () pushEqFalse a <| mkApp2 (mkConst ``Lean.Grind.eq_false_of_not_eq_true) a (← mkEqTrueProof e) /-- Propagates `Eq` upwards -/ -def propagateEqUp (e : Expr) : GoalM Unit := do - let a := e.appFn!.appArg! - let b := e.appArg! +builtin_grind_propagator propagateEqUp ↑Eq := fun e => do + let_expr Eq _ a b := e | return () if (← isEqTrue a) then pushEq e b <| mkApp3 (mkConst ``Lean.Grind.eq_eq_of_eq_true_left) a b (← mkEqTrueProof a) else if (← isEqTrue b) then @@ -128,45 +123,15 @@ def propagateEqUp (e : Expr) : GoalM Unit := do pushEqTrue e <| mkApp2 (mkConst ``of_eq_true) e (← mkEqProof a b) /-- Propagates `Eq` downwards -/ -def propagateEqDown (e : Expr) : GoalM Unit := do +builtin_grind_propagator propagateEqDown ↓Eq := fun e => do if (← isEqTrue e) then - let a := e.appFn!.appArg! - let b := e.appArg! + let_expr Eq _ a b := e | return () pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e) /-- Propagates `HEq` downwards -/ -def propagateHEqDown (e : Expr) : GoalM Unit := do +builtin_grind_propagator propagateHEqDown ↓HEq := fun e => do if (← isEqTrue e) then - let a := e.appFn!.appFn!.appArg! - let b := e.appArg! + let_expr HEq _ a _ b := e | return () pushHEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e) -/-- Propagates equalities upwards for logical connectives. -/ -def propagateConectivesUp (e : Expr) : GoalM Unit := do - let .const declName _ := e.getAppFn | return () - if declName == ``Eq && e.getAppNumArgs == 3 then - propagateEqUp e - else if declName == ``And && e.getAppNumArgs == 2 then - propagateAndUp e - else if declName == ``Or && e.getAppNumArgs == 2 then - propagateOrUp e - else if declName == ``Not && e.getAppNumArgs == 1 then - propagateNotUp e - -- TODO support for equality between Props - -/-- Propagates equalities downwards for logical connectives. -/ -def propagateConnectivesDown (e : Expr) : GoalM Unit := do - let .const declName _ := e.getAppFn | return () - if declName == ``Eq && e.getAppNumArgs == 3 then - propagateEqDown e - else if declName == ``HEq && e.getAppNumArgs == 4 then - propagateHEqDown e - else if declName == ``And && e.getAppNumArgs == 2 then - propagateAndDown e - else if declName == ``Or && e.getAppNumArgs == 2 then - propagateOrDown e - else if declName == ``Not && e.getAppNumArgs == 1 then - propagateNotDown e - -- TODO support for `if-then-else`, equality between Props - end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index de9328939acf..23c3c19d5083 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -87,6 +87,9 @@ def getFalseExpr : GrindM Expr := do def getMainDeclName : GrindM Name := return (← readThe Context).mainDeclName +@[inline] def getMethodsRef : GrindM MethodsRef := + read + /-- Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization. -/ @@ -449,17 +452,26 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do if isSameExpr n.self n.root then f n -structure Methods where +private structure Methods where propagateUp : Propagator := fun _ => return () propagateDown : Propagator := fun _ => return () deriving Inhabited -def Methods.toMethodsRef (m : Methods) : MethodsRef := +private def Methods.toMethodsRef (m : Methods) : MethodsRef := unsafe unsafeCast m -def MethodsRef.toMethods (m : MethodsRef) : Methods := +private def MethodsRef.toMethods (m : MethodsRef) : Methods := unsafe unsafeCast m +@[inline] def getMethods : GrindM Methods := + return (← getMethodsRef).toMethods + +def propagateUp (e : Expr) : GoalM Unit := do + (← getMethods).propagateUp e + +def propagateDown (e : Expr) : GoalM Unit := do + (← getMethods).propagateDown e + /-- Builtin propagators. -/ structure BuiltinPropagators where up : Std.HashMap Name Propagator := {} @@ -509,7 +521,7 @@ builtin_initialize add := fun declName stx _ => addBuiltin declName stx } -def getMethods : CoreM Methods := do +def mkMethods : CoreM Methods := do let builtinPropagators ← builtinPropagatorsRef.get return { propagateUp := fun e => do @@ -532,7 +544,7 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do (config := { arith := true }) (simpTheorems := #[thms]) (congrTheorems := (← getSimpCongrTheorems)) - x (← getMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } + x (← mkMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } @[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) := goal.mvarId.withContext do StateRefT'.run x goal