diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 1bb9a68026c1..d6bab55a6eed 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -147,7 +147,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do || (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then withDefault <| unfoldDefinition? e else - -- `We are not unfolding partial applications, and `fName` does not have smart unfolding support. + -- We are not unfolding partial applications, and `fName` does not have smart unfolding support. -- Thus, we must check whether the arity of the function >= number of arguments. let some cinfo := (← getEnv).find? fName | return none let some value := cinfo.value? | return none diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index e3eb695a7606..850c7a4cf75e 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -181,7 +181,7 @@ private def cleanupNatOffsetMajor (e : Expr) : MetaM Expr := do return mkNatSucc (mkNatAdd e (toExpr (k - 1))) /-- Auxiliary function for reducing recursor applications. -/ -private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Unit → MetaM α) (successK : Expr → MetaM α) : MetaM α := +private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Expr → MetaM α) (successK : Expr → MetaM α) : MetaM α := let majorIdx := recVal.getMajorIdx if h : majorIdx < recArgs.size then do let major := recArgs.get ⟨majorIdx, h⟩ @@ -198,11 +198,12 @@ private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : A major := major.toCtorIfLit major ← cleanupNatOffsetMajor major major ← toCtorWhenStructure recVal.getMajorInduct major + let failK := failK (mkAppN (.const recVal.name recLvls) (recArgs.set ⟨majorIdx, h⟩ major)) match getRecRuleFor recVal major with | some rule => let majorArgs := major.getAppArgs if recLvls.length != recVal.levelParams.length then - failK () + failK else let rhs := rule.rhs.instantiateLevelParams recVal.levelParams recLvls -- Apply parameters, motives and minor premises from recursor application. @@ -214,9 +215,9 @@ private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : A let rhs := mkAppRange rhs nparams majorArgs.size majorArgs let rhs := mkAppRange rhs (majorIdx + 1) recArgs.size recArgs successK rhs - | none => failK () + | none => failK else - failK () + failK (mkAppN (.const recVal.name recLvls) recArgs) -- =========================== /-! # Helper functions for reducing Quot.lift and Quot.ind -/ @@ -642,7 +643,7 @@ where | .notMatcher => matchConstAux f' (fun _ => return e) fun cinfo lvls => match cinfo with - | .recInfo rec => reduceRec rec lvls e.getAppArgs (fun _ => return e) (fun e => do recordUnfold cinfo.name; go e) + | .recInfo rec => reduceRec rec lvls e.getAppArgs pure (fun e => do recordUnfold cinfo.name; go e) | .quotInfo rec => reduceQuotRec rec e.getAppArgs (fun _ => return e) (fun e => do recordUnfold cinfo.name; go e) | c@(.defnInfo _) => do if (← isAuxDef c.name) then @@ -651,11 +652,11 @@ where else return e | _ => return e - | .proj _ i c => + | .proj n i c => let k (c : Expr) := do match (← projectCore? c i) with | some e => go e - | none => return e + | none => return .proj n i c match config.proj with | .no => return e | .yes => k (← go c) diff --git a/src/Std/Tactic.lean b/src/Std/Tactic.lean index b2a31e324973..de1830042ecb 100644 --- a/src/Std/Tactic.lean +++ b/src/Std/Tactic.lean @@ -5,6 +5,7 @@ Authors: Henrik Böving -/ prelude import Std.Tactic.BVDecide +import Std.Tactic.RSimp /-! This directory is mainly used for bootstrapping reasons. Suppose a tactic generates a proof term diff --git a/src/Std/Tactic/BVDecide/Bitblast/BVExpr/Circuit/Impl/Pred.lean b/src/Std/Tactic/BVDecide/Bitblast/BVExpr/Circuit/Impl/Pred.lean index d3e2728f76ca..6c813c8b7dd4 100644 --- a/src/Std/Tactic/BVDecide/Bitblast/BVExpr/Circuit/Impl/Pred.lean +++ b/src/Std/Tactic/BVDecide/Bitblast/BVExpr/Circuit/Impl/Pred.lean @@ -23,7 +23,7 @@ namespace BVPred def bitblast (aig : AIG BVBit) (pred : BVPred) : AIG.Entrypoint BVBit := match pred with | .bin lhs op rhs => - let res := lhs.bitblast aig + have res := lhs.bitblast aig let aig := res.aig let lhsRefs := res.vec let res := rhs.bitblast aig diff --git a/src/Std/Tactic/RSimp.lean b/src/Std/Tactic/RSimp.lean new file mode 100644 index 000000000000..1fc5c7019dd4 --- /dev/null +++ b/src/Std/Tactic/RSimp.lean @@ -0,0 +1,13 @@ +/- +Copyright (c) 2024 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ +prelude +import Std.Tactic.RSimp.Setup +import Std.Tactic.RSimp.RSimpDecide +import Std.Tactic.RSimp.Optimize + +/-! +This directory contains the implementation of the `rsimp_decide` tactic and infrastructure for that. +-/ diff --git a/src/Std/Tactic/RSimp/ConvTheorem.lean b/src/Std/Tactic/RSimp/ConvTheorem.lean new file mode 100644 index 000000000000..dfedaf5000a5 --- /dev/null +++ b/src/Std/Tactic/RSimp/ConvTheorem.lean @@ -0,0 +1,167 @@ +/- +Copyright (c) 2024 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ + +prelude +import Std.Tactic.RSimp.Setup +import Std.Tactic.RSimp.Fuel +import Lean.Elab.Tactic +import Lean.Elab.DeclUtil +import Lean.Elab.Command +import Lean.Elab.Tactic.Conv +import Init.Tactics +import Init.Conv + +open Lean.Parser.Tactic in +/-- +TODO +-/ +syntax (name := convTheorem) declModifiers "conv_theorem" declId declSig " => " Conv.convSeq : command + +open Lean.Parser.Tactic in +syntax (name := withFuel) "withFuel" " => " Conv.convSeq : conv + +syntax (name := abstractAs) "abstractAs " nestedDeclModifiers ident : conv + +open Lean Meta Elab Command Tactic Conv + +@[tactic withFuel] def evalWithFuel : Tactic := fun stx => do + let lhs ← getLhs + withMainContext do + -- TODO: This needs to generalize some variables first + let (rhs, proof) ← convert lhs (evalTactic stx[2]) + if lhs == rhs then + throwError "Non-productive recursive equation {lhs} = {rhs}." + else if let some (rhs', proof') ← recursionToFuel? lhs rhs proof then + updateLhs rhs' proof' + else + throwError "Did not find {lhs} in {rhs}." + +@[tactic abstractAs] def evalAbstractAs : Tactic := fun stx => do + let lhs ← getLhs + let modifiers ← elabModifiers ⟨stx[1]⟩ + let declId := stx[2] + let ⟨_, declName, _⟩ ← Term.expandDeclId (← getCurrNamespace) (← Term.getLevelNames) declId modifiers + let e ← mkAuxDefinition (compile := false) declName (← inferType lhs) lhs + withSaveInfoContext <| Term.addTermInfo' declId e.getAppFn (isBinder := true) + changeLhs e + +def convert (lhs : Expr) (conv : TacticM Unit) : TermElabM (Expr × Expr) := do + let (rhs, newGoal) ← Conv.mkConvGoalFor lhs + let _ ← Tactic.run newGoal.mvarId! do + conv + for mvarId in (← getGoals) do + liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure () + pruneSolvedGoals + unless (← getGoals).isEmpty do + throwError "convert tactic failed, there are unsolved goals\n{goalsToMessageData (← getGoals)}" + return (← instantiateMVars rhs, ← instantiateMVars newGoal) + +@[command_elab convTheorem] +def elabConvTheorem : CommandElab := fun stx => do + let modifiers ← elabModifiers ⟨stx[0]⟩ + let declId := stx[2] + let (binders, lhsStx) := expandDeclSig stx[3] + runTermElabM fun vars => do + let scopeLevelNames ← Term.getLevelNames + let ⟨shortName, declName, allUserLevelNames⟩ ← Term.expandDeclId (← getCurrNamespace) scopeLevelNames declId modifiers + addDeclarationRangesForBuiltin declName modifiers.stx stx + Term.withAutoBoundImplicitForbiddenPred (fun n => shortName == n) do + Term.withDeclName declName <| Term.withLevelNames allUserLevelNames <| Term.elabBinders binders.getArgs fun xs => do + Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.beforeElaboration + let lhs ← Term.elabTermAndSynthesize lhsStx none + Term.synthesizeSyntheticMVarsNoPostponing + let lhs ← instantiateMVars lhs + let (rhs, value) ← convert lhs (Tactic.evalTactic stx[5]) + let eqType ← mkEq lhs rhs + let type ← mkForallFVars xs eqType + let type ← mkForallFVars vars type + let type ← Term.levelMVarToParam type + let value ← mkLambdaFVars xs value + let value ← mkLambdaFVars vars value + let usedParams := collectLevelParams {} type |>.params + match sortDeclLevelParams scopeLevelNames allUserLevelNames usedParams with + | Except.error msg => throwErrorAt stx msg + | Except.ok levelParams => + let type ← instantiateMVars type + let value ← instantiateMVars value + + let decl := Declaration.thmDecl { + name := declName, + levelParams := levelParams, + type := type, + value := value, + } + Term.ensureNoUnassignedMVars decl + addDecl decl + withSaveInfoContext do + Term.addTermInfo' declId (← mkConstWithLevelParams declName) (isBinder := true) + Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterTypeChecking + Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterCompilation + + + +conv_theorem test : id 1 + 2 => + skip + unfold id + simp + +/-- info: test : id 1 + 2 = 3 -/ +#guard_msgs in +#check test + +conv_theorem test2 (a b c : Fin 12) : (a + b + c).val => + simp [Fin.ext_iff, Fin.val_add] + abstractAs /-- docstrings possible! -/ test2_rhs + +/-- +info: def test2_rhs : Fin 12 → Fin 12 → Fin 12 → Nat := +fun a b c => (↑a + ↑b + ↑c) % 12 +-/ +#guard_msgs in +#print test2_rhs + +/-- info: test2 (a b c : Fin 12) : ↑(a + b + c) = test2_rhs a b c -/ +#guard_msgs in +#check test2 + +/-- error: unknown identifier 'foo' -/ +#guard_msgs in +conv_theorem bad_elab1 : foo && true => skip + +-- TODO: Why does this not abort nicely + +/-- +error: unknown identifier 'foo' +--- +error: (kernel) declaration has metavariables 'bad_elab2' +-/ +#guard_msgs in +conv_theorem bad_elab2 : foo true => skip + +def fib : Nat → Nat + | 0 => 0 + | 1 => 1 + | n+2 => fib n + fib (n+1) + +conv_theorem fib_optimize : fib => + withFuel => + ext n + unfold fib + simp [← Nat.add_eq] + abstractAs fib.opt + +/-- +info: def fib.opt : Nat → Nat := +rsimp_iterate fib fun ih n => + match n with + | 0 => 0 + | 1 => 1 + | n.succ.succ => (ih n).add (ih (n.add 1)) +-/ +#guard_msgs in +#print fib.opt + +def fib' (n : Nat) := (fib n, fib (n+1)) diff --git a/src/Std/Tactic/RSimp/Fuel.lean b/src/Std/Tactic/RSimp/Fuel.lean new file mode 100644 index 000000000000..91b33907bdb6 --- /dev/null +++ b/src/Std/Tactic/RSimp/Fuel.lean @@ -0,0 +1,32 @@ +import Lean + +open Lean Meta + +def lots_of_fuel : Nat := 9223372036854775807 + +def rsimp_iterate {α : Sort u} (x : α) (f : α → α) : α := + Nat.rec x (fun _ ih => f ih) lots_of_fuel + +theorem reduce_with_fuel {α : Sort u} {x : α} {f : α → α} (h : x = f x) : + x = rsimp_iterate x f := by + unfold rsimp_iterate + exact Nat.rec rfl (fun _ ih => h.trans (congrArg f ih)) lots_of_fuel + +def recursionToFuel? (lhs rhs proof : Expr) : MetaM (Option (Expr × Expr)) := do + let f ← kabstract rhs lhs + if f.hasLooseBVars then + let t ← inferType lhs + let u ← getLevel t + let f := mkLambda `ih .default t f + let rhs' := mkApp3 (.const ``rsimp_iterate [u]) t lhs f + let proof' := mkApp4 (.const ``reduce_with_fuel [u]) t lhs f proof + return some (rhs', proof') + else + return none + +def recursionToFuel (lhs rhs proof : Expr) : MetaM (Expr × Expr) := do + if let some (rhs', proof') ← recursionToFuel? lhs rhs proof then + return (rhs', proof') + else + -- Not (obviously) recursive + return (rhs, proof) diff --git a/src/Std/Tactic/RSimp/Optimize.lean b/src/Std/Tactic/RSimp/Optimize.lean new file mode 100644 index 000000000000..c4aa7528c491 --- /dev/null +++ b/src/Std/Tactic/RSimp/Optimize.lean @@ -0,0 +1,77 @@ +/- +Copyright (c) 2024 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ + +prelude +import Std.Tactic.RSimp.Setup +import Std.Tactic.RSimp.Fuel +import Lean.Elab.Tactic +import Init.Tactics + +open Lean Meta + +-- TODO: Namespace + +initialize registerTraceClass `tactic.rsimp_optimize + +def getEqUnfold (declName : Name) : MetaM (Option (Expr × Expr)) := do + -- TODO: Make nicer, and move near the eqUnfold definition + if (← getUnfoldEqnFor? declName (nonRec := false)).isSome then + let unfold := .str declName eqUnfoldThmSuffix + executeReservedNameAction unfold + let unfoldProof ← mkConstWithLevelParams unfold + let some (_, _, rhs) := (← inferType unfoldProof).eq? | throwError "Unexpected type of {unfold}" + return some (rhs, unfoldProof) + else return none + +def optimize (declName : Name) : MetaM Unit := do + let opt_name := .str declName "rsimp" + let proof_name := .str declName "eq_rsimp" + if (← getEnv).contains opt_name then throwError "{opt_name} has already been declared" + if (← getEnv).contains proof_name then throwError "{proof_name} has already been declared" + + let info ← getConstInfoDefn declName + let lhs := mkConst declName (info.levelParams.map mkLevelParam) + let (rhs0, rwProof) ← + if let some (rhs, unfoldProof) ← getEqUnfold declName then + pure (rhs, unfoldProof) + else + let unfoldProof ← mkEqRefl lhs + pure (info.value, unfoldProof) + + -- Do we need to give the user control over the simplifier here? + -- TODO: Unify with rsimp_decide + let .some se ← getSimpExtension? `rsimp | throwError "simp set 'rsimp' not found" + -- TODO: zeta := false seems reasonable, we do not want to duplicate terms + -- but it can produce type-incorrect terms here. + let ctx : Simp.Context := { config := {}, simpTheorems := #[(← se.getTheorems)], congrTheorems := (← Meta.getSimpCongrTheorems) } + let (res, _stats) ← simp rhs0 ctx #[(← Simp.getSimprocs)] none + let rhs := res.expr + let proof ← mkEqTrans rwProof (← res.getProof) + + let (rhs, proof) ← recursionToFuel lhs rhs proof + + trace[tactic.rsimp_optimize] "Optimizing {lhs} to:{indentExpr rhs}" + addDecl <| Declaration.defnDecl { info with + name := opt_name, type := info.type, value := rhs, levelParams := info.levelParams + } + let proof_type ← mkEq lhs (mkConst opt_name (info.levelParams.map mkLevelParam)) + addDecl <| Declaration.thmDecl { + name := proof_name, type := proof_type, value := proof, levelParams := info.levelParams + } + addSimpTheorem se proof_name (post := true) (inv := false) AttributeKind.global (prio := eval_prio default) + +/-- +TODO +-/ +syntax (name := rsimp_optimize) "rsimp_optimize" : attr + +initialize registerBuiltinAttribute { + name := `rsimp_optimize + descr := "optimize for kernel reduction" + add := fun declName _stx attrKind => do + unless attrKind == AttributeKind.global do throwError "invalid attribute 'rsimp_optimize', must be global" + (optimize declName).run' {} {} + } diff --git a/src/Std/Tactic/RSimp/RSimpDecide.lean b/src/Std/Tactic/RSimp/RSimpDecide.lean new file mode 100644 index 000000000000..15c3ab5fb116 --- /dev/null +++ b/src/Std/Tactic/RSimp/RSimpDecide.lean @@ -0,0 +1,81 @@ +/- +Copyright (c) 2024 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ + +prelude +import Std.Tactic.RSimp.Setup +import Lean.Elab.Tactic +import Init.Tactics + +-- TODO: Namespace + +open Lean Elab Tactic Meta + +private def preprocessPropToDecide (expectedType : Expr) : TermElabM Expr := do + let mut expectedType ← instantiateMVars expectedType + if expectedType.hasFVar then + expectedType ← zetaReduce expectedType + if expectedType.hasFVar || expectedType.hasMVar then + throwError "expected type must not contain free or meta variables{indentExpr expectedType}" + return expectedType + +theorem of_opt_decide_eq_true {p : Prop} [inst : Decidable p] (c : Bool) (h : decide p = c) + : c = true → p := by subst h; exact of_decide_eq_true + +initialize registerTraceClass `tactic.rsimp_decide +initialize registerTraceClass `tactic.rsimp_decide.debug + +section Syntax +open Lean.Parser.Tactic + +/-- +TODO +-/ +syntax (name := rsimp_decide) "rsimp_decide" (config)? (discharger)? + (&" only")? (" [" (simpErase <|> simpLemma),* "]")? : tactic + +@[tactic rsimp_decide] +def rsimpDecideImpl : Tactic := fun stx => do + -- TODO: Using closeMainGoalUsing did not work as expected + -- closeMainGoalUsing `rsimp_decide fun expectedType _tag => do + withMainContext do + let expectedType ← getMainTarget + let expectedType ← preprocessPropToDecide expectedType + let d ← mkAppOptM ``Decidable.decide #[expectedType, none] + let d ← instantiateMVars d + -- Get instance from `d` + let s := d.appArg! + let decE := mkApp2 (mkConst ``Decidable.decide) expectedType s + let .some se ← getSimpExtension? `rsimp | throwError "simp set 'rsimp' not found" + + -- Setting up the simplifier + -- Passing the stx here is a hairy hack, and only works as long as `rsimp_decide` syntax + -- is compatible with the simp syntax. Maybe mkSimpContext should take the components + -- separately + let scr ← mkSimpContext stx + (simpTheorems := se.getTheorems) (ignoreStarArg := true) (eraseLocal := false) + let (res, _stats) ← scr.dischargeWrapper.with fun discharge? => + simp decE scr.ctx scr.simprocs discharge? + + let optE := res.expr + trace[tactic.rsimp_decide] "Optimized expression:{indentExpr optE}" + let optPrf ← res.getProof + let rflPrf ← mkEqRefl (toExpr true) + let rflType ← mkEq optE (toExpr true) + -- We peform the kernel computation in an auxillary definition, like `decide!` + let levelsInType := (collectLevelParams {} rflType).params + let lemmaLevels := (← Term.getLevelNames).reverse.filter levelsInType.contains + let lemmaName ← + try + mkAuxLemma lemmaLevels rflType rflPrf + catch e => + trace[tactic.rsimp_decide.debug] "mkAuxLemma failed: {e.toMessageData}" + throwTacticEx `rsimp_decide (← getMainGoal) "this may be because the proposition is false, involves non-computable axioms or opaque definitions." + let eqPrf := mkConst lemmaName (lemmaLevels.map .param) + closeMainGoal `rsimp_decide <| + mkApp5 (Lean.mkConst ``of_opt_decide_eq_true) expectedType s optE optPrf eqPrf + + +end Syntax diff --git a/src/Std/Tactic/RSimp/Setup.lean b/src/Std/Tactic/RSimp/Setup.lean new file mode 100644 index 000000000000..8dee6c7d0450 --- /dev/null +++ b/src/Std/Tactic/RSimp/Setup.lean @@ -0,0 +1,16 @@ +/- +Copyright (c) 2024 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ +prelude +import Lean.Meta.Tactic.Simp.RegisterCommand + +/-- +The `rsimp` simp set is used by the `rsimp_decide` tactic to optimize terms for kernel reduction. + +It is separate from the default simp set because they have different normal forms. For example +`simp` wants to replae concrete operations like `Nat.add` with the overloaded `+`, but for +efficient reduction, we want to go the other way. +-/ +register_simp_attr rsimp diff --git a/tests/lean/run/rsimp.lean b/tests/lean/run/rsimp.lean new file mode 100644 index 000000000000..a2a3a8cf7396 --- /dev/null +++ b/tests/lean/run/rsimp.lean @@ -0,0 +1,74 @@ +import Std.Tactic.RSimp + +/-! +Basic testing of syntax +-/ + +set_option trace.tactic.rsimp_decide true + +structure MyTrue : Prop + +instance : Decidable MyTrue := .isTrue MyTrue.mk +@[simp] theorem decide_MyTrue : decide MyTrue = true := rfl + +/-- +info: [tactic.rsimp_decide] Optimized expression: + decide MyTrue +-/ +#guard_msgs in +def ex1 : MyTrue := by rsimp_decide + +/-- +info: [tactic.rsimp_decide] Optimized expression: + true +-/ +#guard_msgs in +def ex2 : MyTrue := by rsimp_decide [decide_MyTrue] + + +attribute [rsimp] decide_MyTrue + +/-- +info: [tactic.rsimp_decide] Optimized expression: + true +-/ +#guard_msgs in +def ex3 : MyTrue := by rsimp_decide + +/-- +info: [tactic.rsimp_decide] Optimized expression: + decide MyTrue +-/ +#guard_msgs in +def ex4 : MyTrue := by rsimp_decide only + +/-- +error: tactic 'rsimp_decide' failed, this may be because the proposition is false, involves non-computable axioms or opaque definitions. +⊢ False +--- +info: [tactic.rsimp_decide] Optimized expression: + decide False +-/ +#guard_msgs in +def ex5 : False := by rsimp_decide + +/-- +error: tactic 'rsimp_decide' failed, this may be because the proposition is false, involves non-computable axioms or opaque definitions. +⊢ False +--- +info: [tactic.rsimp_decide] Optimized expression: + decide False +[tactic.rsimp_decide.debug] mkAuxLemma failed: (kernel) declaration type mismatch, 'lean.run.rsimp._auxLemma.3' has type + true = true + but it is expected to have type + decide False = true +-/ +#guard_msgs in +set_option trace.tactic.rsimp_decide.debug true in +def ex6 : False := by rsimp_decide + + +-- Check that level parameters don't trip it up +local instance inst.{u} : Decidable (Nonempty PUnit.{u}) := .isTrue ⟨⟨⟩⟩ +set_option trace.tactic.rsimp_decide false in +def ex7.{v} : Nonempty (PUnit.{v}) := by rsimp_decide diff --git a/tests/lean/run/rsimp_bv_decide.lean b/tests/lean/run/rsimp_bv_decide.lean new file mode 100644 index 000000000000..7668ff35058c --- /dev/null +++ b/tests/lean/run/rsimp_bv_decide.lean @@ -0,0 +1,136 @@ +import Std.Tactic.BVDecide +import Std.Tactic.RSimp +import Std.Tactic.RSimp.ConvTheorem + +theorem ex (i : BitVec 5) : 2 * i &&& 1 = 0 := by bv_decide + +/-- info: true -/ +#guard_msgs in +#eval ex._reflection_def_1 + +-- #print ex._expr_def_1 +-- #print ex._cert_def_1 + +open Lean in +partial def callPaths (source : Name) (target : Name) : CoreM MessageData := do + let rec go (n : Name) : StateT (NameMap Bool) CoreM (Option MessageData) := do + if n = target then + return m!"{.ofConstName n} !" + if let some hit := (← get).find? n then + if hit then + return m!"{.ofConstName n} ↑" + let .defnInfo ci ← getConstInfo n | return none + let ns := Expr.getUsedConstants ci.value + let ms ← ns.filterMapM go + let hit := !ms.isEmpty + modify (·.insert n hit) + unless hit do return none + return some <| .ofConstName n ++ (.nest 1 ("\n" ++ (.joinSep ms.toList "\n"))) + + if let some md ← (go source).run' {} then + pure md + else + pure "No paths from {.ofConstName source} to {.ofConstName target} found" + +open Lean Elab Command in +elab "#call_paths " i1:ident " => " i2:ident : command => liftTermElabM do + let source ← realizeGlobalConstNoOverloadWithInfo i1 + let target ← realizeGlobalConstNoOverloadWithInfo i2 + let m ← callPaths source target + logInfo m + + +def parsedProof : Array Std.Tactic.BVDecide.LRAT.IntAction := + #[Std.Tactic.BVDecide.LRAT.Action.addRup 64 #[-183] #[1, 3], + Std.Tactic.BVDecide.LRAT.Action.addRup 65 #[-8, -182] #[64, 7], + Std.Tactic.BVDecide.LRAT.Action.addRup 66 #[8, 1] #[61], + Std.Tactic.BVDecide.LRAT.Action.addRup 67 #[8] #[62, 66], + Std.Tactic.BVDecide.LRAT.Action.addRup 68 #[-182] #[67, 65], + Std.Tactic.BVDecide.LRAT.Action.addRup 69 #[-9] #[67, 57], + Std.Tactic.BVDecide.LRAT.Action.addRup 70 #[-181] #[68, 67, 10], + Std.Tactic.BVDecide.LRAT.Action.addRup 71 #[10] #[69, 62, 55], + Std.Tactic.BVDecide.LRAT.Action.addRup 72 #[-180] #[70, 67, 13], + Std.Tactic.BVDecide.LRAT.Action.addRup 73 #[-11] #[71, 51], + Std.Tactic.BVDecide.LRAT.Action.addRup 74 #[-179] #[72, 67, 16], + Std.Tactic.BVDecide.LRAT.Action.addRup 75 #[12] #[73, 62, 49], + Std.Tactic.BVDecide.LRAT.Action.addRup 76 #[173] #[74, 62, 19], + Std.Tactic.BVDecide.LRAT.Action.addRup 77 #[-13] #[75, 45], + Std.Tactic.BVDecide.LRAT.Action.addRup 78 #[-140] #[76, 21], + Std.Tactic.BVDecide.LRAT.Action.addRup 79 #[46] #[77, 62, 43], + Std.Tactic.BVDecide.LRAT.Action.addRup 80 #[133] #[78, 62, 25], + Std.Tactic.BVDecide.LRAT.Action.addRup 81 #[-47] #[79, 39], + Std.Tactic.BVDecide.LRAT.Action.addRup 82 #[-99] #[80, 27], + Std.Tactic.BVDecide.LRAT.Action.addRup 83 #[56] #[81, 62, 37], + Std.Tactic.BVDecide.LRAT.Action.addRup 84 #[91] #[82, 62, 31], + Std.Tactic.BVDecide.LRAT.Action.addEmpty 85 #[84, 83, 33]] + +@[rsimp] theorem parse_aux : + Std.Tactic.BVDecide.LRAT.parseLRATProof ex._cert_def_1.toUTF8 + = .ok parsedProof := sorry + +attribute [rsimp_optimize] Std.Tactic.BVDecide.BVExpr.bitblast.blastConst.go +attribute [rsimp_optimize] Std.Tactic.BVDecide.BVExpr.bitblast.blastConst +attribute [rsimp_optimize] Std.Tactic.BVDecide.BVExpr.bitblast.go +attribute [rsimp_optimize] Std.Tactic.BVDecide.BVExpr.bitblast +attribute [rsimp_optimize] Std.Tactic.BVDecide.LRAT.Internal.Formula.performRupAdd +attribute [rsimp_optimize] List.filterMap +attribute [rsimp_optimize] Std.Sat.AIG.Entrypoint.relabelNat +attribute [rsimp_optimize] Std.Sat.AIG.toCNF +attribute [rsimp_optimize] Std.Sat.CNF.numLiterals +attribute [rsimp_optimize] Std.Tactic.BVDecide.LRAT.Internal.intActionToDefaultClauseAction +attribute [rsimp_optimize] Std.Tactic.BVDecide.LRAT.Internal.CNF.convertLRAT +attribute [rsimp_optimize] Std.Tactic.BVDecide.LRAT.Internal.lratChecker +attribute [rsimp_optimize] Std.Tactic.BVDecide.ofBoolExprCached +set_option trace.tactic.rsimp_optimize true in +-- This doesn't rewrite BVExpr.bitblast because it's dependent +-- attribute [rsimp_optimize] Std.Tactic.BVDecide.BVPred.bitblast + +-- Also doesnt work, due to abstracted proofs +set_option pp.proofs true in +conv_theorem bitblast_opt : Std.Tactic.BVDecide.BVPred.bitblast => + unfold Std.Tactic.BVDecide.BVPred.bitblast + unfold Std.Tactic.BVDecide.BVPred.bitblast.proof_1 + simp -zeta [Std.Tactic.BVDecide.BVExpr.bitblast.eq_rsimp] + +attribute [rsimp_optimize] Std.Tactic.BVDecide.BVLogicalExpr.bitblast +attribute [rsimp_optimize] Std.Tactic.BVDecide.LRAT.check +attribute [rsimp_optimize] Std.Tactic.BVDecide.Reflect.verifyCert +attribute [rsimp_optimize] Std.Tactic.BVDecide.Reflect.verifyBVExpr +attribute [rsimp_optimize] ex._reflection_def_1 + + +@[rsimp] conv_theorem unfold_to_parser : + ex._reflection_def_1 => + simp [ex._reflection_def_1, Std.Tactic.BVDecide.Reflect.verifyBVExpr, + Std.Tactic.BVDecide.Reflect.verifyCert, rsimp] + abstractAs optimized_reflection_def + +#call_paths optimized_reflection_def => WellFounded.fix + +-- #print optimized_reflection_def + +open Lean Meta in +elab "#kernel_reduce" t:term : command => Lean.Elab.Command.runTermElabM fun _ => do + let e ← Lean.Elab.Term.elabTermAndSynthesize t none + Lean.Meta.lambdaTelescope e fun _ e => do + -- let e' ← Lean.ofExceptKernelException <| Lean.Kernel.whnf (← Lean.getEnv) (← Lean.getLCtx) e + let e' ← withOptions (smartUnfolding.set · false) <| withTransparency .all <| Lean.Meta.whnf e + Lean.logInfo m!"{e'}" + +open Std.Tactic.BVDecide.LRAT + +-- #kernel_reduce ex._reflection_def_1 +-- /-- info: true -/ #guard_msgs in #eval optimized_reflection_def + +syntax "λ" : term + +open Lean PrettyPrinter Delaborator in +@[delab lam] +def delabLam : Delab := `(λ) + +set_option diagnostics true +set_option diagnostics.threshold 0 +#kernel_reduce optimized_reflection_def + +-- #print ex._reflection_def_1 +-- #eval Std.Sat.AIG.toCNF (Std.Tactic.BVDecide.BVLogicalExpr.bitblast ex._expr_def_1).relabelNat diff --git a/tests/lean/run/rsimp_factorization.lean b/tests/lean/run/rsimp_factorization.lean new file mode 100644 index 000000000000..fcc29b69275b --- /dev/null +++ b/tests/lean/run/rsimp_factorization.lean @@ -0,0 +1,96 @@ +import Std.Tactic.RSimp + +-- An example rsimpset + +-- Unfortunately, `attribute` does not allow to add theorems with symm := true + +@[rsimp] def Nat.beq_eq_symm {x y : Nat} : (x = y) = (x.beq y = true) := (@Nat.beq_eq x y).symm +attribute [rsimp] Nat.dvd_iff_mod_eq_zero +@[rsimp] def Bool.cond_decide_symm := fun α p inst t e => (@Bool.cond_decide α p inst t e).symm +attribute [rsimp] Std.Tactic.BVDecide.Normalize.Bool.decide_eq_true +attribute [rsimp] Bool.decide_and + +@[rsimp] def Nat.ble_eq_symm := fun a b => (@Nat.ble_eq a b).symm +@[rsimp] def Nat.blt_eq_symm := fun a b => (@Nat.blt_eq a b).symm +@[rsimp] def Nat.mod_eq_symm (a b : Nat) : a % b = Nat.mod a b := rfl +@[rsimp] def Nat.add_eq_symm := fun a b => (@Nat.add_eq a b).symm +@[rsimp] def Nat.mul_eq_symm := fun a b => (@Nat.mul_eq a b).symm + +-- Somehow, simp does not like unfolding cond.match_1 +-- attribute [rsimp] cond cond.match_1 +@[rsimp] theorem cond_eq_rec (b : Bool) (t e : α) : cond b t e = b.rec e t := by + cases b <;> rfl + +@[semireducible] def minFacAux (n : Nat) : Nat → Nat + | k => + if n < k * k then n + else + if k ∣ n then k + else + minFacAux n (k + 2) +termination_by k => n + 2 - k +decreasing_by sorry + +def minFac (n : Nat) : Nat := if 2 ∣ n then 2 else minFacAux n 3 +def isPrime (p : Nat) : Bool := 2 ≤ p ∧ minFac p = p + +/-- +error: maximum recursion depth has been reached +use `set_option maxRecDepth ` to increase limit +use `set_option diagnostics true` to get diagnostic information +-/ +#guard_msgs in +example : isPrime 524287 := by decide + +attribute [rsimp_optimize] minFacAux +attribute [rsimp_optimize] minFac +attribute [rsimp_optimize] isPrime + +/-- +info: def minFacAux.rsimp : Nat → Nat → Nat := +rsimp_iterate minFacAux fun ih n x => Bool.rec (Bool.rec (ih n (x.add 2)) x ((n.mod x).beq 0)) n (n.blt (x.mul x)) +-/ +#guard_msgs in +#print minFacAux.rsimp + +/-- +info: def minFac.rsimp : Nat → Nat := +fun n => Bool.rec (minFacAux.rsimp n 3) 2 ((n.mod 2).beq 0) +-/ +#guard_msgs in +#print minFac.rsimp + +/-- +info: def isPrime.rsimp : Nat → Bool := +fun p => Nat.ble 2 p && (minFac.rsimp p).beq p +-/ +#guard_msgs in +#print isPrime.rsimp + + +set_option trace.tactic.rsimp_decide true in + +/-- +info: [tactic.rsimp_decide] Optimized expression: + isPrime.rsimp 524287 +-/ +#guard_msgs in +example : isPrime 524287 := by rsimp_decide + +-- 6ms: +-- #time example : isPrime 524287 := by rsimp_decide +-- 25ms: +-- #time example : isPrime 10000019 := by rsimp_decide + + +-- For larger ones, we get deep kernel recursion. I wonder what's recursive here? + +set_option trace.tactic.rsimp_decide.debug true in +/-- +error: tactic 'rsimp_decide' failed, this may be because the proposition is false, involves non-computable axioms or opaque definitions. +⊢ isPrime 100000007 = true +--- +info: [tactic.rsimp_decide.debug] mkAuxLemma failed: (kernel) deep recursion detected +-/ +#guard_msgs in +example : isPrime 100000007 := by rsimp_decide diff --git a/tests/lean/run/rsimp_linear.lean b/tests/lean/run/rsimp_linear.lean new file mode 100644 index 000000000000..0999ab64b561 --- /dev/null +++ b/tests/lean/run/rsimp_linear.lean @@ -0,0 +1,411 @@ +import Std.Tactic.RSimp +import Std.Tactic.RSimp.ConvTheorem + +import Lean + +namespace Data + +/-! +A data structure for modelling `Fin n → α` (or `Array α`) optimized for a fast kernel-reduction get +operation. + +For now not universe-polymorphic; smaller proof objects and no complication with the `ToExpr` type +class. +-/ + +inductive RArray (α : Type) : Type where + | leaf : α → RArray α + | branch : Nat → RArray α → RArray α → RArray α + +variable {α : Type} + +/-- The crucial operation, written with very little abstractional overhead -/ +noncomputable def RArray.get (a : RArray α) (n : Nat) : α := + RArray.rec (fun x => x) (fun p _ _ l r => (Nat.ble p n).rec l r) a + +theorem RArray.get_eq_def (a : RArray α) (n : Nat) : + a.get n = match a with + | .leaf x => x + | .branch p l r => (Nat.ble p n).rec (l.get n) (r.get n) := by + conv => lhs; unfold RArray.get + split <;> rfl + +def RArray.getImpl (a : RArray α) (n : Nat) : α := + match a with + | .leaf x => x + | .branch p l r => if n < p then l.getImpl n else r.getImpl n + +@[csimp] +theorem RArray.get_eq_getImpl : @RArray.get = @RArray.getImpl := by + ext α a n + induction a with + | leaf _ => rfl + | branch p l r ihl ihr => + rw [RArray.getImpl, RArray.get_eq_def] + simp only [ihl, ihr] + cases hnp : Nat.ble p n + · replace hnp := ne_true_of_eq_false hnp + simp at hnp + rw [if_pos] + omega + · simp at hnp + rw [if_neg] + omega + +instance : GetElem (RArray α) Nat α (fun _ _ => True) where + getElem a n _ := a.get n + +def RArray.size : RArray α → Nat + | leaf _ => 1 + | branch _ l r => l.size + r.size + +def RArray.ofFn {n : Nat} (f : Fin n → α) (h : 0 < n) : RArray α := + go 0 n h (Nat.le_refl _) +where + go (lb ub : Nat) (h1 : lb < ub) (h2 : ub ≤ n) : RArray α := + if h : lb + 1 = ub then + .leaf (f ⟨lb, Nat.lt_of_lt_of_le h1 h2⟩) + else + let mid := (lb + ub)/2 + .branch mid (go lb mid (by omega) (by omega)) (go mid ub (by omega) h2) + +def RArray.ofArray (xs : Array α) (h : 0 < xs.size) : RArray α := + .ofFn (fun i => xs.get i) h + +theorem RArray.ofFn_correct {n : Nat} (f : Fin n → α) (h : 0 < n) (i : Fin n): + RArray.get (.ofFn f h) i = f i := + go 0 n h (Nat.le_refl _) (Nat.zero_le _) i.2 +where + go lb ub h1 h2 (h3 : lb ≤ i.val) (h3 : i.val < ub) : RArray.get (.ofFn.go f lb ub h1 h2) i = f i := by + induction lb, ub, h1, h2 using RArray.ofFn.go.induct (f := f) (n := n) + case case1 => + simp [ofFn.go, RArray.get_eq_getImpl, RArray.getImpl] + congr + omega + case case2 ih1 ih2 hiu => + rw [ofFn.go]; simp only [↓reduceDIte, *] + simp [RArray.get_eq_getImpl, RArray.getImpl] at * + split + · rw [ih1] <;> omega + · rw [ih2] <;> omega + + +section Meta +open Lean + +def RArray.toExpr (ty : Expr) (f : α → Expr) : RArray α → Expr + | .leaf x => + mkApp2 (mkConst ``RArray.leaf) ty (f x) + | .branch p l r => + mkApp4 (mkConst ``RArray.branch) ty (.lit (.natVal p)) (l.toExpr ty f) (r.toExpr ty f) + +instance [ToExpr α] : ToExpr (RArray α) where + toExpr a := a.toExpr (toTypeExpr α) (toExpr ·) + toTypeExpr := mkApp (mkConst ``RArray) (toTypeExpr α) + +end Meta +end Data + +open Nat.Linear + +section sortFuse + +-- Oddly, fusing norm and fuse does not yield a speed-up + +def Nat.Linear.Poly.insertSortedFused (k : Nat) (v : Var) (p : Poly) : Poly := + match p with + | [] => [(k, v)] + | (k', v') :: p => + bif Nat.blt v v' then + (k, v) :: (k', v') :: p + else + bif Nat.beq v v' then + (k + k', v) :: p + else + (k', v') :: insertSortedFused k v p + +def Nat.Linear.Poly.sortFuse (p : Poly) : Poly := + let rec go (p : Poly) (r : Poly) : Poly := + match p with + | [] => r + | (k, v) :: p => go p (r.insertSortedFused k v) + go p [] + +/-- warning: declaration uses 'sorry' -/ +#guard_msgs in +theorem Nat.Linear.Poly.norm_eq_sortFuse (p : Poly) : p.norm = p.sortFuse := by + sorry + +end sortFuse + +section ToPoly + +def Nat.Linear.Expr.toPolyAux (coeff : Nat) : Expr → (Poly → Poly) + | Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·) + | Expr.var i => ((coeff, i) :: ·) + | Expr.add a b => a.toPolyAux coeff ∘ b.toPolyAux coeff + | Expr.mulL k a => if k == 0 then id else a.toPolyAux (k * coeff) + | Expr.mulR a k => if k == 0 then id else a.toPolyAux (k * coeff) + +-- attribute [rsimp_optimize] Nat.Linear.Expr.toPoly + +noncomputable +def Nat.Linear.Expr.toPolyFast (e : Expr) : Poly := + Nat.Linear.Expr.toPolyAux 1 e [] + +@[simp] +theorem Nat.Linear.Poly.mul.go_append : + Nat.Linear.Poly.mul.go k (p₁ ++ p₂) = + Nat.Linear.Poly.mul.go k p₁ ++ Nat.Linear.Poly.mul.go k p₂ := by + induction p₁ <;> simp [mul.go, *] + +@[simp] +theorem Nat.Linear.Poly.mul_nil : + Nat.Linear.Poly.mul k [] = [] := by simp [mul, mul.go] + +@[simp] +theorem Nat.Linear.Poly.mul_0 : + Nat.Linear.Poly.mul 0 p = [] := by simp [mul] + +@[simp] +theorem Nat.Linear.Poly.mul_append : + Nat.Linear.Poly.mul k (p₁ ++ p₂) = + Nat.Linear.Poly.mul k p₁ ++ Nat.Linear.Poly.mul k p₂ := by + unfold Poly.mul + simp only [cond_eq_if] + split <;> (try split) <;> simp + +@[simp] +theorem Nat.Linear.Poly.mul_go_mul_go : + Poly.mul.go k (Poly.mul.go k' p) = Poly.mul.go (k * k') p := by + induction p <;> simp_all [mul.go, Nat.mul_assoc] + +theorem Nat.mul_eq_one (n m : Nat) : n * m = 1 ↔ (n = 1 ∧ m = 1) := by + cases n <;> simp + rename_i n + cases m <;> simp + rename_i m + cases n <;> simp + rename_i n + simp [Nat.mul_add, ← Nat.add_assoc] + + +@[simp] +theorem Nat.Linear.Poly.mul_mul : + Poly.mul k (Poly.mul k' p) = Poly.mul (k' * k) p := by + unfold Poly.mul + simp only [cond_eq_if, beq_iff_eq, Nat.mul_eq_one, Nat.mul_eq_zero] + repeat' split <;> try (simp_all [mul.go, Nat.mul_comm]) + +theorem Nat.Linear.Expr.toPoly_eq_toPolyFast : + Nat.Linear.Expr.toPoly = Nat.Linear.Expr.toPolyFast := by + funext p + unfold toPolyFast + suffices ∀ k r, k ≠ 0 → (p.toPoly.mul k ++ r = toPolyAux k p r) by simpa using this 1 [] + intro k r hk + induction p generalizing k r + · simp [toPoly, toPolyAux, cond_eq_if, hk] + split + . simp [Poly.mul, Poly.mul.go] + . simp [Poly.mul, Poly.mul.go, cond_eq_if, *] + split <;> simp [*] + · simp [toPoly, toPolyAux, cond_eq_if, Poly.mul, Poly.mul.go, hk] + split <;> simp [*] + next iha ihb => + simp [toPoly, toPolyAux, cond_eq_if, hk] + rw [← ihb k _ hk, ← iha k _ hk] + next k' _ ih => + simp [toPoly, toPolyAux, cond_eq_if, hk] + split + · simp [*] + · rw [← ih (k' * k) _ _] + simp [*, Nat.mul_ne_zero] + next _ k' ih => + simp [toPoly, toPolyAux, cond_eq_if, hk] + split + · simp [*] + · rw [← ih (k' * k) _ _] + simp [*, Nat.mul_ne_zero] + +end ToPoly + +section AltDenote + +def Nat.Linear.Expr.denote' (ctx : Data.RArray Nat) : Expr → Nat + | Expr.add a b => Nat.add (denote' ctx a) (denote' ctx b) + | Expr.num k => k + | Expr.var v => ctx.get v + | Expr.mulL k e => Nat.mul k (denote' ctx e) + | Expr.mulR e k => Nat.mul (denote' ctx e) k + +end AltDenote + +-- theorem Nat.beq_eq' (a b : Nat) : (a == b) = Nat.beq a b := sorry + +attribute [rsimp ←] Nat.beq_eq Nat.mul_eq Nat.add_eq Bool.cond_decide +attribute [rsimp] Std.Tactic.BVDecide.Normalize.Bool.decide_eq_true +attribute [rsimp] BEq.beq +attribute [rsimp_optimize] Nat.Linear.Poly.mul.go +attribute [rsimp_optimize] Nat.Linear.Poly.mul + +attribute [rsimp_optimize] Nat.Linear.Poly.insertSortedFused +attribute [rsimp_optimize] Nat.Linear.Poly.sortFuse.go +attribute [rsimp_optimize] Nat.Linear.Poly.sortFuse + +attribute [rsimp_optimize] Nat.Linear.Poly.insertSorted +attribute [rsimp_optimize] Nat.Linear.Poly.sort.go +attribute [rsimp_optimize] Nat.Linear.Poly.sort +attribute [rsimp_optimize] Nat.Linear.Poly.fuse +attribute [rsimp_optimize] Nat.Linear.Expr.toPolyAux +attribute [rsimp_optimize] Nat.Linear.Expr.toPolyFast + +-- This is actually a bit slower, it seems +-- But faster if there are repeated variables +attribute [rsimp] Nat.Linear.Poly.norm_eq_sortFuse +-- attribute [rsimp_optimize] Nat.Linear.Poly.norm + + +-- attribute [rsimp_optimize] Nat.Linear.Expr.toPoly +attribute [rsimp] Nat.Linear.Expr.toPoly_eq_toPolyFast + +attribute [rsimp_optimize] Nat.Linear.Expr.toNormPoly +attribute [rsimp ←] List.reverseAux_eq +attribute [rsimp_optimize] Nat.Linear.Poly.cancelAux +attribute [rsimp_optimize] Nat.Linear.Poly.cancel + +attribute [rsimp_optimize] Nat.Linear.Var.denote.go +attribute [rsimp_optimize] Nat.Linear.Var.denote +-- set_option trace.tactic.rsimp_optimize true in +attribute [rsimp_optimize] Nat.Linear.Expr.denote + +/-- A hook to use below, and to easily swap out the definition -/ +def Nat.Linear.Expr.toPoly' := @Nat.Linear.Expr.toPoly + +theorem Nat.Linear.Expr.toPoly'_eq_to_Poly : + Nat.Linear.Expr.toPoly = Nat.Linear.Expr.toPoly' := rfl +-- set_option trace.tactic.rsimp_optimize true in +attribute [rsimp_optimize] Nat.Linear.Expr.toPoly' + +/-- warning: declaration uses 'sorry' -/ +#guard_msgs in +theorem Nat.Linear.Expr.of_cancel_eq_no_rfl (ctx : Context) (a b c d : Expr) : + (a.denote ctx = b.denote ctx) = (c.denote ctx = d.denote ctx) := sorry + +theorem Nat.Linear.Expr.of_cancel_eq_opt (ctx : Context) (a b c d : Expr) + (h : Poly.cancel.rsimp (Expr.toNormPoly.rsimp a) (Expr.toNormPoly.rsimp b) = + (Expr.toPoly'.rsimp c, Expr.toPoly'.rsimp d)) : + (a.denote ctx = b.denote ctx) = (c.denote ctx = d.denote ctx) := by + revert h + simp only [← Expr.toNormPoly.eq_rsimp, ← Expr.toPolyFast.eq_rsimp, + ← Poly.cancel.eq_rsimp, Nat.Linear.Expr.toPoly'_eq_to_Poly, ← toPoly'.eq_rsimp] + exact Expr.of_cancel_eq ctx a b c d + +theorem Nat.Linear.Expr.of_cancel_eq_opt_denote (ctx : Context) (a b c d : Expr) + (h : Poly.cancel.rsimp (Expr.toNormPoly.rsimp a) (Expr.toNormPoly.rsimp b) = + (Expr.toPolyFast.rsimp c, Expr.toPolyFast.rsimp d)) : + (Nat.Linear.Expr.denote.rsimp ctx a = Nat.Linear.Expr.denote.rsimp ctx b) = + (Nat.Linear.Expr.denote.rsimp ctx c = Nat.Linear.Expr.denote.rsimp ctx d) := by + revert h + simp only [← Expr.toNormPoly.eq_rsimp, ← Expr.toPolyFast.eq_rsimp, + ← Poly.cancel.eq_rsimp, ← Nat.Linear.Expr.toPoly_eq_toPolyFast, + ← Nat.Linear.Expr.denote.eq_rsimp + ] + exact Expr.of_cancel_eq ctx a b c d + +/-- warning: declaration uses 'sorry' -/ +#guard_msgs in +theorem Nat.Linear.Expr.of_cancel_eq_opt_denote' (ctx : Data.RArray Nat) (a b c d : Expr) + (h : Poly.cancel.rsimp (Expr.toNormPoly.rsimp a) (Expr.toNormPoly.rsimp b) = (Expr.toPolyFast.rsimp c, Expr.toPolyFast.rsimp d)) : + (Nat.Linear.Expr.denote' ctx a = Nat.Linear.Expr.denote' ctx b) = + (Nat.Linear.Expr.denote' ctx c = Nat.Linear.Expr.denote' ctx d) := by + sorry + all_goals + revert h + simp only [← Expr.toNormPoly.eq_rsimp, ← Expr.toPolyFast.eq_rsimp, + ← Poly.cancel.eq_rsimp, ← Nat.Linear.Expr.toPoly_eq_toPolyFast, + ← Nat.Linear.Expr.denote.eq_rsimp + ] + exact Expr.of_cancel_eq ctx a b c d + +/-- warning: declaration uses 'sorry' -/ +#guard_msgs in +theorem Nat.Linear.Expr.of_cancel_eq_opt_denote'_no_rfl (ctx : Data.RArray Nat) (a b c d : Expr) + (h : True) : + (Nat.Linear.Expr.denote' ctx a = Nat.Linear.Expr.denote' ctx b) = + (Nat.Linear.Expr.denote' ctx c = Nat.Linear.Expr.denote' ctx d) := by + sorry + +open Lean Meta + +def bench (variant : Nat) : MetaM Unit := + let n := 40 + let decls := Array.ofFn fun (i : Fin n) => ((`x).appendIndexAfter i, (fun _ => pure (mkConst ``Nat))) + withLocalDeclsD decls fun xs => do + if h : 0 < xs.size then + let mut e₁ := Expr.num 42 + let mut e₂ := Expr.num 23 + for _ in [:4] do + for i in [:xs.size] do + e₁ := .add (.mulL i e₁) (.var i) + e₂ := .add (.var i) (Expr.mulR e₂ (xs.size - i)) + + let (p₁', p₂') := Poly.cancel e₁.toNormPoly e₂.toNormPoly + let e₁' := p₁'.toExpr + let e₂' := p₂'.toExpr + have _value_orig := mkApp6 (.const ``Expr.of_cancel_eq []) + (← mkListLit (mkConst ``Nat) xs.toList) + (toExpr e₁) (toExpr e₂) (toExpr e₁') (toExpr e₂') + (← mkEqRefl (toExpr (p₁', p₂'))) + have _value_no_rfl := mkApp5 (.const ``Expr.of_cancel_eq_no_rfl []) + (← mkListLit (mkConst ``Nat) xs.toList) + (toExpr e₁) (toExpr e₂) (toExpr e₁') (toExpr e₂') + have _value_opt := mkApp6 (.const ``Expr.of_cancel_eq_opt []) + (← mkListLit (mkConst ``Nat) xs.toList) + (toExpr e₁) (toExpr e₂) (toExpr e₁') (toExpr e₂') + (← mkEqRefl (toExpr (p₁', p₂'))) + have _value_opt_denote := mkApp6 (.const ``Expr.of_cancel_eq_opt_denote []) + (← mkListLit (mkConst ``Nat) xs.toList) + (toExpr e₁) (toExpr e₂) (toExpr e₁') (toExpr e₂') + (← mkEqRefl (toExpr (p₁', p₂'))) + have _value_opt_denote' := mkApp6 (.const ``Expr.of_cancel_eq_opt_denote' []) + (Data.RArray.toExpr (mkConst ``Nat) id (.ofArray xs h)) + (toExpr e₁) (toExpr e₂) (toExpr e₁') (toExpr e₂') + (← mkEqRefl (toExpr (p₁', p₂'))) + have _value_opt_denote'_no_rfl := mkApp6 (.const ``Expr.of_cancel_eq_opt_denote'_no_rfl []) + (Data.RArray.toExpr (mkConst ``Nat) id (.ofArray xs h)) + (toExpr e₁) (toExpr e₂) (toExpr e₁') (toExpr e₂') + (mkConst ``True.intro) + let value := match variant with + | 0 => _value_orig + | 1 => _value_no_rfl + | 2 => _value_opt + | 3 => _value_opt_denote + | 4 => _value_opt_denote' + | _ => _value_opt_denote'_no_rfl + let value ← mkLambdaFVars xs value + let exp₁ ← Linear.Nat.LinearExpr.toArith xs e₁ + let exp₂ ← Linear.Nat.LinearExpr.toArith xs e₂ + let exp₁' ← Linear.Nat.LinearExpr.toArith xs e₁' + let exp₂' ← Linear.Nat.LinearExpr.toArith xs e₂' + let type ← mkEq (← mkEq exp₁ exp₂) (← mkEq exp₁' exp₂') + let type ← mkForallFVars xs type + + let name := `linear_test + let decl := .thmDecl { name, value, type, levelParams := [] } + let timings : List Nat ← (List.range 5).mapM fun _ => + withoutModifyingEnv do + let start ← IO.monoMsNow + addDecl decl + return (← IO.monoMsNow) - start + let some best := timings.min? | unreachable! + logInfo m!"time {variant}: {best}ms" + -- logInfo m!"{type}" + else + unreachable! + +-- run_meta bench 0 +-- run_meta bench 1 +run_meta bench 2 +run_meta bench 3 +run_meta bench 4 +run_meta bench 5 diff --git a/tests/lean/run/rsimp_magma.lean b/tests/lean/run/rsimp_magma.lean new file mode 100644 index 000000000000..61928c5b3139 --- /dev/null +++ b/tests/lean/run/rsimp_magma.lean @@ -0,0 +1,145 @@ +import Std.Tactic.RSimp + +/-! +This test applies the rsimp_decide tactic to a calculation from the equational_theories +project. +-/ + +/- +First a bit of of setup. Maybe some of this can eventually become part of the default +rsimp-set. The proofs here are not great, sorry for that, I hope they do not break too often. +-/ + +def Fin.all {n : Nat} (P : ∀ i < n, Bool) : Bool := go n (Nat.le_refl n) + where + go := Nat.rec + (motive := fun i => i ≤ n → Bool) + (fun _ => true) + (fun i ih p => P i (by omega) && ih (by omega)) + +theorem Fin.all_eq_true_iff {n : Nat} (P : ∀ i < n, Bool) : + Fin.all P ↔ (∀ (i : Nat) (hj : i < n), P i (by omega) = true) := + go (Nat.le_refl n) +where + go {i : Nat} (h : i ≤ n) : + Fin.all.go P i h ↔ (∀ (j : Nat) (hj : j < i), P j (by omega) = true) := by + induction i + case zero => simp [Fin.all.go]; rfl + case succ i ih => + symm + calc (∀ (j : Nat) (hj : j < i + 1), P j (by omega)) + _ ↔ P i h ∧ (∀ (j : Nat) (hj : j < i), P j (by omega)) := by + constructor + · exact fun h' => ⟨h' i (by omega), fun j hj => h' j (by omega)⟩ + · intro h' j hj + by_cases j = i + · subst j; apply h'.1 + · apply h'.2 j (by omega) + _ ↔ P i h ∧ (∀ (j : Nat) (hj : j < i), P j (by omega)) := by simp + _ ↔ P i h = true ∧ all.go P i (by omega) = true := by rw [ih] + _ ↔ all.go P (i+1) h = true := by simp [all.go] + + +def Nat.all_below (n : Nat) (P : Nat → Bool) : Bool := + Nat.rec true (fun i ih => P i && ih) n + +@[rsimp] +def Fin.all_eq_all_below {n : Nat} (P : ∀ i < n, Bool) (P' : Nat → Bool) + (hP : ∀ i h, P i h = P' i) : Fin.all P = Nat.all_below n P' := by + suffices ∀ i (h : i ≤ n), Fin.all.go P i h = Nat.all_below i P' + by apply this + intros i h + induction i + case zero => rfl + case succ i ih => + simp [all.go, Nat.all_below] + congr + · apply hP + · apply ih + + +theorem Bool.eq_of_eq_true_iff_eq_true {a b : Bool} : (a = true ↔ b = true) → a = b := by + cases a; cases b + all_goals simp + +@[rsimp] +theorem Fin.decideAll_to_Fin.all {n : Nat} {P : Fin n → Prop} [DecidablePred P] : + decide (∀ x, P x) = Fin.all (fun i h => decide (P ⟨i, h⟩)) := by + apply Bool.eq_of_eq_true_iff_eq_true + simp [Fin.all_eq_true_iff, Fin.forall_iff] + +@[rsimp] +theorem Nat.decideEq_to_beq {x y : Nat} : + decide (x = y) = Nat.beq x y := by + simp [decide, instDecidableEqNat, Nat.decEq] + split + · simp [*] + · simp [*] + + +attribute [rsimp] Nat.decideEq_to_beq +attribute [rsimp] Fin.ext_iff +attribute [rsimp] Fin.val_mul Fin.val_add Mul.mul Fin.mul +attribute [rsimp] instHAdd instHMul instAddNat instMulNat instHPow instPowNat instNatPowNat +attribute [rsimp] instHMod Nat.instMod instHDiv Nat.instDiv + +-- Now the example calculation + +namespace Example +class Magma (α : Type u) where /-- op -/ op : α → α → α +@[inherit_doc] infixl:65 " ◇ " => Magma.op + +def opOfTable {n : Nat} (t : Nat) (a : Fin n) (b : Fin n) : Fin n := + let i := a.val * n + b.val + let r := (t / n^i) % n + ⟨r, Nat.mod_lt _ (Fin.pos a)⟩ + +attribute [rsimp] Magma.op opOfTable + +def table : Nat := 176572862725894008122698639442158340463570358062018791456284713065412594783123644086682432661794684073102303331486778326370940525772356431236683795848309863276639424307474540043134479302998 + +abbrev Equation2531 (G: Type _) [Magma G] := ∀ x y : G, x = (y ◇ ((y ◇ x) ◇ x)) ◇ y + +@[rsimp] +def M2 : Magma (Fin 13) where + op := opOfTable table + +-- #time approx 130ms +theorem Equation2531_M2_unopt : @Equation2531 (Fin 13) M2 := by decide + +set_option trace.tactic.rsimp_decide true +set_option pp.fieldNotation.generalized false + +/-- +info: [tactic.rsimp_decide] Optimized expression: + Nat.all_below 13 fun i => + Nat.all_below 13 fun i_1 => + Nat.beq i + (Nat.mod + (Nat.div Example.table + (Nat.pow 13 + (Nat.add + (Nat.mul + (Nat.mod + (Nat.div Example.table + (Nat.pow 13 + (Nat.add (Nat.mul i_1 13) + (Nat.mod + (Nat.div Example.table + (Nat.pow 13 + (Nat.add + (Nat.mul + (Nat.mod (Nat.div Example.table (Nat.pow 13 (Nat.add (Nat.mul i_1 13) i))) 13) + 13) + i))) + 13)))) + 13) + 13) + i_1))) + 13) +-/ +#guard_msgs in +-- #time -- approx 33ms +theorem Equation2531_M2_opt : @Equation2531 (Fin 13) M2 := by rsimp_decide + +end Example diff --git a/tests/lean/run/rsimp_optimize.lean b/tests/lean/run/rsimp_optimize.lean new file mode 100644 index 000000000000..1436544c0f38 --- /dev/null +++ b/tests/lean/run/rsimp_optimize.lean @@ -0,0 +1,76 @@ +import Std.Tactic.RSimp + + +-- A little experimental rsimp set + +attribute [rsimp] Fin.ext_iff +attribute [rsimp] Fin.val_mul Fin.val_add +-- Unfortunately, `attribute` does not allow to add theorems with symm? +def Bool.cond_decide_symm := fun α p inst t e => (@Bool.cond_decide α p inst t e).symm +attribute [rsimp] Bool.cond_decide_symm +def Nat.beq_eq_symm {x y : Nat} : (x = y) = (x.beq y = true) := (@Nat.beq_eq x y).symm +attribute [rsimp] Nat.beq_eq_symm +attribute [rsimp] Std.Tactic.BVDecide.Normalize.Bool.decide_eq_true +@[rsimp] theorem Bool.cond_true_false (b : Bool) : cond b true false = b := by simp + + +-- A function we may want to optimize + +def foo (a b : Fin 12) : Bool := if a * b = a + b then true else false + +attribute [rsimp_optimize] foo + +/-- +info: def foo.rsimp : Fin 12 → Fin 12 → Bool := +fun a b => (↑a * ↑b % 12).beq ((↑a + ↑b) % 12) +-/ +#guard_msgs in +#print foo.rsimp + +/-- info: foo.eq_rsimp : foo = foo.rsimp -/ +#guard_msgs in +#check foo.eq_rsimp + +/-- error: foo.rsimp has already been declared -/ +#guard_msgs in +attribute [rsimp_optimize] foo + +-- Now a recursive function +def bar (a b : Fin 12) : Bool := + if a = 0 then + false + else + if a * b = a + b then true else bar (a - 1) (b + 1) + +attribute [rsimp_optimize] bar + +/-- +info: def bar.rsimp : Fin 12 → Fin 12 → Bool := +rsimp_iterate bar fun ih a b => + bif (↑a).beq ↑0 then false else bif (↑a * ↑b % 12).beq ((↑a + ↑b) % 12) then true else ih (a - 1) (b + 1) +-/ +#guard_msgs in +#print bar.rsimp + + +namespace NotReallyRecursive + +-- Mostly a curious corner case: A recursive function with recursion that will be optimized away + +def bar (a b : Fin 12) : Bool := + if a = 0 then + false + else + if true then true else bar (a - 1) (b + 1) +termination_by a + +attribute [rsimp] Bool.cond_true +attribute [rsimp_optimize] bar +/-- +info: def NotReallyRecursive.bar.rsimp : Fin 12 → Fin 12 → Bool := +fun a b => bif (↑a).beq ↑0 then false else true +-/ +#guard_msgs in +#print bar.rsimp + +end NotReallyRecursive