From d93545b317a94bf32fd97e09dfa0cd8cdb2ed405 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 11 Nov 2024 07:54:25 +0000 Subject: [PATCH] feat: BitVector normalization for udiv by twoPow This PR adds a normalization rule to `bv_normalize` (which is used by `bv_decide`) that converts `x / 2^k` into `x >>> k` under suitable conditions. This allows us to simplify the expensive division circuits that are used for bitblasting into much cheaper shifting circuits. Concretely, it allows for the following canonicalization: ```lean example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize ``` --- src/Init/Data/BitVec/Lemmas.lean | 8 +++++ .../Tactic/BVDecide/Frontend/Normalize.lean | 30 +++++++++++++++++++ src/Std/Tactic/BVDecide/Normalize/BitVec.lean | 6 ++++ tests/lean/run/bv_decide_rewriter.lean | 2 ++ 4 files changed, 46 insertions(+) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 139d07705068..954c13a7e942 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -2809,6 +2809,14 @@ theorem shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : ext i simp [getLsbD_shiftLeft, Fin.is_lt, decide_true, Bool.true_and, mul_twoPow_eq_shiftLeft] +/-- +The unsigned division of `x` by `2^k` equals shifting `x` right by `k`, +when `k` is less than the bitwidth `w`. +-/ +theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x / (twoPow w k) = x >>> k := by + have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk + simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this] + /- ### cons -/ @[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by diff --git a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean index 8bad5ffb2659..fbdf59c83fe0 100644 --- a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean +++ b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean @@ -129,6 +129,36 @@ builtin_simproc [bv_normalize] bv_add_const' (((_ : BitVec _) + (_ : BitVec _)) attribute [builtin_bv_normalize_proc↓] reduceIte +/-- Return a number `k` such that `2^k = n`. -/ +private def Nat.log2Exact (n : Nat) : Option Nat := do + guard <| n ≠ 0 + let k := n.log2 + guard <| Nat.pow 2 k == n + return k + +-- Build an expression for `x ^ y`. +def mkPow (x y : Expr) : MetaM Expr := mkAppM ``HPow.hPow #[x, y] + +builtin_simproc [bv_normalize] bv_udiv_of_two_pow (((_ : BitVec _) / (BitVec.ofNat _ _) : BitVec _)) := fun e => do + let_expr HDiv.hDiv _α _β _γ _self x y := e | return .continue + let some ⟨w, yVal⟩ ← getBitVecValue? y | return .continue + let n := yVal.toNat + -- BitVec.ofNat w n, where n =def= 2^k + let some k := Nat.log2Exact n | return .continue + -- check that k < w. + if k ≥ w then return .continue + let rhs ← mkAppM ``HShiftRight.hShiftRight #[x, mkNatLit k] + -- 2^k = n + let hk ← mkDecideProof (← mkEq (← mkPow (mkNatLit 2) (mkNatLit k)) (mkNatLit n)) + -- k < w + let hlt ← mkDecideProof (← mkLt (mkNatLit k) (mkNatLit w)) + let proof := mkAppN (mkConst ``Std.Tactic.BVDecide.Normalize.BitVec.udiv_ofNat_eq_of_lt) + #[mkNatLit w, x, mkNatLit n, mkNatLit k, hk, hlt] + return .done { + expr := rhs + proof? := some proof + } + /-- A pass in the normalization pipeline. Takes the current goal and produces a refined one or closes the goal fully, indicated by returning `none`. diff --git a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean index a7763e01f8d6..b7187f59dee3 100644 --- a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean +++ b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean @@ -304,5 +304,11 @@ attribute [bv_normalize] BitVec.umod_zero attribute [bv_normalize] BitVec.umod_one attribute [bv_normalize] BitVec.umod_eq_and +/-- `x / (BitVec.ofNat n)` where `n = 2^k` is the same as shifting `x` right by `k`. -/ +theorem BitVec.udiv_ofNat_eq_of_lt (w : Nat) (x : BitVec w) (n : Nat) (k : Nat) (hk : 2 ^ k = n) (hlt : k < w) : + x / (BitVec.ofNat w n) = x >>> k := by + have : BitVec.ofNat w n = BitVec.twoPow w k := by simp [bv_toNat, hk] + rw [this, BitVec.udiv_twoPow_eq_of_lt (hk := by omega)] + end Normalize end Std.Tactic.BVDecide diff --git a/tests/lean/run/bv_decide_rewriter.lean b/tests/lean/run/bv_decide_rewriter.lean index 6d4e3aa804d6..eb5b9fddbbe6 100644 --- a/tests/lean/run/bv_decide_rewriter.lean +++ b/tests/lean/run/bv_decide_rewriter.lean @@ -82,6 +82,8 @@ example {x : BitVec 16} {b : Bool} : (if b then x else x) = x := by bv_normalize example {b : Bool} {x : Bool} : (bif b then x else x) = x := by bv_normalize example {x : BitVec 16} : x.abs = if x.msb then -x else x := by bv_normalize example {x : BitVec 16} : (BitVec.twoPow 16 2) = 4#16 := by bv_normalize +example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize +example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize section