Skip to content

Commit 132b86d

Browse files
committed
feat: BitVector normalization for udiv by twoPow
This PR adds a normalization rule to `bv_normalize` (which is used by `bv_decide`) that converts `x / 2^k` into `x >>> k` under suitable conditions. This allows us to simplify the expensive division circuits that are used for bitblasting into much cheaper shifting circuits. Concretely, it allows for the following canonicalization: ```lean example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize ```
1 parent 456e6d2 commit 132b86d

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

src/Init/Data/BitVec/Lemmas.lean

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2809,6 +2809,14 @@ theorem shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) :
28092809
ext i
28102810
simp [getLsbD_shiftLeft, Fin.is_lt, decide_true, Bool.true_and, mul_twoPow_eq_shiftLeft]
28112811

2812+
/--
2813+
The unsigned division of `x` by `2^k` equals shifting `x` right by `k`,
2814+
when `k` is less than the bitwidth `w`.
2815+
-/
2816+
theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x / (twoPow w k) = x >>> k := by
2817+
have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk
2818+
simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this]
2819+
28122820
/- ### cons -/
28132821

28142822
@[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by

src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,36 @@ builtin_simproc [bv_normalize] bv_add_const' (((_ : BitVec _) + (_ : BitVec _))
129129

130130
attribute [builtin_bv_normalize_proc↓] reduceIte
131131

132+
/-- Return a number `k` such that `2^k = n`. -/
133+
private def Nat.log2Exact (n : Nat) : Option Nat := do
134+
guard <| n ≠ 0
135+
let k := n.log2
136+
guard <| Nat.pow 2 k == n
137+
return k
138+
139+
-- Build an expression for `x ^ y`
140+
def mkPow (x y : Expr) : MetaM Expr := mkAppM ``HPow.hPow #[x, y]
141+
142+
builtin_simproc [bv_normalize] bv_udiv_of_two_pow (((_ : BitVec _) / (BitVec.ofNat _ _) : BitVec _)) := fun e => do
143+
let_expr HDiv.hDiv _α _β _γ _self x y := e | return .continue
144+
let some ⟨w, yVal⟩ ← getBitVecValue? y | return .continue
145+
let n := yVal.toNat
146+
-- BitVec.ofNat w n, where n =def= 2^k
147+
let some k := Nat.log2Exact n | return .continue
148+
-- check that k < w.
149+
if k ≥ w then return .continue
150+
let rhs ← mkAppM ``HShiftRight.hShiftRight #[x, mkNatLit k]
151+
-- 2^k = n
152+
let hk ← mkDecideProof (← mkEq (← mkPow (mkNatLit 2) (mkNatLit k)) (mkNatLit n))
153+
-- k < w
154+
let hlt ← mkDecideProof (← mkLt (mkNatLit k) (mkNatLit w))
155+
let proof := mkAppN (mkConst ``Std.Tactic.BVDecide.Normalize.BitVec.udiv_ofNat_eq_of_lt)
156+
#[mkNatLit w, x, mkNatLit n, mkNatLit k, hk, hlt]
157+
return .done {
158+
expr := rhs
159+
proof? := some proof
160+
}
161+
132162
/--
133163
A pass in the normalization pipeline. Takes the current goal and produces a refined one or closes
134164
the goal fully, indicated by returning `none`.

src/Std/Tactic/BVDecide/Normalize/BitVec.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,5 +304,11 @@ attribute [bv_normalize] BitVec.umod_zero
304304
attribute [bv_normalize] BitVec.umod_one
305305
attribute [bv_normalize] BitVec.umod_eq_and
306306

307+
/-- `x / (BitVec.ofNat n)` where `n = 2^k` is the same as shifting `x` right by `k`. -/
308+
theorem BitVec.udiv_ofNat_eq_of_lt (w : Nat) (x : BitVec w) (n : Nat) (k : Nat) (hk : 2 ^ k = n) (hlt : k < w) :
309+
x / (BitVec.ofNat w n) = x >>> k := by
310+
have : BitVec.ofNat w n = BitVec.twoPow w k := by simp [bv_toNat, hk]
311+
rw [this, BitVec.udiv_twoPow_eq_of_lt (hk := by omega)]
312+
307313
end Normalize
308314
end Std.Tactic.BVDecide

tests/lean/run/bv_decide_rewriter.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ example {x : BitVec 16} {b : Bool} : (if b then x else x) = x := by bv_normalize
8282
example {b : Bool} {x : Bool} : (bif b then x else x) = x := by bv_normalize
8383
example {x : BitVec 16} : x.abs = if x.msb then -x else x := by bv_normalize
8484
example {x : BitVec 16} : (BitVec.twoPow 16 2) = 4#16 := by bv_normalize
85+
example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize
86+
example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize
8587

8688
section
8789

0 commit comments

Comments
 (0)