diff --git a/src/Init/Grind.lean b/src/Init/Grind.lean index 12db182084cf..58dfe1290ddf 100644 --- a/src/Init/Grind.lean +++ b/src/Init/Grind.lean @@ -8,3 +8,4 @@ import Init.Grind.Norm import Init.Grind.Tactics import Init.Grind.Lemmas import Init.Grind.Cases +import Init.Grind.Propagator diff --git a/src/Init/Grind/Propagator.lean b/src/Init/Grind/Propagator.lean new file mode 100644 index 000000000000..0edb9d8fb465 --- /dev/null +++ b/src/Init/Grind/Propagator.lean @@ -0,0 +1,27 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.NotationExtra + +namespace Lean.Parser + +/-- A user-defined propagator for the `grind` tactic. -/ +-- TODO: not implemented yet +syntax (docComment)? "grind_propagator " (Tactic.simpPre <|> Tactic.simpPost) ident " (" ident ")" " := " term : command + +/-- A builtin propagator for the `grind` tactic. -/ +syntax (docComment)? "builtin_grind_propagator " ident (Tactic.simpPre <|> Tactic.simpPost) ident " := " term : command + +/-- Auxiliary attribute for builtin `grind` propagators. -/ +syntax (name := grindPropagatorBuiltinAttr) "builtin_grind_propagator" (Tactic.simpPre <|> Tactic.simpPost) ident : attr + +macro_rules + | `($[$doc?:docComment]? builtin_grind_propagator $propagatorName:ident $direction $op:ident := $body) => do + let propagatorType := `Lean.Meta.Grind.Propagator + `($[$doc?:docComment]? def $propagatorName:ident : $(mkIdent propagatorType) := $body + attribute [builtin_grind_propagator $direction $op] $propagatorName) + +end Lean.Parser diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 0716c656250f..9ad5cc03cc3a 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -28,6 +28,7 @@ builtin_initialize registerTraceClass `grind.issues builtin_initialize registerTraceClass `grind.add builtin_initialize registerTraceClass `grind.pre builtin_initialize registerTraceClass `grind.debug +builtin_initialize registerTraceClass `grind.simp builtin_initialize registerTraceClass `grind.congr end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Simp.lean b/src/Lean/Meta/Tactic/Grind/Simp.lean index 601f5472b1a5..50435d7eb788 100644 --- a/src/Lean/Meta/Tactic/Grind/Simp.lean +++ b/src/Lean/Meta/Tactic/Grind/Simp.lean @@ -17,7 +17,7 @@ namespace Lean.Meta.Grind /-- Simplifies the given expression using the `grind` simprocs and normalization theorems. -/ def simp (e : Expr) : GrindM Simp.Result := do let simpStats := (← get).simpStats - let (r, simpStats) ← Meta.simp e (← read).simp (← read).simprocs (stats := simpStats) + let (r, simpStats) ← Meta.simp e (← readThe Context).simp (← readThe Context).simprocs (stats := simpStats) modify fun s => { s with simpStats } return r @@ -35,6 +35,7 @@ def pre (e : Expr) : GrindM Simp.Result := do let e' ← normalizeLevels e' let e' ← canon e' let e' ← shareCommon e' + trace[grind.simp] "{e}\n===>\n{e'}" return { r with expr := e' } end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 98cb48180891..de9328939acf 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -70,19 +70,11 @@ structure State where trueExpr : Expr falseExpr : Expr -abbrev GrindM := ReaderT Context $ StateRefT State MetaM +private opaque MethodsRefPointed : NonemptyType.{0} +private def MethodsRef : Type := MethodsRefPointed.type +instance : Nonempty MethodsRef := MethodsRefPointed.property -def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do - let scState := ShareCommon.State.mk _ - let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) - let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) - let thms ← grindNormExt.getTheorems - let simprocs := #[(← grindNormSimprocExt.getSimprocs)] - let simp ← Simp.mkContext - (config := { arith := true }) - (simpTheorems := #[thms]) - (congrTheorems := (← getSimpCongrTheorems)) - x { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } +abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM /-- Returns the internalized `True` constant. -/ def getTrueExpr : GrindM Expr := do @@ -93,7 +85,7 @@ def getFalseExpr : GrindM Expr := do return (← get).falseExpr def getMainDeclName : GrindM Name := - return (← read).mainDeclName + return (← readThe Context).mainDeclName /-- Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization. @@ -269,11 +261,7 @@ def Goal.admit (goal : Goal) : MetaM Unit := abbrev GoalM := StateRefT Goal GrindM -@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) := - goal.mvarId.withContext do StateRefT'.run x goal - -@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal := - goal.mvarId.withContext do StateRefT'.run' (x *> get) goal +abbrev Propagator := Expr → GoalM Unit /-- Return `true` if the goal is inconsistent. -/ def isInconsistent : GoalM Bool := @@ -437,13 +425,6 @@ def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do let interpreted ← isInterpreted e mkENodeCore e interpreted ctor generation -def mkGoal (mvarId : MVarId) : GrindM Goal := do - let trueExpr ← getTrueExpr - let falseExpr ← getFalseExpr - GoalM.run' { mvarId } do - mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) - mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) - /-- Returns all enodes in the goal -/ def getENodes : GoalM (Array ENode) := do -- We must sort because we are using pointer addresses as keys in `enodes` @@ -468,4 +449,102 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do if isSameExpr n.self n.root then f n +structure Methods where + propagateUp : Propagator := fun _ => return () + propagateDown : Propagator := fun _ => return () + deriving Inhabited + +def Methods.toMethodsRef (m : Methods) : MethodsRef := + unsafe unsafeCast m + +def MethodsRef.toMethods (m : MethodsRef) : Methods := + unsafe unsafeCast m + +/-- Builtin propagators. -/ +structure BuiltinPropagators where + up : Std.HashMap Name Propagator := {} + down : Std.HashMap Name Propagator := {} + deriving Inhabited + +builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {} + +private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc : Propagator) : IO Unit := do + unless (← initializing) do + throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization") + if up then + if (← builtinPropagatorsRef.get).up.contains declName then + throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared") + builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down } + else + if (← builtinPropagatorsRef.get).down.contains declName then + throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared") + builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc } + +def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit := + registerBuiltinPropagatorCore declName true proc + +def registerBuiltinDownwardPropagator (declName : Name) (proc : Propagator) : IO Unit := + registerBuiltinPropagatorCore declName false proc + +private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do + let go : MetaM Unit := do + let up := stx[1].getKind == ``Lean.Parser.Tactic.simpPost + let addDeclName := if up then + ``registerBuiltinUpwardPropagator + else + ``registerBuiltinDownwardPropagator + let declName ← resolveGlobalConstNoOverload stx[2] + let val := mkAppN (mkConst addDeclName) #[toExpr declName, mkConst propagatorName] + let initDeclName ← mkFreshUserName (propagatorName ++ `declare) + declareBuiltin initDeclName val + go.run' {} + +builtin_initialize + registerBuiltinAttribute { + ref := by exact decl_name% + name := `grindPropagatorBuiltinAttr + descr := "Builtin `grind` propagator procedure" + applicationTime := AttributeApplicationTime.afterCompilation + erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]" + add := fun declName stx _ => addBuiltin declName stx + } + +def getMethods : CoreM Methods := do + let builtinPropagators ← builtinPropagatorsRef.get + return { + propagateUp := fun e => do + let .const declName _ := e.getAppFn | return () + if let some prop := builtinPropagators.up[declName]? then + prop e + propagateDown := fun e => do + let .const declName _ := e.getAppFn | return () + if let some prop := builtinPropagators.down[declName]? then + prop e + } + +def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do + let scState := ShareCommon.State.mk _ + let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) + let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) + let thms ← grindNormExt.getTheorems + let simprocs := #[(← grindNormSimprocExt.getSimprocs)] + let simp ← Simp.mkContext + (config := { arith := true }) + (simpTheorems := #[thms]) + (congrTheorems := (← getSimpCongrTheorems)) + x (← getMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr } + +@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) := + goal.mvarId.withContext do StateRefT'.run x goal + +@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal := + goal.mvarId.withContext do StateRefT'.run' (x *> get) goal + +def mkGoal (mvarId : MVarId) : GrindM Goal := do + let trueExpr ← getTrueExpr + let falseExpr ← getFalseExpr + GoalM.run' { mvarId } do + mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) + mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) + end Lean.Meta.Grind