Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rsimp_decide etc #5839

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
|| (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then
withDefault <| unfoldDefinition? e
else
-- `We are not unfolding partial applications, and `fName` does not have smart unfolding support.
-- We are not unfolding partial applications, and `fName` does not have smart unfolding support.
-- Thus, we must check whether the arity of the function >= number of arguments.
let some cinfo := (← getEnv).find? fName | return none
let some value := cinfo.value? | return none
Expand Down
15 changes: 8 additions & 7 deletions src/Lean/Meta/WHNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ private def cleanupNatOffsetMajor (e : Expr) : MetaM Expr := do
return mkNatSucc (mkNatAdd e (toExpr (k - 1)))

/-- Auxiliary function for reducing recursor applications. -/
private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Unit → MetaM α) (successK : Expr → MetaM α) : MetaM α :=
private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Expr → MetaM α) (successK : Expr → MetaM α) : MetaM α :=
let majorIdx := recVal.getMajorIdx
if h : majorIdx < recArgs.size then do
let major := recArgs.get ⟨majorIdx, h⟩
Expand All @@ -198,11 +198,12 @@ private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : A
major := major.toCtorIfLit
major ← cleanupNatOffsetMajor major
major ← toCtorWhenStructure recVal.getMajorInduct major
let failK := failK (mkAppN (.const recVal.name recLvls) (recArgs.set ⟨majorIdx, h⟩ major))
match getRecRuleFor recVal major with
| some rule =>
let majorArgs := major.getAppArgs
if recLvls.length != recVal.levelParams.length then
failK ()
failK
else
let rhs := rule.rhs.instantiateLevelParams recVal.levelParams recLvls
-- Apply parameters, motives and minor premises from recursor application.
Expand All @@ -214,9 +215,9 @@ private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : A
let rhs := mkAppRange rhs nparams majorArgs.size majorArgs
let rhs := mkAppRange rhs (majorIdx + 1) recArgs.size recArgs
successK rhs
| none => failK ()
| none => failK
else
failK ()
failK (mkAppN (.const recVal.name recLvls) recArgs)

-- ===========================
/-! # Helper functions for reducing Quot.lift and Quot.ind -/
Expand Down Expand Up @@ -642,7 +643,7 @@ where
| .notMatcher =>
matchConstAux f' (fun _ => return e) fun cinfo lvls =>
match cinfo with
| .recInfo rec => reduceRec rec lvls e.getAppArgs (fun _ => return e) (fun e => do recordUnfold cinfo.name; go e)
| .recInfo rec => reduceRec rec lvls e.getAppArgs pure (fun e => do recordUnfold cinfo.name; go e)
| .quotInfo rec => reduceQuotRec rec e.getAppArgs (fun _ => return e) (fun e => do recordUnfold cinfo.name; go e)
| c@(.defnInfo _) => do
if (← isAuxDef c.name) then
Expand All @@ -651,11 +652,11 @@ where
else
return e
| _ => return e
| .proj _ i c =>
| .proj n i c =>
let k (c : Expr) := do
match (← projectCore? c i) with
| some e => go e
| none => return e
| none => return .proj n i c
match config.proj with
| .no => return e
| .yes => k (← go c)
Expand Down
1 change: 1 addition & 0 deletions src/Std/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Henrik Böving
-/
prelude
import Std.Tactic.BVDecide
import Std.Tactic.RSimp

/-!
This directory is mainly used for bootstrapping reasons. Suppose a tactic generates a proof term
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace BVPred
def bitblast (aig : AIG BVBit) (pred : BVPred) : AIG.Entrypoint BVBit :=
match pred with
| .bin lhs op rhs =>
let res := lhs.bitblast aig
have res := lhs.bitblast aig
let aig := res.aig
let lhsRefs := res.vec
let res := rhs.bitblast aig
Expand Down
13 changes: 13 additions & 0 deletions src/Std/Tactic/RSimp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
prelude
import Std.Tactic.RSimp.Setup
import Std.Tactic.RSimp.RSimpDecide
import Std.Tactic.RSimp.Optimize

/-!
This directory contains the implementation of the `rsimp_decide` tactic and infrastructure for that.
-/
167 changes: 167 additions & 0 deletions src/Std/Tactic/RSimp/ConvTheorem.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/

prelude
import Std.Tactic.RSimp.Setup
import Std.Tactic.RSimp.Fuel
import Lean.Elab.Tactic
import Lean.Elab.DeclUtil
import Lean.Elab.Command
import Lean.Elab.Tactic.Conv
import Init.Tactics
import Init.Conv

open Lean.Parser.Tactic in
/--
TODO
-/
syntax (name := convTheorem) declModifiers "conv_theorem" declId declSig " => " Conv.convSeq : command

open Lean.Parser.Tactic in
syntax (name := withFuel) "withFuel" " => " Conv.convSeq : conv

syntax (name := abstractAs) "abstractAs " nestedDeclModifiers ident : conv

open Lean Meta Elab Command Tactic Conv

@[tactic withFuel] def evalWithFuel : Tactic := fun stx => do
let lhs ← getLhs
withMainContext do
-- TODO: This needs to generalize some variables first
let (rhs, proof) ← convert lhs (evalTactic stx[2])
if lhs == rhs then
throwError "Non-productive recursive equation {lhs} = {rhs}."
else if let some (rhs', proof') ← recursionToFuel? lhs rhs proof then
updateLhs rhs' proof'
else
throwError "Did not find {lhs} in {rhs}."

@[tactic abstractAs] def evalAbstractAs : Tactic := fun stx => do
let lhs ← getLhs
let modifiers ← elabModifiers ⟨stx[1]⟩
let declId := stx[2]
let ⟨_, declName, _⟩ ← Term.expandDeclId (← getCurrNamespace) (← Term.getLevelNames) declId modifiers
let e ← mkAuxDefinition (compile := false) declName (← inferType lhs) lhs
withSaveInfoContext <| Term.addTermInfo' declId e.getAppFn (isBinder := true)
changeLhs e

def convert (lhs : Expr) (conv : TacticM Unit) : TermElabM (Expr × Expr) := do
let (rhs, newGoal) ← Conv.mkConvGoalFor lhs
let _ ← Tactic.run newGoal.mvarId! do
conv
for mvarId in (← getGoals) do
liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure ()
pruneSolvedGoals
unless (← getGoals).isEmpty do
throwError "convert tactic failed, there are unsolved goals\n{goalsToMessageData (← getGoals)}"
return (← instantiateMVars rhs, ← instantiateMVars newGoal)

@[command_elab convTheorem]
def elabConvTheorem : CommandElab := fun stx => do
let modifiers ← elabModifiers ⟨stx[0]⟩
let declId := stx[2]
let (binders, lhsStx) := expandDeclSig stx[3]
runTermElabM fun vars => do
let scopeLevelNames ← Term.getLevelNames
let ⟨shortName, declName, allUserLevelNames⟩ ← Term.expandDeclId (← getCurrNamespace) scopeLevelNames declId modifiers
addDeclarationRangesForBuiltin declName modifiers.stx stx
Term.withAutoBoundImplicitForbiddenPred (fun n => shortName == n) do
Term.withDeclName declName <| Term.withLevelNames allUserLevelNames <| Term.elabBinders binders.getArgs fun xs => do
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.beforeElaboration
let lhs ← Term.elabTermAndSynthesize lhsStx none
Term.synthesizeSyntheticMVarsNoPostponing
let lhs ← instantiateMVars lhs
let (rhs, value) ← convert lhs (Tactic.evalTactic stx[5])
let eqType ← mkEq lhs rhs
let type ← mkForallFVars xs eqType
let type ← mkForallFVars vars type
let type ← Term.levelMVarToParam type
let value ← mkLambdaFVars xs value
let value ← mkLambdaFVars vars value
let usedParams := collectLevelParams {} type |>.params
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedParams with
| Except.error msg => throwErrorAt stx msg
| Except.ok levelParams =>
let type ← instantiateMVars type
let value ← instantiateMVars value

let decl := Declaration.thmDecl {
name := declName,
levelParams := levelParams,
type := type,
value := value,
}
Term.ensureNoUnassignedMVars decl
addDecl decl
withSaveInfoContext do
Term.addTermInfo' declId (← mkConstWithLevelParams declName) (isBinder := true)
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterTypeChecking
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterCompilation



conv_theorem test : id 1 + 2 =>
skip
unfold id
simp

/-- info: test : id 1 + 2 = 3 -/
#guard_msgs in
#check test

conv_theorem test2 (a b c : Fin 12) : (a + b + c).val =>
simp [Fin.ext_iff, Fin.val_add]
abstractAs /-- docstrings possible! -/ test2_rhs

/--
info: def test2_rhs : Fin 12 → Fin 12 → Fin 12 → Nat :=
fun a b c => (↑a + ↑b + ↑c) % 12
-/
#guard_msgs in
#print test2_rhs

/-- info: test2 (a b c : Fin 12) : ↑(a + b + c) = test2_rhs a b c -/
#guard_msgs in
#check test2

/-- error: unknown identifier 'foo' -/
#guard_msgs in
conv_theorem bad_elab1 : foo && true => skip

-- TODO: Why does this not abort nicely

/--
error: unknown identifier 'foo'
---
error: (kernel) declaration has metavariables 'bad_elab2'
-/
#guard_msgs in
conv_theorem bad_elab2 : foo true => skip

def fib : Nat → Nat
| 0 => 0
| 1 => 1
| n+2 => fib n + fib (n+1)

conv_theorem fib_optimize : fib =>
withFuel =>
ext n
unfold fib
simp [← Nat.add_eq]
abstractAs fib.opt

/--
info: def fib.opt : Nat → Nat :=
rsimp_iterate fib fun ih n =>
match n with
| 0 => 0
| 1 => 1
| n.succ.succ => (ih n).add (ih (n.add 1))
-/
#guard_msgs in
#print fib.opt

def fib' (n : Nat) := (fib n, fib (n+1))
32 changes: 32 additions & 0 deletions src/Std/Tactic/RSimp/Fuel.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import Lean

open Lean Meta

def lots_of_fuel : Nat := 9223372036854775807

def rsimp_iterate {α : Sort u} (x : α) (f : α → α) : α :=
Nat.rec x (fun _ ih => f ih) lots_of_fuel

theorem reduce_with_fuel {α : Sort u} {x : α} {f : α → α} (h : x = f x) :
x = rsimp_iterate x f := by
unfold rsimp_iterate
exact Nat.rec rfl (fun _ ih => h.trans (congrArg f ih)) lots_of_fuel

def recursionToFuel? (lhs rhs proof : Expr) : MetaM (Option (Expr × Expr)) := do
let f ← kabstract rhs lhs
if f.hasLooseBVars then
let t ← inferType lhs
let u ← getLevel t
let f := mkLambda `ih .default t f
let rhs' := mkApp3 (.const ``rsimp_iterate [u]) t lhs f
let proof' := mkApp4 (.const ``reduce_with_fuel [u]) t lhs f proof
return some (rhs', proof')
else
return none

def recursionToFuel (lhs rhs proof : Expr) : MetaM (Expr × Expr) := do
if let some (rhs', proof') ← recursionToFuel? lhs rhs proof then
return (rhs', proof')
else
-- Not (obviously) recursive
return (rhs, proof)
77 changes: 77 additions & 0 deletions src/Std/Tactic/RSimp/Optimize.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/

prelude
import Std.Tactic.RSimp.Setup
import Std.Tactic.RSimp.Fuel
import Lean.Elab.Tactic
import Init.Tactics

open Lean Meta

-- TODO: Namespace

initialize registerTraceClass `tactic.rsimp_optimize

def getEqUnfold (declName : Name) : MetaM (Option (Expr × Expr)) := do
-- TODO: Make nicer, and move near the eqUnfold definition
if (← getUnfoldEqnFor? declName (nonRec := false)).isSome then
let unfold := .str declName eqUnfoldThmSuffix
executeReservedNameAction unfold
let unfoldProof ← mkConstWithLevelParams unfold
let some (_, _, rhs) := (← inferType unfoldProof).eq? | throwError "Unexpected type of {unfold}"
return some (rhs, unfoldProof)
else return none

def optimize (declName : Name) : MetaM Unit := do
let opt_name := .str declName "rsimp"
let proof_name := .str declName "eq_rsimp"
if (← getEnv).contains opt_name then throwError "{opt_name} has already been declared"
if (← getEnv).contains proof_name then throwError "{proof_name} has already been declared"

let info ← getConstInfoDefn declName
let lhs := mkConst declName (info.levelParams.map mkLevelParam)
let (rhs0, rwProof) ←
if let some (rhs, unfoldProof) ← getEqUnfold declName then
pure (rhs, unfoldProof)
else
let unfoldProof ← mkEqRefl lhs
pure (info.value, unfoldProof)

-- Do we need to give the user control over the simplifier here?
-- TODO: Unify with rsimp_decide
let .some se ← getSimpExtension? `rsimp | throwError "simp set 'rsimp' not found"
-- TODO: zeta := false seems reasonable, we do not want to duplicate terms
-- but it can produce type-incorrect terms here.
let ctx : Simp.Context := { config := {}, simpTheorems := #[(← se.getTheorems)], congrTheorems := (← Meta.getSimpCongrTheorems) }
let (res, _stats) ← simp rhs0 ctx #[(← Simp.getSimprocs)] none
let rhs := res.expr
let proof ← mkEqTrans rwProof (← res.getProof)

let (rhs, proof) ← recursionToFuel lhs rhs proof

trace[tactic.rsimp_optimize] "Optimizing {lhs} to:{indentExpr rhs}"
addDecl <| Declaration.defnDecl { info with
name := opt_name, type := info.type, value := rhs, levelParams := info.levelParams
}
let proof_type ← mkEq lhs (mkConst opt_name (info.levelParams.map mkLevelParam))
addDecl <| Declaration.thmDecl {
name := proof_name, type := proof_type, value := proof, levelParams := info.levelParams
}
addSimpTheorem se proof_name (post := true) (inv := false) AttributeKind.global (prio := eval_prio default)

/--
TODO
-/
syntax (name := rsimp_optimize) "rsimp_optimize" : attr

initialize registerBuiltinAttribute {
name := `rsimp_optimize
descr := "optimize for kernel reduction"
add := fun declName _stx attrKind => do
unless attrKind == AttributeKind.global do throwError "invalid attribute 'rsimp_optimize', must be global"
(optimize declName).run' {} {}
}
Loading
Loading