Skip to content

Commit 738db83

Browse files
committed
WIP
1 parent e39d27c commit 738db83

File tree

2 files changed

+82
-59
lines changed

2 files changed

+82
-59
lines changed

src/Init/Data/BitVec/Bitblast.lean

Lines changed: 81 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -190,47 +190,112 @@ We implement [Booth's multiplication circuit](https://en.wikipedia.org/wiki/Boot
190190
on bitvectors, and show that this circuit is equal to our straightforward `BitVec.mul` implementation.
191191
-/
192192

193-
def mulAdd (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) : BitVec (w+v) :=
194-
let x : BitVec (w+v) := x.zeroExtend' (le_add_right w v)
195-
Prod.snd <| iunfoldr (s:=a) fun (i : Fin (w+v)) a =>
193+
def mulAdd (a x y : BitVec w) : BitVec w :=
194+
Prod.snd <| iunfoldr (s:=a) fun (i : Fin w) a =>
196195
let a := if y.getLsb i = true then a + x else a
197196
(a >>> 1, a.getLsb 0)
198197

199-
def mulAddAccumulator (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) (i : Nat) : BitVec (w+v) :=
200-
(a + (x.zeroExtend' <| le_add_right w v) * ((y.extractLsb' 0 i).zeroExtend _) ) >>> i
198+
def mulAddAccumulator (a x y : BitVec w) (i : Nat) : BitVec w :=
199+
(a + x * (y.truncate i |>.zeroExtend _)) >>> i
200+
201+
@[simp] theorem truncate_zero : truncate 0 x = 0#0 := of_length_zero
202+
203+
@[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl
204+
@[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl
205+
@[simp] theorem shiftLeft_zero (x : BitVec w) : x <<< 0 = x := by apply eq_of_toNat_eq; simp
206+
207+
theorem mulAddAccumulator_zero (a x y : BitVec w) : mulAddAccumulator a x y 0 = a := by
208+
simp [mulAddAccumulator]
209+
210+
theorem Nat.shiftRight_add' (m n k : Nat) :
211+
m >>> n + k = (m + (k <<< n)) >>> n := by
212+
sorry
213+
214+
theorem shiftRight_add' (x y : BitVec w) (n : Nat) :
215+
x >>> n + y = (x + (y <<< n)) >>> n := by
216+
sorry
217+
218+
#check BitVec.shiftRight_shiftRight
219+
220+
theorem zeroExtend_truncate_eq_and (x : BitVec w) (i : Nat) :
221+
zeroExtend w (x.truncate i) = x &&& ((-1 : BitVec _) >>> (w-i)) := by
222+
sorry
223+
224+
theorem add_shiftRight (x y : BitVec w) (n : Nat) : (x + y) >>> n = (x >>> n) + (y >>> n) := by
225+
sorry
226+
227+
@[simp] theorem zero_shiftRight (w n : Nat) : 0#w >>> n = 0#w := by
228+
sorry
229+
230+
theorem mod_two_pow_shiftRight (x m n : Nat) : (x % 2^m) >>> n = (x >>> n) % (2^(m+n)) := by
231+
induction n
232+
case zero => rfl
233+
case succ n ih =>
234+
simp [shiftRight_succ]
235+
sorry
236+
237+
theorem shiftLeft_shiftRight_eq_zeroExtend_truncate (x : BitVec w) (i : Nat) :
238+
x <<< i >>> i = zeroExtend w (truncate (w-i) x) := by
239+
apply eq_of_toNat_eq
240+
simp only [toNat_ushiftRight, toNat_shiftLeft, toNat_truncate]
241+
induction i
242+
case a.zero => simp
243+
case a.succ i ih =>
244+
rw [mod_two_pow_shiftRight]
245+
sorry
246+
247+
theorem mulAddAccumulator_succ (a x y : BitVec w) :
248+
mulAddAccumulator a x y (i+1)
249+
= (mulAddAccumulator a x y i >>> 1)
250+
+ bif y.getLsb (i+1) then (x.truncate (i+1) |>.zeroExtend _) else 0#w := by
251+
-- ext j
252+
simp only [mulAddAccumulator, natCast_eq_ofNat, BitVec.shiftRight_shiftRight]
253+
have :
254+
x * zeroExtend w (truncate (i + 1) y)
255+
= x * zeroExtend w (truncate i y) + (bif y.getLsb (i+1) then x <<< (i+1) else 0) := by
256+
simp [← shiftLeft_shiftRight_eq_zeroExtend_truncate]
257+
rw [this, ← BitVec.add_assoc, add_shiftRight]
258+
congr
259+
cases y.getLsb (i+1)
260+
· simp
261+
· simp; sorry
262+
263+
264+
201265

202266
@[simp]
203267
theorem zeroExtend_zero_width (x : BitVec 0) : zeroExtend w x = 0#w := by
204268
sorry
205269

206-
@[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl
207-
@[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl
270+
-- @[simp] theorem shiftRight_zero (x : BitVec w) : x >>> 0 = x := rfl
271+
-- @[simp] theorem mul_zero (x : BitVec w) : x * 0#w = 0#w := rfl
208272

209273
theorem extractLsb'_succ_eq_concat (x : BitVec w) (s n : Nat) :
210274
x.extractLsb' s (n+1) = cons (x.getLsb (s+n)) (x.extractLsb' s n) := by
211275
sorry
212276

213-
theorem mulAdd_spec (a : BitVec (w+v)) (x : BitVec w) (y : BitVec v) :
214-
mulAdd a x y = a + (x.zeroExtend' <| le_add_right w v) * (y.zeroExtend' <| le_add_left v w) := by
277+
theorem mulAdd_spec (a x y : BitVec w):
278+
mulAdd a x y = a + x * y := by
215279
simp only [mulAdd]
216280
rw [iunfoldr_replace (state := mulAddAccumulator a x y)]
217-
· simp [mulAddAccumulator]
281+
· simp [mulAddAccumulator, Nat.mod_one]
218282
· intro i
219-
simp only [mulAddAccumulator, Prod.mk.injEq]
220-
simp only [extractLsb'_succ_eq_concat y 0 i, Nat.zero_add]
283+
simp only [mulAddAccumulator, Prod.mk.injEq, natCast_eq_ofNat]
221284
cases y.getLsb i <;> simp
222285
· sorry
223286
· sorry
224287

225-
@[simp] theorem zeroExtend'_mul_zeroExtend' (x y : BitVec w) (h : w ≤ v) :
226-
x.zeroExtend' h * y.zeroExtend' h = (x * y).zeroExtend' h := by
288+
theorem getLsb_mul (x y : BitVec w) (i : Fin w) :
289+
(x * y).getLsb i = Bool.xor (x.getLsb i && y.getLsb i) ((mulAddAccumulator 0 x y i).getLsb 0) := by
227290
sorry
228291

292+
theorem zeroExtend'_mul_zeroExtend' (x y : BitVec w) (h : w ≤ v) :
293+
x.zeroExtend' h * y.zeroExtend' h = (x * y).zeroExtend' h := by
294+
sorry
229295

230296
@[simp] theorem zeroExtend'_rfl (x : BitVec w) (h : w ≤ w := by rfl) : x.zeroExtend' h = x := rfl
231297

232-
@[simp]
233-
theorem truncate_zeroExtend' (x : BitVec w) (h : w ≤ v) : truncate w (x.zeroExtend' h) = x := by
298+
@[simp] theorem truncate_zeroExtend' (x : BitVec w) (h : w ≤ v) : truncate w (x.zeroExtend' h) = x := by
234299
simp [truncate, zeroExtend]
235300
intro h'
236301
have h_eq : w = v := Nat.le_antisymm h h'
@@ -241,46 +306,4 @@ theorem mul_eq_mulAdd (x y : BitVec w) :
241306
x * y = (mulAdd 0 x y).truncate w := by
242307
simp [mulAdd_spec]
243308

244-
@[simp]
245-
theorem extractLsb'_zero (x : BitVec w) : extractLsb' 0 n x = truncate n x := by
246-
simp [extractLsb']
247-
248-
@[simp]
249-
theorem extractLsb'_succ_concat : extractLsb' (start+1) n (concat x a) = extractLsb' start n x := by
250-
simp [extractLsb']
251-
sorry
252-
253-
-- theorem mulAdd_eq
254-
255-
theorem mul_eq_mulAdd (x y : BitVec w) :
256-
x * y = (mulAdd 0 x y).truncate _ := by
257-
suffices ∀ {v w} (x : BitVec (w+v)) (y : BitVec w) (z : BitVec v),
258-
x * (y ++ z) = (mulAdd (x*z) x y).truncate _
259-
by simpa using @this 0 w x y 0
260-
induction w
261-
case zero =>
262-
sorry
263-
case succ w ih =>
264-
have ⟨x, x₀, hx⟩ : ∃ (x' : BitVec w) (x₀ : Bool), x = BitVec.concat x' x₀ := sorry
265-
have ⟨y, y₀, hy⟩ : ∃ (y' : BitVec w) (y₀ : Bool), y = BitVec.concat y' y₀ := sorry
266-
subst hx hy
267-
cases y₀ <;> simp [mulAdd]
268-
· simp [extractLsb']
269-
rw [show 0#w = 0 from rfl, ← ih]
270-
· sorry
271-
272-
def mulC (x y : BitVec w) : BitVec w :=
273-
go _
274-
where go (acc : BitVec w) (x y : BitVec w) : BitVec w
275-
276-
-- def boothMul (x y : BitVec w) : BitVec w :=
277-
-- let a : BitVec (w+w+1) := x ++ (0 : BitVec (w+1))
278-
-- let s : BitVec (w+w+1) := -x ++ (0 : BitVec (w+1))
279-
-- let p : BitVec (w+w+1) := (0 : BitVec w) ++ (y : BitVec w) ++ (0 : BitVec 1)
280-
-- go a s p w
281-
-- where
282-
-- go (a s p : BitVec (w+w+1)) : Nat → BitVec w
283-
-- | 0 => p
284-
-- | n+1 =>
285-
286309
end BitVec

src/Init/Data/BitVec/Lemmas.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ theorem msb_append {x : BitVec w} {y : BitVec v} :
704704
simp only [getLsb_append, cond_eq_if]
705705
split <;> simp [*]
706706

707-
theorem BitVec.shiftRight_shiftRight (w : Nat) (x : BitVec w) (n m : Nat) :
707+
theorem shiftRight_shiftRight (w : Nat) (x : BitVec w) (n m : Nat) :
708708
(x >>> n) >>> m = x >>> (n + m) := by
709709
ext i
710710
simp [Nat.add_assoc n m i]

0 commit comments

Comments
 (0)