From c84ad0ef8d66d673dcf217d8c4dedf6274048e7c Mon Sep 17 00:00:00 2001 From: Tobias Grosser Date: Wed, 1 Jan 2025 21:05:53 +0000 Subject: [PATCH] feat: bv_decide short-circuit mul_eq_mul with shared left|right This PR adds short-circuit support to bv_decide to accelerate certain multiplications. In particular, `a * x = b * x` can be extended to `a = b v (a * x = b * x)`. The latter is faster if `a = b` is indeed true. --- .../Tactic/BVDecide/Frontend/Normalize.lean | 36 ++++++++++++++++++- src/Std/Tactic/BVDecide/Normalize/BitVec.lean | 12 +++++++ tests/lean/run/bv_arith.lean | 15 ++++++++ tests/lean/run/bv_llvm.lean | 8 +++++ 4 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean index 8f5e651a13b7..82159cd3f9f5 100644 --- a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean +++ b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean @@ -221,6 +221,39 @@ def rewriteRulesPass (maxSteps : Nat) : Pass where let some (_, newGoal) := result? | return none return newGoal +/-- +Responsible for applying short-circuit optimizations for `*`. +-/ +def shortCircuitPass (maxSteps : Nat) : Pass where + name := `shortCircuitPass + run goal := do + let mut theorems : SimpTheoremsArray := #[] + + let cl : Expr := mkConst ``mul_beq_mul_short_circuit_left + let ol : Lean.Meta.Origin := Lean.Meta.Origin.decl `mul_beq_mul_short_circuit_left + theorems ← theorems.addTheorem ol cl + + let cr : Expr := mkConst ``mul_beq_mul_short_circuit_right + let or : Lean.Meta.Origin := Lean.Meta.Origin.decl `mul_beq_mul_short_circuit_right + theorems ← theorems.addTheorem or cr + + let bn : Expr := mkConst ``Bool.not_not + let obn: Lean.Meta.Origin := Lean.Meta.Origin.decl `not_not + theorems ← theorems.addTheorem obn bn + + let simpCtx ← Simp.mkContext + (config := { failIfUnchanged := false, zetaDelta := true, singlePass := true, maxSteps }) + (simpTheorems := theorems) + (congrTheorems := (← getSimpCongrTheorems)) + + let hyps ← goal.getNondepPropHyps + let ⟨result?, _⟩ ← simpGoal goal + (ctx := simpCtx) + (simprocs := #[]) + (fvarIdsToSimp := hyps) + let some (_, newGoal) := result? | return none + return newGoal + /-- Flatten out ands. That is look for hypotheses of the form `h : (x && y) = true` and replace them with `h.left : x = true` and `h.right : y = true`. This can enable more fine grained substitutions @@ -359,7 +392,8 @@ def bvNormalize (g : MVarId) (cfg : BVDecideConfig) : MetaM (Option MVarId) := d -- Contradiction proof let some g ← g.falseOrByContra | return none trace[Meta.Tactic.bv] m!"Running preprocessing pipeline on:\n{g}" - Pass.fixpointPipeline (Pass.passPipeline cfg) g + let some g ← Pass.fixpointPipeline (Pass.passPipeline cfg) g | return none + (Pass.shortCircuitPass cfg.maxSteps).run g @[builtin_tactic Lean.Parser.Tactic.bvNormalize] def evalBVNormalize : Tactic := fun diff --git a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean index 09088f9eb595..dd466a0084a1 100644 --- a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean +++ b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean @@ -309,5 +309,17 @@ theorem BitVec.udiv_ofNat_eq_of_lt (w : Nat) (x : BitVec w) (n : Nat) (k : Nat) have : BitVec.ofNat w n = BitVec.twoPow w k := by simp [bv_toNat, hk] rw [this, BitVec.udiv_twoPow_eq_of_lt (hk := by omega)] +theorem mul_beq_mul_short_circuit_left {x z y : BitVec w} : + ((x * z == y * z)) = !(!(x == y) && !(x * z == y * z)) := by + simp + intros + congr + +theorem mul_beq_mul_short_circuit_right {x z y : BitVec w} : + ((z * x == z * y)) = !(!(x == y) && !(z * x == z * y)) := by + simp + intros + congr + end Normalize end Std.Tactic.BVDecide diff --git a/tests/lean/run/bv_arith.lean b/tests/lean/run/bv_arith.lean index 32805b47fa55..6320a92e494e 100644 --- a/tests/lean/run/bv_arith.lean +++ b/tests/lean/run/bv_arith.lean @@ -70,3 +70,18 @@ theorem arith_unit_18 (x y : BitVec 8) (hx : x.msb = true) (h : y.msb = true) : theorem arith_unit_19 (x y : BitVec 8) (hx : x.msb = true) (h : y.msb = true) : x.srem y = -((-x) % (-y)) := by bv_decide + +-- This theorem is not short-circuited, so it slow for large bitwidths. +theorem mul_mul_eq_mul_mul (x₁ x₂ y₁ y₂ z : BitVec 4) (h₁ : x₁ = y₁) (h₂ : x₂ = y₂) : + x₁ * (x₂ * z) = y₁ * (y₂ * z) := by + bv_decide + +-- This theorem is short-circuited and scales to standard bitwidths. +theorem mul_eq_mul_eq_right (x y z : BitVec 64) (h : x = y) : + x * z = y * z := by + bv_decide + +-- This theorem is short-circuited and scales to standard bitwidths. +theorem mul_eq_mul_eq_left (x y z : BitVec 64) (h : x = y) : + z * x = z * y := by + bv_decide diff --git a/tests/lean/run/bv_llvm.lean b/tests/lean/run/bv_llvm.lean index c5e47432106f..4b42e0fc86d4 100644 --- a/tests/lean/run/bv_llvm.lean +++ b/tests/lean/run/bv_llvm.lean @@ -10,3 +10,11 @@ theorem test21_thm (x : _root_.BitVec 8) : theorem bitvec_AndOrXor_1683_2 : ∀ (a b : BitVec 64), (b ≤ a) || (a != b) = true := by intros; bv_decide + +theorem short_circuit_mul_right (x x_1 : BitVec 32) (h : ¬BitVec.ofBool (x_1 &&& 4096#32 == 0#32) = 1#1) : + (x ||| 4096#32) * (x ||| 4096#32) = (x ||| x_1 &&& 4096#32) * (x ||| 4096#32) := by + bv_decide + +theorem short_circuit_mul_left (x x_1 : BitVec 32) (h : ¬BitVec.ofBool (x_1 &&& 4096#32 == 0#32) = 1#1) : + (x ||| 4096#32) * (x ||| 4096#32) = (x ||| 4096#32) * (x ||| x_1 &&& 4096#32) := by + bv_decide