Skip to content

Commit d38dc72

Browse files
tobiasgrosserkim-em
authored andcommitted
chore: introduce BitVec.setWidth to unify zeroExtend and truncate
incomplete deprecations chore: complete deprecations
1 parent 4641ed8 commit d38dc72

File tree

9 files changed

+341
-203
lines changed

9 files changed

+341
-203
lines changed

src/Init/Data/BitVec/Basic.lean

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,15 @@ SMT-Lib name: `extract`.
453453
def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x
454454

455455
/--
456-
A version of `zeroExtend` that requires a proof, but is a noop.
456+
A version of `setWidth` that requires a proof, but is a noop.
457457
-/
458-
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
458+
def setWidth' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
459459
x.toNat#'(by
460460
apply Nat.lt_of_lt_of_le x.isLt
461461
exact Nat.pow_le_pow_of_le_right (by trivial) le)
462462

463+
@[deprecated setWidth' (since := "2024-09-18"), inherit_doc setWidth'] abbrev zeroExtend' := @setWidth'
464+
463465
/--
464466
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
465467
needing to compute `x % 2^(2+n)`.
@@ -472,22 +474,35 @@ def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w + m) :=
472474
(msbs.toNat <<< m)#'(shiftLeftLt msbs.isLt m)
473475

474476
/--
475-
Zero extend vector `x` of length `w` by adding zeros in the high bits until it has length `v`.
476-
If `v < w` then it truncates the high bits instead.
477+
Transform `x` of length `w` into a bitvector of length `v`, by either:
478+
- zero extending, that is, adding zeros in the high bits until it has length `v`, if `v > w`, or
479+
- truncating the high bits, if `v < w`.
477480
478481
SMT-Lib name: `zero_extend`.
479482
-/
480-
def zeroExtend (v : Nat) (x : BitVec w) : BitVec v :=
483+
def setWidth (v : Nat) (x : BitVec w) : BitVec v :=
481484
if h : w ≤ v then
482-
zeroExtend' h x
485+
setWidth' h x
483486
else
484487
.ofNat v x.toNat
485488

486489
/--
487-
Truncate the high bits of bitvector `x` of length `w`, resulting in a vector of length `v`.
488-
If `v > w` then it zero-extends the vector instead.
490+
Transform `x` of length `w` into a bitvector of length `v`, by either:
491+
- zero extending, that is, adding zeros in the high bits until it has length `v`, if `v > w`, or
492+
- truncating the high bits, if `v < w`.
493+
494+
SMT-Lib name: `zero_extend`.
495+
-/
496+
abbrev zeroExtend := @setWidth
497+
498+
/--
499+
Transform `x` of length `w` into a bitvector of length `v`, by either:
500+
- zero extending, that is, adding zeros in the high bits until it has length `v`, if `v > w`, or
501+
- truncating the high bits, if `v < w`.
502+
503+
SMT-Lib name: `zero_extend`.
489504
-/
490-
abbrev truncate := @zeroExtend
505+
abbrev truncate := @setWidth
491506

492507
/--
493508
Sign extend a vector of length `w`, extending with `i` additional copies of the most significant
@@ -638,7 +653,7 @@ input is on the left, so `0xAB#8 ++ 0xCD#8 = 0xABCD#16`.
638653
SMT-Lib name: `concat`.
639654
-/
640655
def append (msbs : BitVec n) (lsbs : BitVec m) : BitVec (n+m) :=
641-
shiftLeftZeroExtend msbs m ||| zeroExtend' (Nat.le_add_left m n) lsbs
656+
shiftLeftZeroExtend msbs m ||| setWidth' (Nat.le_add_left m n) lsbs
642657

643658
instance : HAppend (BitVec w) (BitVec v) (BitVec (w + v)) := ⟨.append⟩
644659

src/Init/Data/BitVec/Bitblast.lean

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ def adc (x y : BitVec w) : Bool → Bool × BitVec w :=
139139
iunfoldr fun (i : Fin w) c => adcb (x.getLsbD i) (y.getLsbD i) c
140140

141141
theorem getLsbD_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) :
142-
getLsbD (x + y + zeroExtend w (ofBool c)) i =
142+
getLsbD (x + y + setWidth w (ofBool c)) i =
143143
Bool.xor (getLsbD x i) (Bool.xor (getLsbD y i) (carry i x y c)) := by
144144
let ⟨x, x_lt⟩ := x
145145
let ⟨y, y_lt⟩ := y
146-
simp only [getLsbD, toNat_add, toNat_zeroExtend, i_lt, toNat_ofFin, toNat_ofBool,
146+
simp only [getLsbD, toNat_add, toNat_setWidth, i_lt, toNat_ofFin, toNat_ofBool,
147147
Nat.mod_add_mod, Nat.add_mod_mod]
148148
apply Eq.trans
149149
rw [← Nat.div_add_mod x (2^i), ← Nat.div_add_mod y (2^i)]
@@ -165,11 +165,11 @@ theorem getLsbD_add {i : Nat} (i_lt : i < w) (x y : BitVec w) :
165165
simpa using getLsbD_add_add_bool i_lt x y false
166166

167167
theorem adc_spec (x y : BitVec w) (c : Bool) :
168-
adc x y c = (carry w x y c, x + y + zeroExtend w (ofBool c)) := by
168+
adc x y c = (carry w x y c, x + y + setWidth w (ofBool c)) := by
169169
simp only [adc]
170170
apply iunfoldr_replace
171171
(fun i => carry i x y c)
172-
(x + y + zeroExtend w (ofBool c))
172+
(x + y + setWidth w (ofBool c))
173173
c
174174
case init =>
175175
simp [carry, Nat.mod_one]
@@ -306,12 +306,12 @@ theorem mulRec_succ_eq (x y : BitVec w) (s : Nat) :
306306
Recurrence lemma: truncating to `i+1` bits and then zero extending to `w`
307307
equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`.
308308
-/
309-
theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) :
310-
zeroExtend w (x.truncate (i + 1)) =
311-
zeroExtend w (x.truncate i) + (x &&& twoPow w i) := by
309+
theorem setWidth_setWidth_succ_eq_setWidth_setWidth_add_twoPow (x : BitVec w) (i : Nat) :
310+
setWidth w (x.setWidth (i + 1)) =
311+
setWidth w (x.setWidth i) + (x &&& twoPow w i) := by
312312
rw [add_eq_or_of_and_eq_zero]
313313
· ext k
314-
simp only [getLsbD_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsbD_or, getLsbD_and]
314+
simp only [getLsbD_setWidth, Fin.is_lt, decide_True, Bool.true_and, getLsbD_or, getLsbD_and]
315315
by_cases hik : i = k
316316
· subst hik
317317
simp
@@ -322,41 +322,49 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w
322322
· have hik'' : ¬ (k < i) := by omega
323323
simp [hik', hik'']
324324
· ext k
325-
simp only [and_twoPow, getLsbD_and, getLsbD_zeroExtend, Fin.is_lt, decide_True, Bool.true_and,
325+
simp only [and_twoPow, getLsbD_and, getLsbD_setWidth, Fin.is_lt, decide_True, Bool.true_and,
326326
getLsbD_zero, and_eq_false_imp, and_eq_true, decide_eq_true_eq, and_imp]
327327
by_cases hi : x.getLsbD i <;> simp [hi] <;> omega
328328

329+
@[deprecated setWidth_setWidth_succ_eq_setWidth_setWidth_add_twoPow (since := "2024-09-18"),
330+
inherit_doc setWidth_setWidth_succ_eq_setWidth_setWidth_add_twoPow]
331+
abbrev zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow :=
332+
@setWidth_setWidth_succ_eq_setWidth_setWidth_add_twoPow
333+
329334
/--
330335
Recurrence lemma: multiplying `x` with the first `s` bits of `y` is the
331336
same as truncating `y` to `s` bits, then zero extending to the original length,
332337
and performing the multplication. -/
333-
theorem mulRec_eq_mul_signExtend_truncate (x y : BitVec w) (s : Nat) :
334-
mulRec x y s = x * ((y.truncate (s + 1)).zeroExtend w) := by
338+
theorem mulRec_eq_mul_signExtend_setWidth (x y : BitVec w) (s : Nat) :
339+
mulRec x y s = x * ((y.setWidth (s + 1)).setWidth w) := by
335340
induction s
336341
case zero =>
337342
simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd]
338343
by_cases y.getLsbD 0
339344
case pos hy =>
340-
simp only [hy, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
345+
simp only [hy, ↓reduceIte, setWidth_one_eq_ofBool_getLsb_zero,
341346
ofBool_true, ofNat_eq_ofNat]
342-
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]
347+
rw [setWidth_ofNat_one_eq_ofNat_one_of_lt (by omega)]
343348
simp
344349
case neg hy =>
345-
simp [hy, zeroExtend_one_eq_ofBool_getLsb_zero]
350+
simp [hy, setWidth_one_eq_ofBool_getLsb_zero]
346351
case succ s' hs =>
347352
rw [mulRec_succ_eq, hs]
348353
have heq :
349354
(if y.getLsbD (s' + 1) = true then x <<< (s' + 1) else 0) =
350355
(x * (y &&& (BitVec.twoPow w (s' + 1)))) := by
351356
simp only [ofNat_eq_ofNat, and_twoPow]
352357
by_cases hy : y.getLsbD (s' + 1) <;> simp [hy]
353-
rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]
358+
rw [heq, ← BitVec.mul_add, ← setWidth_setWidth_succ_eq_setWidth_setWidth_add_twoPow]
359+
360+
@[deprecated mulRec_eq_mul_signExtend_setWidth (since := "2024-09-18"),
361+
inherit_doc mulRec_eq_mul_signExtend_setWidth]
362+
abbrev mulRec_eq_mul_signExtend_truncate := @mulRec_eq_mul_signExtend_setWidth
354363

355364
theorem getLsbD_mul (x y : BitVec w) (i : Nat) :
356365
(x * y).getLsbD i = (mulRec x y w).getLsbD i := by
357-
simp only [mulRec_eq_mul_signExtend_truncate]
358-
rw [truncate, ← truncate_eq_zeroExtend, ← truncate_eq_zeroExtend,
359-
truncate_truncate_of_le]
366+
simp only [mulRec_eq_mul_signExtend_setWidth]
367+
rw [setWidth_setWidth_of_le]
360368
· simp
361369
· omega
362370

@@ -402,22 +410,22 @@ theorem shiftLeft_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
402410
`shiftLeftRec x y n` shifts `x` to the left by the first `n` bits of `y`.
403411
-/
404412
theorem shiftLeftRec_eq {x : BitVec w₁} {y : BitVec w₂} {n : Nat} :
405-
shiftLeftRec x y n = x <<< (y.truncate (n + 1)).zeroExtend w₂ := by
413+
shiftLeftRec x y n = x <<< (y.setWidth (n + 1)).setWidth w₂ := by
406414
induction n generalizing x y
407415
case zero =>
408416
ext i
409-
simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one,
410-
and_one_eq_zeroExtend_ofBool_getLsbD]
417+
simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, setWidth_one,
418+
and_one_eq_setWidth_ofBool_getLsbD]
411419
case succ n ih =>
412420
simp only [shiftLeftRec_succ, and_twoPow]
413421
rw [ih]
414422
by_cases h : y.getLsbD (n + 1)
415423
· simp only [h, ↓reduceIte]
416-
rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsbD_true h,
424+
rw [setWidth_setWidth_succ_eq_setWidth_setWidth_or_twoPow_of_getLsbD_true h,
417425
shiftLeft_or_of_and_eq_zero]
418426
simp [and_twoPow]
419427
· simp only [h, false_eq_true, ↓reduceIte, shiftLeft_zero']
420-
rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsbD_false (i := n + 1)]
428+
rw [setWidth_setWidth_succ_eq_setWidth_setWidth_of_getLsbD_false (i := n + 1)]
421429
simp [h]
422430

423431
/--
@@ -466,18 +474,18 @@ theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
466474
toNat_add_of_and_eq_zero h, sshiftRight_add]
467475

468476
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
469-
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
477+
sshiftRightRec x y n = x.sshiftRight' ((y.setWidth (n + 1)).setWidth w₂) := by
470478
induction n generalizing x y
471479
case zero =>
472480
ext i
473-
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsbD, truncate_one]
481+
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_setWidth_ofBool_getLsbD, setWidth_one]
474482
case succ n ih =>
475483
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
476484
by_cases h : y.getLsbD (n + 1)
477-
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsbD_true h,
485+
· rw [setWidth_setWidth_succ_eq_setWidth_setWidth_or_twoPow_of_getLsbD_true h,
478486
sshiftRight'_or_of_and_eq_zero (by simp [and_twoPow]), h]
479487
simp
480-
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsbD_false (i := n + 1)
488+
· rw [setWidth_setWidth_succ_eq_setWidth_setWidth_of_getLsbD_false (i := n + 1)
481489
(by simp [h])]
482490
simp [h]
483491

@@ -529,20 +537,20 @@ theorem ushiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
529537
simp [← add_eq_or_of_and_eq_zero _ _ h, toNat_add_of_and_eq_zero h, shiftRight_add]
530538

531539
theorem ushiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
532-
ushiftRightRec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by
540+
ushiftRightRec x y n = x >>> (y.setWidth (n + 1)).setWidth w₂ := by
533541
induction n generalizing x y
534542
case zero =>
535543
ext i
536544
simp only [ushiftRightRec_zero, twoPow_zero, Nat.reduceAdd,
537-
and_one_eq_zeroExtend_ofBool_getLsbD, truncate_one]
545+
and_one_eq_setWidth_ofBool_getLsbD, setWidth_one]
538546
case succ n ih =>
539547
simp only [ushiftRightRec_succ, and_twoPow]
540548
rw [ih]
541549
by_cases h : y.getLsbD (n + 1) <;> simp only [h, ↓reduceIte]
542-
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsbD_true h,
550+
· rw [setWidth_setWidth_succ_eq_setWidth_setWidth_or_twoPow_of_getLsbD_true h,
543551
ushiftRight'_or_of_and_eq_zero]
544552
simp [and_twoPow]
545-
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsbD_false, h]
553+
· simp [setWidth_setWidth_succ_eq_setWidth_setWidth_of_getLsbD_false, h]
546554

547555
/--
548556
Show that `x >>> y` can be written in terms of `ushiftRightRec`.

src/Init/Data/BitVec/Folds.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private theorem iunfoldr.eq_test
4848
simp only [init, eq_nil]
4949
case step =>
5050
intro i
51-
simp_all [truncate_succ]
51+
simp_all [setWidth_succ]
5252

5353
theorem iunfoldr_getLsbD' {f : Fin w → α → α × Bool} (state : Nat → α)
5454
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :

0 commit comments

Comments
 (0)