Skip to content

Commit

Permalink
feat: support for builtin grind propagators (#6448)
Browse files Browse the repository at this point in the history
This PR declares the command `builtin_grind_propagator` for registering
equation propagator for `grind`. It also declares the auxiliary the
attribute.
  • Loading branch information
leodemoura authored Dec 25, 2024
1 parent 977b8e0 commit 3cddae6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/Init/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ import Init.Grind.Norm
import Init.Grind.Tactics
import Init.Grind.Lemmas
import Init.Grind.Cases
import Init.Grind.Propagator
27 changes: 27 additions & 0 deletions src/Init/Grind/Propagator.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.NotationExtra

namespace Lean.Parser

/-- A user-defined propagator for the `grind` tactic. -/
-- TODO: not implemented yet
syntax (docComment)? "grind_propagator " (Tactic.simpPre <|> Tactic.simpPost) ident " (" ident ")" " := " term : command

/-- A builtin propagator for the `grind` tactic. -/
syntax (docComment)? "builtin_grind_propagator " ident (Tactic.simpPre <|> Tactic.simpPost) ident " := " term : command

/-- Auxiliary attribute for builtin `grind` propagators. -/
syntax (name := grindPropagatorBuiltinAttr) "builtin_grind_propagator" (Tactic.simpPre <|> Tactic.simpPost) ident : attr

macro_rules
| `($[$doc?:docComment]? builtin_grind_propagator $propagatorName:ident $direction $op:ident := $body) => do
let propagatorType := `Lean.Meta.Grind.Propagator
`($[$doc?:docComment]? def $propagatorName:ident : $(mkIdent propagatorType) := $body
attribute [builtin_grind_propagator $direction $op] $propagatorName)

end Lean.Parser
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.add
builtin_initialize registerTraceClass `grind.pre
builtin_initialize registerTraceClass `grind.debug
builtin_initialize registerTraceClass `grind.simp
builtin_initialize registerTraceClass `grind.congr

end Lean
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Lean.Meta.Grind
/-- Simplifies the given expression using the `grind` simprocs and normalization theorems. -/
def simp (e : Expr) : GrindM Simp.Result := do
let simpStats := (← get).simpStats
let (r, simpStats) ← Meta.simp e (← read).simp (← read).simprocs (stats := simpStats)
let (r, simpStats) ← Meta.simp e (← readThe Context).simp (← readThe Context).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return r

Expand All @@ -35,6 +35,7 @@ def pre (e : Expr) : GrindM Simp.Result := do
let e' ← normalizeLevels e'
let e' ← canon e'
let e' ← shareCommon e'
trace[grind.simp] "{e}\n===>\n{e'}"
return { r with expr := e' }

end Lean.Meta.Grind
129 changes: 104 additions & 25 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,11 @@ structure State where
trueExpr : Expr
falseExpr : Expr

abbrev GrindM := ReaderT Context $ StateRefT State MetaM
private opaque MethodsRefPointed : NonemptyType.{0}
private def MethodsRef : Type := MethodsRefPointed.type
instance : Nonempty MethodsRef := MethodsRefPointed.property

def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let thms ← grindNormExt.getTheorems
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }
abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM

/-- Returns the internalized `True` constant. -/
def getTrueExpr : GrindM Expr := do
Expand All @@ -93,7 +85,7 @@ def getFalseExpr : GrindM Expr := do
return (← get).falseExpr

def getMainDeclName : GrindM Name :=
return (← read).mainDeclName
return (← readThe Context).mainDeclName

/--
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
Expand Down Expand Up @@ -269,11 +261,7 @@ def Goal.admit (goal : Goal) : MetaM Unit :=

abbrev GoalM := StateRefT Goal GrindM

@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
goal.mvarId.withContext do StateRefT'.run x goal

@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal
abbrev Propagator := Expr → GoalM Unit

/-- Return `true` if the goal is inconsistent. -/
def isInconsistent : GoalM Bool :=
Expand Down Expand Up @@ -437,13 +425,6 @@ def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
let interpreted ← isInterpreted e
mkENodeCore e interpreted ctor generation

def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
GoalM.run' { mvarId } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)

/-- Returns all enodes in the goal -/
def getENodes : GoalM (Array ENode) := do
-- We must sort because we are using pointer addresses as keys in `enodes`
Expand All @@ -468,4 +449,102 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do
if isSameExpr n.self n.root then
f n

structure Methods where
propagateUp : Propagator := fun _ => return ()
propagateDown : Propagator := fun _ => return ()
deriving Inhabited

def Methods.toMethodsRef (m : Methods) : MethodsRef :=
unsafe unsafeCast m

def MethodsRef.toMethods (m : MethodsRef) : Methods :=
unsafe unsafeCast m

/-- Builtin propagators. -/
structure BuiltinPropagators where
up : Std.HashMap Name Propagator := {}
down : Std.HashMap Name Propagator := {}
deriving Inhabited

builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {}

private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc : Propagator) : IO Unit := do
unless (← initializing) do
throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization")
if up then
if (← builtinPropagatorsRef.get).up.contains declName then
throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared")
builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down }
else
if (← builtinPropagatorsRef.get).down.contains declName then
throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared")
builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc }

def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=
registerBuiltinPropagatorCore declName true proc

def registerBuiltinDownwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=
registerBuiltinPropagatorCore declName false proc

private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do
let go : MetaM Unit := do
let up := stx[1].getKind == ``Lean.Parser.Tactic.simpPost
let addDeclName := if up then
``registerBuiltinUpwardPropagator
else
``registerBuiltinDownwardPropagator
let declName ← resolveGlobalConstNoOverload stx[2]
let val := mkAppN (mkConst addDeclName) #[toExpr declName, mkConst propagatorName]
let initDeclName ← mkFreshUserName (propagatorName ++ `declare)
declareBuiltin initDeclName val
go.run' {}

builtin_initialize
registerBuiltinAttribute {
ref := by exact decl_name%
name := `grindPropagatorBuiltinAttr
descr := "Builtin `grind` propagator procedure"
applicationTime := AttributeApplicationTime.afterCompilation
erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]"
add := fun declName stx _ => addBuiltin declName stx
}

def getMethods : CoreM Methods := do
let builtinPropagators ← builtinPropagatorsRef.get
return {
propagateUp := fun e => do
let .const declName _ := e.getAppFn | return ()
if let some prop := builtinPropagators.up[declName]? then
prop e
propagateDown := fun e => do
let .const declName _ := e.getAppFn | return ()
if let some prop := builtinPropagators.down[declName]? then
prop e
}

def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let thms ← grindNormExt.getTheorems
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x (← getMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }

@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
goal.mvarId.withContext do StateRefT'.run x goal

@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal

def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
GoalM.run' { mvarId } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)

end Lean.Meta.Grind

0 comments on commit 3cddae6

Please sign in to comment.