Skip to content

Commit

Permalink
feat: BitVector normalization for udiv by twoPow
Browse files Browse the repository at this point in the history
This PR adds a normalization rule to `bv_normalize` (which is used by `bv_decide`) that converts `x / 2^k` into `x >>> k` under suitable conditions. This allows us to simplify the expensive division circuits that are used for bitblasting into much cheaper shifting circuits. Concretely, it allows for the following canonicalization:

```lean
example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize
example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize
```
  • Loading branch information
bollu committed Nov 11, 2024
1 parent 456e6d2 commit d93545b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2809,6 +2809,14 @@ theorem shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) :
ext i
simp [getLsbD_shiftLeft, Fin.is_lt, decide_true, Bool.true_and, mul_twoPow_eq_shiftLeft]

/--
The unsigned division of `x` by `2^k` equals shifting `x` right by `k`,
when `k` is less than the bitwidth `w`.
-/
theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x / (twoPow w k) = x >>> k := by
have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk
simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this]

/- ### cons -/

@[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by
Expand Down
30 changes: 30 additions & 0 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,36 @@ builtin_simproc [bv_normalize] bv_add_const' (((_ : BitVec _) + (_ : BitVec _))

attribute [builtin_bv_normalize_proc↓] reduceIte

/-- Return a number `k` such that `2^k = n`. -/
private def Nat.log2Exact (n : Nat) : Option Nat := do
guard <| n ≠ 0
let k := n.log2
guard <| Nat.pow 2 k == n
return k

-- Build an expression for `x ^ y`.
def mkPow (x y : Expr) : MetaM Expr := mkAppM ``HPow.hPow #[x, y]

builtin_simproc [bv_normalize] bv_udiv_of_two_pow (((_ : BitVec _) / (BitVec.ofNat _ _) : BitVec _)) := fun e => do
let_expr HDiv.hDiv _α _β _γ _self x y := e | return .continue
let some ⟨w, yVal⟩ ← getBitVecValue? y | return .continue
let n := yVal.toNat
-- BitVec.ofNat w n, where n =def= 2^k
let some k := Nat.log2Exact n | return .continue
-- check that k < w.
if k ≥ w then return .continue
let rhs ← mkAppM ``HShiftRight.hShiftRight #[x, mkNatLit k]
-- 2^k = n
let hk ← mkDecideProof (← mkEq (← mkPow (mkNatLit 2) (mkNatLit k)) (mkNatLit n))
-- k < w
let hlt ← mkDecideProof (← mkLt (mkNatLit k) (mkNatLit w))
let proof := mkAppN (mkConst ``Std.Tactic.BVDecide.Normalize.BitVec.udiv_ofNat_eq_of_lt)
#[mkNatLit w, x, mkNatLit n, mkNatLit k, hk, hlt]
return .done {
expr := rhs
proof? := some proof
}

/--
A pass in the normalization pipeline. Takes the current goal and produces a refined one or closes
the goal fully, indicated by returning `none`.
Expand Down
6 changes: 6 additions & 0 deletions src/Std/Tactic/BVDecide/Normalize/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,11 @@ attribute [bv_normalize] BitVec.umod_zero
attribute [bv_normalize] BitVec.umod_one
attribute [bv_normalize] BitVec.umod_eq_and

/-- `x / (BitVec.ofNat n)` where `n = 2^k` is the same as shifting `x` right by `k`. -/
theorem BitVec.udiv_ofNat_eq_of_lt (w : Nat) (x : BitVec w) (n : Nat) (k : Nat) (hk : 2 ^ k = n) (hlt : k < w) :
x / (BitVec.ofNat w n) = x >>> k := by
have : BitVec.ofNat w n = BitVec.twoPow w k := by simp [bv_toNat, hk]
rw [this, BitVec.udiv_twoPow_eq_of_lt (hk := by omega)]

end Normalize
end Std.Tactic.BVDecide
2 changes: 2 additions & 0 deletions tests/lean/run/bv_decide_rewriter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ example {x : BitVec 16} {b : Bool} : (if b then x else x) = x := by bv_normalize
example {b : Bool} {x : Bool} : (bif b then x else x) = x := by bv_normalize
example {x : BitVec 16} : x.abs = if x.msb then -x else x := by bv_normalize
example {x : BitVec 16} : (BitVec.twoPow 16 2) = 4#16 := by bv_normalize
example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize
example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize

section

Expand Down

0 comments on commit d93545b

Please sign in to comment.