Skip to content

Commit

Permalink
refactor: move simplifier support to GrindM
Browse files Browse the repository at this point in the history
This PR refactors `grind` and adds support for invokind the simplifier
using the `GrindM` monad.
  • Loading branch information
leodemoura committed Dec 25, 2024
1 parent f9f8abe commit 7de9466
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 58 deletions.
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.Proof
import Lean.Meta.Tactic.Grind.Propagate
import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Simp

namespace Lean

Expand Down
15 changes: 1 addition & 14 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,11 @@ import Lean.Meta.Tactic.Grind.PP

namespace Lean.Meta.Grind

/--
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hashconsed.
-/
def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
let ctor := (← isConstructorAppCore? e).isSome
let interpreted ← isInterpreted e
mkENodeCore e interpreted ctor generation

/-- We use this auxiliary constant to mark delayed congruence proofs. -/
private def congrPlaceholderProof := mkConst (Name.mkSimple "[congruence]")

/-- Adds `e` to congruence table. -/
def addCongrTable (e : Expr) : GoalM Unit := do
private def addCongrTable (e : Expr) : GoalM Unit := do
if let some { e := e' } := (← get).congrTable.find? { e } then
trace[grind.congr] "{e} = {e'}"
pushEqHEq e e' congrPlaceholderProof
Expand Down Expand Up @@ -95,9 +85,6 @@ where
private def markAsInconsistent : GoalM Unit :=
modify fun s => { s with inconsistent := true }

def isInconsistent : GoalM Bool :=
return (← get).inconsistent

/--
Remove `root` parents from the congruence table.
This is an auxiliary function performed while merging equivalence classes.
Expand Down
42 changes: 6 additions & 36 deletions src/Lean/Meta/Tactic/Grind/Preprocessor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,19 @@ import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core
import Lean.Meta.Tactic.Grind.MarkNestedProofs
import Lean.Meta.Tactic.Grind.Simp

namespace Lean.Meta.Grind
namespace Preprocessor

-- TODO: use congruence closure and decision procedures during pre-processing
-- TODO: implement `simp` discharger using preprocessor state

structure Context where
simp : Simp.Context
simprocs : Array Simp.Simprocs
deriving Inhabited

structure State where
simpStats : Simp.Stats := {}
goals : PArray Goal := {}
deriving Inhabited

abbrev PreM := ReaderT Context $ StateRefT State GrindM
abbrev PreM := StateRefT State GrindM

def PreM.run (x : PreM α) : GrindM α := do
let thms ← grindNormExt.getTheorems
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x { simp, simprocs } |>.run' {}

def simp (_goal : Goal) (e : Expr) : PreM Simp.Result := do
-- TODO: use `goal` state in the simplifier
let simpStats := (← get).simpStats
let (r, simpStats) ← Meta.simp e (← read).simp (← read).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return r
x.run' {}

inductive IntroResult where
| done
Expand All @@ -70,24 +48,16 @@ def introNext (goal : Goal) : PreM IntroResult := do
let tag ← goal.mvarId.getTag
let q := target.bindingBody!
-- TODO: keep applying simp/eraseIrrelevantMData/canon/shareCommon until no progress
let r ← simp goal p
let p' := r.expr
let p' ← markNestedProofs p'
let p' ← unfoldReducible p'
let p' ← eraseIrrelevantMData p'
let p' ← foldProjs p'
let p' ← normalizeLevels p'
let p' ← canon p'
let p' ← shareCommon p'
let r ← pre p
let fvarId ← mkFreshFVarId
let lctx := (← getLCtx).mkLocalDecl fvarId target.bindingName! p' target.bindingInfo!
let lctx := (← getLCtx).mkLocalDecl fvarId target.bindingName! r.expr target.bindingInfo!
let mvarNew ← mkFreshExprMVarAt lctx (← getLocalInstances) q .syntheticOpaque tag
let mvarIdNew := mvarNew.mvarId!
mvarIdNew.withContext do
let h ← mkLambdaFVars #[mkFVar fvarId] mvarNew
match r.proof? with
| some he =>
let hNew := mkAppN (mkConst ``Lean.Grind.intro_with_eq) #[p, p', q, he, h]
let hNew := mkAppN (mkConst ``Lean.Grind.intro_with_eq) #[p, r.expr, q, he, h]
goal.mvarId.assign hNew
return .newHyp fvarId { goal with mvarId := mvarIdNew }
| none =>
Expand Down
40 changes: 40 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Simp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.MarkNestedProofs

namespace Lean.Meta.Grind

-- TODO: use congruence closure and decision procedures during pre-processing
-- TODO: implement `simp` discharger using preprocessor state

/-- Simplifies the given expression using the `grind` simprocs and normalization theorems. -/
def simp (e : Expr) : GrindM Simp.Result := do
let simpStats := (← get).simpStats
let (r, simpStats) ← Meta.simp e (← read).simp (← read).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return r

/--
Simplifies `e` using `grind` normalization theorems and simprocs,
and then applies several other preprocessing steps.
-/
def pre (e : Expr) : GrindM Simp.Result := do
let r ← simp e
let e' := r.expr
let e' ← markNestedProofs e'
let e' ← unfoldReducible e'
let e' ← eraseIrrelevantMData e'
let e' ← foldProjs e'
let e' ← normalizeLevels e'
let e' ← canon e'
let e' ← shareCommon e'
return { r with expr := e' }

end Lean.Meta.Grind
43 changes: 35 additions & 8 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import Lean.Util.ShareCommon
import Lean.Meta.Basic
import Lean.Meta.CongrTheorems
import Lean.Meta.AbstractNestedProofs
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.Attr

namespace Lean.Meta.Grind

Expand All @@ -32,12 +34,13 @@ register_builtin_option grind.debug : Bool := {
descr := "check invariants after updates"
}

/-- Context for `GrindM` monad. -/
structure Context where
simp : Simp.Context
simprocs : Array Simp.Simprocs
mainDeclName : Name

/--
Key for the congruence theorem cache.
-/
/-- Key for the congruence theorem cache. -/
structure CongrTheoremCacheKey where
f : Expr
numArgs : Nat
Expand All @@ -50,6 +53,7 @@ instance : BEq CongrTheoremCacheKey where
instance : Hashable CongrTheoremCacheKey where
hash a := mixHash (unsafe ptrAddrUnsafe a.f).toUInt64 (hash a.numArgs)

/-- State for the `GrindM` monad. -/
structure State where
canon : Canon.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
Expand All @@ -62,6 +66,7 @@ structure State where
Remark: we currently do not reuse congruence theorems
-/
congrThms : PHashMap CongrTheoremCacheKey CongrTheorem := {}
simpStats : Simp.Stats := {}
trueExpr : Expr
falseExpr : Expr

Expand All @@ -71,11 +76,19 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
x { mainDeclName } |>.run' { scState, trueExpr, falseExpr }

let thms ← grindNormExt.getTheorems
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }

/-- Returns the internalized `True` constant. -/
def getTrueExpr : GrindM Expr := do
return (← get).trueExpr

/-- Returns the internalized `False` constant. -/
def getFalseExpr : GrindM Expr := do
return (← get).falseExpr

Expand All @@ -96,9 +109,9 @@ Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have
been hash-consing. We perform this step before we internalize expressions.
-/
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr } =>
modifyGet fun { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, simpStats } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr })
(e, { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr, simpStats })

/--
Canonicalizes nested types, type formers, and instances in `e`.
Expand Down Expand Up @@ -262,6 +275,10 @@ abbrev GoalM := StateRefT Goal GrindM
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal

/-- Return `true` if the goal is inconsistent. -/
def isInconsistent : GoalM Bool :=
return (← get).inconsistent

/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindM Bool :=
return isSameExpr e (← getTrueExpr)
Expand Down Expand Up @@ -410,6 +427,16 @@ def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM
}
modify fun s => { s with nextIdx := s.nextIdx + 1 }

/--
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hashconsed.
-/
def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
let ctor := (← isConstructorAppCore? e).isSome
let interpreted ← isInterpreted e
mkENodeCore e interpreted ctor generation

def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
Expand Down

0 comments on commit 7de9466

Please sign in to comment.