Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: BitVec normalization rule for udiv by twoPow #6029

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2809,6 +2809,14 @@ theorem shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) :
ext i
simp [getLsbD_shiftLeft, Fin.is_lt, decide_true, Bool.true_and, mul_twoPow_eq_shiftLeft]

/--
The unsigned division of `x` by `2^k` equals shifting `x` right by `k`,
when `k` is less than the bitwidth `w`.
-/
theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x / (twoPow w k) = x >>> k := by
have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk
simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this]

/- ### cons -/

@[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by
Expand Down
30 changes: 30 additions & 0 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,36 @@ builtin_simproc [bv_normalize] bv_add_const' (((_ : BitVec _) + (_ : BitVec _))

attribute [builtin_bv_normalize_proc↓] reduceIte

/-- Return a number `k` such that `2^k = n`. -/
private def Nat.log2Exact (n : Nat) : Option Nat := do
guard <| n ≠ 0
let k := n.log2
guard <| Nat.pow 2 k == n
return k

-- Build an expression for `x ^ y`.
def mkPow (x y : Expr) : MetaM Expr := mkAppM ``HPow.hPow #[x, y]

builtin_simproc [bv_normalize] bv_udiv_of_two_pow (((_ : BitVec _) / (BitVec.ofNat _ _) : BitVec _)) := fun e => do
let_expr HDiv.hDiv _α _β _γ _self x y := e | return .continue
let some ⟨w, yVal⟩ ← getBitVecValue? y | return .continue
let n := yVal.toNat
-- BitVec.ofNat w n, where n =def= 2^k
let some k := Nat.log2Exact n | return .continue
-- check that k < w.
if k ≥ w then return .continue
let rhs ← mkAppM ``HShiftRight.hShiftRight #[x, mkNatLit k]
-- 2^k = n
let hk ← mkDecideProof (← mkEq (← mkPow (mkNatLit 2) (mkNatLit k)) (mkNatLit n))
-- k < w
let hlt ← mkDecideProof (← mkLt (mkNatLit k) (mkNatLit w))
let proof := mkAppN (mkConst ``Std.Tactic.BVDecide.Normalize.BitVec.udiv_ofNat_eq_of_lt)
#[mkNatLit w, x, mkNatLit n, mkNatLit k, hk, hlt]
return .done {
expr := rhs
proof? := some proof
}

/--
A pass in the normalization pipeline. Takes the current goal and produces a refined one or closes
the goal fully, indicated by returning `none`.
Expand Down
6 changes: 6 additions & 0 deletions src/Std/Tactic/BVDecide/Normalize/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,11 @@ attribute [bv_normalize] BitVec.umod_zero
attribute [bv_normalize] BitVec.umod_one
attribute [bv_normalize] BitVec.umod_eq_and

/-- `x / (BitVec.ofNat n)` where `n = 2^k` is the same as shifting `x` right by `k`. -/
theorem BitVec.udiv_ofNat_eq_of_lt (w : Nat) (x : BitVec w) (n : Nat) (k : Nat) (hk : 2 ^ k = n) (hlt : k < w) :
x / (BitVec.ofNat w n) = x >>> k := by
have : BitVec.ofNat w n = BitVec.twoPow w k := by simp [bv_toNat, hk]
rw [this, BitVec.udiv_twoPow_eq_of_lt (hk := by omega)]

end Normalize
end Std.Tactic.BVDecide
2 changes: 2 additions & 0 deletions tests/lean/run/bv_decide_rewriter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ example {x : BitVec 16} {b : Bool} : (if b then x else x) = x := by bv_normalize
example {b : Bool} {x : Bool} : (bif b then x else x) = x := by bv_normalize
example {x : BitVec 16} : x.abs = if x.msb then -x else x := by bv_normalize
example {x : BitVec 16} : (BitVec.twoPow 16 2) = 4#16 := by bv_normalize
example {x : BitVec 16} : x / (BitVec.twoPow 16 2) = x >>> 2 := by bv_normalize
example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize

section

Expand Down
Loading