Skip to content

Commit

Permalink
chore: cleaup proofs, move them to the right locations
Browse files Browse the repository at this point in the history
  • Loading branch information
bollu committed Jun 8, 2024
1 parent e6cba22 commit 2f0cb91
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 92 deletions.
7 changes: 7 additions & 0 deletions src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,13 @@ theorem ofBool_append (msb : Bool) (lsbs : BitVec w) :

end bitwise

section twoPow

/-- `twoPow i` is the bitvector `2^i` if `i < w`, and `0` otherwise. That is, 2 to the power `i`. -/
def twoPow {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i

end twoPow

section normalization_eqs
/-! We add simp-lemmas that rewrite bitvector operations into the equivalent notation -/
@[simp] theorem append_eq (x : BitVec w) (y : BitVec v) : BitVec.append x y = x ++ y := rfl
Expand Down
114 changes: 22 additions & 92 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -252,75 +252,11 @@ theorem sle_eq_carry (x y : BitVec w) :

/-! ### mul recurrence for bitblasting -/

open BitVec in
/-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/
def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i

@[simp]
theorem toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by
rcases w with rfl | w
· simp [Nat.mod_one]
· simp [pot, toNat_shiftLeft]
have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega)
rw [Nat.mod_eq_of_lt h1]
rw [Nat.shiftLeft_eq, Nat.one_mul]

/-- `testBit 1 i` is true iff the index `i` equals 0. -/
private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} :
Nat.testBit 1 i = true ↔ i = 0 := by
cases i <;> simp

@[simp]
theorem getLsb_pot (i j : Nat) : (pot i : BitVec w).getLsb j = ((i < w) && (i = j)) := by
rcases w with rfl | w
· simp only [pot, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq,
Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not]
omega
· simp only [pot, getLsb_shiftLeft, getLsb_ofNat]
by_cases hj : j < i
· simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq,
Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not]
omega
· by_cases hi : Nat.testBit 1 (j - i)
· obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi
have hij : j = i := by omega
simp_all
· have hij : i ≠ j := by
intro h; subst h
simp at hi
simp_all

theorem and_pot_eq_getLsb (x : BitVec w) (i : Nat) :
x &&& (pot i : BitVec w) = if x.getLsb i then pot i else 0#w := by
ext j
simp only [getLsb_and, getLsb_pot]
by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all

@[simp]
theorem mul_pot_eq_shiftLeft (x : BitVec w) (i : Nat) :
x * (pot i : BitVec w) = x <<< i := by
apply eq_of_toNat_eq
simp only [toNat_mul, toNat_pot, toNat_shiftLeft, Nat.shiftLeft_eq]
by_cases hi : i < w
· have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega)
rw [Nat.mod_eq_of_lt hpow]
· have hpow : 2 ^ i % 2 ^ w = 0 := by
rw [Nat.mod_eq_zero_of_dvd]
apply Nat.pow_dvd_pow 2 (by omega)
simp [Nat.mul_mod, hpow]

theorem BitVec.toNat_pot (w : Nat) (i : Nat) : (pot i : BitVec w).toNat = 2^i % 2^w := by
rcases w with rfl | w
· simp [Nat.mod_one]
· simp [pot, toNat_shiftLeft]
have hone : 1 < 2 ^ (w + 1) := by
rw [show 1 = 2^0 by simp[Nat.pow_zero]]
exact Nat.pow_lt_pow_of_lt (by omega) (by omega)
simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq]

theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) :
/-- Recurrence lemma that saus that truncating to `i+1` bits and then zero extending to `w`
equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. -/
theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) :
zeroExtend w (x.truncate (i + 1)) =
zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by
zeroExtend w (x.truncate i) + (x &&& twoPow i) := by
rw [add_eq_or_of_and_eq_zero]
· ext k
simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and]
Expand All @@ -341,30 +277,24 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (

theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) :
(mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by
induction w generalizing s
case zero => apply Subsingleton.elim
case succ w' hw =>
induction s
case zero =>
simp [mulRec_zero_eq]
by_cases r.getLsb 0
case pos hr =>
simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
hr, ofBool_true, ofNat_eq_ofNat]
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp
case neg hr =>
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
case succ s' hs =>
rw [mulRec_succ_eq]
rw [hs];
have heq :
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
(l * (r &&& (BitVec.pot (s' + 1)))) := by
simp only [ofNat_eq_ofNat, and_pot_eq_getLsb]
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
rw [heq, ← BitVec.mul_add]
rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot]

induction s
case zero =>
simp [mulRec_zero_eq]
by_cases r.getLsb 0
case pos hr =>
simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
hr, ofBool_true, ofNat_eq_ofNat]
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp
case neg hr =>
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
case succ s' hs =>
rw [mulRec_succ_eq, hs]
have heq :
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
(l * (r &&& (BitVec.twoPow (s' + 1)))) := by
simp only [ofNat_eq_ofNat, and_twoPow_eq_getLsb]
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]

/-- Zero extending by number of bits larger than the bitwidth has no effect. -/
theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) :
Expand Down
59 changes: 59 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1428,4 +1428,63 @@ theorem getLsb_rotateRight {x : BitVec w} {r i : Nat} :
· simp
· rw [← rotateRight_mod_eq_rotateRight, getLsb_rotateRight_of_le (Nat.mod_lt _ (by omega))]

/- ## twoPow -/

@[simp]
theorem toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2^w := by
rcases w with rfl | w
· simp [Nat.mod_one]
· simp [twoPow, toNat_shiftLeft]
have h1 : 1 < 2 ^ (w + 1) := Nat.one_lt_two_pow (by omega)
rw [Nat.mod_eq_of_lt h1]
rw [Nat.shiftLeft_eq, Nat.one_mul]

@[simp]
theorem getLsb_twoPow (i j : Nat) : (twoPow i : BitVec w).getLsb j = ((i < w) && (i = j)) := by
rcases w with rfl | w
· simp only [twoPow, BitVec.reduceOfNat, Nat.zero_le, getLsb_ge, Bool.false_eq,
Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not]
omega
· simp only [twoPow, getLsb_shiftLeft, getLsb_ofNat]
by_cases hj : j < i
· simp only [hj, decide_True, Bool.not_true, Bool.and_false, Bool.false_and, Bool.false_eq,
Bool.and_eq_false_imp, decide_eq_true_eq, decide_eq_false_iff_not]
omega
· by_cases hi : Nat.testBit 1 (j - i)
· obtain hi' := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi
have hij : j = i := by omega
simp_all
· have hij : i ≠ j := by
intro h; subst h
simp at hi
simp_all

theorem and_twoPow_eq_getLsb (x : BitVec w) (i : Nat) :
x &&& (twoPow i : BitVec w) = if x.getLsb i then twoPow i else 0#w := by
ext j
simp only [getLsb_and, getLsb_twoPow]
by_cases hj : i = j <;> by_cases hx : x.getLsb i <;> simp_all

@[simp]
theorem mul_twoPow_eq_shiftLeft (x : BitVec w) (i : Nat) :
x * (twoPow i : BitVec w) = x <<< i := by
apply eq_of_toNat_eq
simp only [toNat_mul, toNat_twoPow, toNat_shiftLeft, Nat.shiftLeft_eq]
by_cases hi : i < w
· have hpow : 2^i < 2^w := Nat.pow_lt_pow_of_lt (by omega) (by omega)
rw [Nat.mod_eq_of_lt hpow]
· have hpow : 2 ^ i % 2 ^ w = 0 := by
rw [Nat.mod_eq_zero_of_dvd]
apply Nat.pow_dvd_pow 2 (by omega)
simp [Nat.mul_mod, hpow]

theorem BitVec.toNat_twoPow (w : Nat) (i : Nat) : (twoPow i : BitVec w).toNat = 2^i % 2^w := by
rcases w with rfl | w
· simp [Nat.mod_one]
· simp [twoPow, toNat_shiftLeft]
have hone : 1 < 2 ^ (w + 1) := by
rw [show 1 = 2^0 by simp[Nat.pow_zero]]
exact Nat.pow_lt_pow_of_lt (by omega) (by omega)
simp [Nat.mod_eq_of_lt hone, Nat.shiftLeft_eq]

end BitVec

0 comments on commit 2f0cb91

Please sign in to comment.