Skip to content

Commit d222874

Browse files
committed
feat: sshiftRight and ushiftRight bitblasting
1 parent 90dab5e commit d222874

File tree

3 files changed

+158
-6
lines changed

3 files changed

+158
-6
lines changed

src/Init/Data/BitVec/Basic.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,8 @@ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
534534
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩
535535
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩
536536

537+
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat
538+
537539
/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
538540
the rotation amount is greater than the bitvector width. -/
539541
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=

src/Init/Data/BitVec/Bitblast.lean

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,8 @@ theorem shiftLeftRec_eq {x : BitVec w₁} {y : BitVec w₂} {n : Nat} :
403403
induction n generalizing x y
404404
case zero =>
405405
ext i
406-
simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one]
407-
suffices (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) by simp [this]
408-
ext i
409-
by_cases h : (↑i : Nat) = 0
410-
· simp [h, Bool.and_comm]
411-
· simp [h]; omega
406+
simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one,
407+
and_one_eq_zeroExtend_ofBool_getLsb]
412408
case succ n ih =>
413409
simp only [shiftLeftRec_succ, and_twoPow]
414410
rw [ih]
@@ -431,4 +427,102 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
431427
· simp [of_length_zero]
432428
· simp [shiftLeftRec_eq]
433429

430+
/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/
431+
432+
def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
433+
let shiftAmt := (y &&& (twoPow w₂ n))
434+
match n with
435+
| 0 => x >>> shiftAmt
436+
| n + 1 => (ushiftRight_rec x y n) >>> shiftAmt
437+
438+
@[simp]
439+
theorem ushiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) :
440+
ushiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by
441+
simp [ushiftRight_rec]
442+
443+
@[simp]
444+
theorem ushiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) :
445+
ushiftRight_rec x y (n + 1) =
446+
(ushiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by
447+
simp [ushiftRight_rec]
448+
449+
theorem ushiftRight'_ushiftRight' {x y z : BitVec w} :
450+
x >>> y >>> z = x >>> (y.toNat + z.toNat) := by
451+
simp [shiftRight_add]
452+
453+
theorem ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
454+
(h : y &&& z = 0#w₂) :
455+
x >>> (y ||| z) = x >>> y >>> z := by
456+
simp [← add_eq_or_of_and_eq_zero _ _ h, toNat_add_of_and_eq_zero h, shiftRight_add]
457+
458+
theorem getLsb_ushiftRight' (x : BitVec w₁) (y : BitVec w₂) (i : Nat) :
459+
(x >>> y).getLsb i = x.getLsb (y.toNat + i) := by
460+
simp [getLsb_ushiftRight]
461+
462+
theorem ushiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
463+
ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by
464+
induction n generalizing x y
465+
case zero =>
466+
ext i
467+
simp only [ushiftRight_rec_zero, twoPow_zero, Nat.reduceAdd,
468+
and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
469+
case succ n ih =>
470+
simp only [ushiftRight_rec_succ, and_twoPow]
471+
rw [ih]
472+
by_cases h : y.getLsb (n + 1) <;> simp only [h, ↓reduceIte]
473+
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h]
474+
rw [ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero]
475+
simp
476+
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1), h]
477+
478+
theorem shiftRight_eq_shiftRight_rec (x : BitVec w₁) (y : BitVec w₂) :
479+
x >>> y = ushiftRight_rec x y (w₂ - 1) := by
480+
rcases w₂ with rfl | w₂
481+
· simp [of_length_zero]
482+
· simp [ushiftRight_rec_eq]
483+
484+
/- ### Arithmetic shift right (sshiftRight) recurrence -/
485+
486+
def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w :=
487+
let shiftAmt := (y &&& (twoPow w₂ n))
488+
match n with
489+
| 0 => x.sshiftRight' shiftAmt
490+
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt
491+
492+
@[simp]
493+
theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) :
494+
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
495+
simp only [sshiftRightRec, twoPow_zero]
496+
497+
@[simp]
498+
theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) :
499+
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
500+
simp [sshiftRightRec]
501+
502+
theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂}
503+
(h : y &&& z = 0#w₂) :
504+
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
505+
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
506+
toNat_add_of_and_eq_zero h, sshiftRight'_add]
507+
508+
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
509+
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
510+
induction n generalizing x y
511+
case zero =>
512+
ext i
513+
simp [ushiftRight_rec_zero, twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb,
514+
truncate_one]
515+
case succ n ih =>
516+
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
517+
by_cases h : y.getLsb (n + 1) <;> simp [h]
518+
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
519+
sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero]
520+
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1), h]
521+
522+
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
523+
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
524+
rcases w₂ with rfl | w₂
525+
· simp [of_length_zero]
526+
· simp [sshiftRightRec_eq]
527+
434528
end BitVec

src/Init/Data/BitVec/Lemmas.lean

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,16 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
731731
getLsb (x >>> i) j = getLsb x (i+j) := by
732732
unfold getLsb ; simp
733733

734+
@[simp]
735+
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
736+
simp [bv_toNat]
737+
738+
/-! ### ushiftRight reductions from BitVec to Nat -/
739+
740+
@[simp]
741+
theorem ushiftRight_eq' (x : BitVec w) (y : BitVec w₂) :
742+
x >>> y = x >>> y.toNat := by rfl
743+
734744
/-! ### sshiftRight -/
735745

736746
theorem sshiftRight_eq {x : BitVec n} {i : Nat} :
@@ -795,6 +805,40 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
795805
Nat.not_lt, decide_eq_true_eq]
796806
omega
797807

808+
/-- The msb after arithmetic shifting right equals the original msb. -/
809+
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
810+
(x.sshiftRight n).msb = x.msb := by
811+
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
812+
by_cases hw₀ : w = 0
813+
· simp [hw₀]
814+
· simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
815+
ite_eq_right_iff]
816+
intros h
817+
simp [show n = 0 by omega]
818+
819+
theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
820+
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
821+
ext i
822+
simp only [getLsb_sshiftRight, Nat.add_assoc]
823+
by_cases h₁ : w ≤ (i : Nat)
824+
· simp [h₁]
825+
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
826+
by_cases h₂ : n + ↑i < w
827+
· simp [h₂]
828+
· simp only [h₂, ↓reduceIte]
829+
by_cases h₃ : m + (n + ↑i) < w
830+
· simp [h₃]
831+
omega
832+
· simp [h₃, sshiftRight_msb_eq_msb]
833+
834+
/-! ### shiftRight reductions from BitVec to Nat -/
835+
836+
@[simp]
837+
theorem sshiftRight'_zero (x : BitVec w) :
838+
x.sshiftRight' (0#w₂) = x := by
839+
ext i
840+
simp [sshiftRight', getLsb_sshiftRight]
841+
798842
/-! ### signExtend -/
799843

800844
/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
@@ -929,6 +973,10 @@ theorem shiftRight_add {w : Nat} (x : BitVec w) (n m : Nat) :
929973
ext i
930974
simp [Nat.add_assoc n m i]
931975

976+
theorem sshiftRight'_add {x : BitVec w₁} {y : BitVec w₂} {z : BitVec w₃} :
977+
x.sshiftRight (y.toNat + z.toNat) = (x.sshiftRight' y).sshiftRight' z := by
978+
simp [sshiftRight', shiftRight_add, sshiftRight_add]
979+
932980
@[deprecated shiftRight_add (since := "2024-06-02")]
933981
theorem shiftRight_shiftRight {w : Nat} (x : BitVec w) (n m : Nat) :
934982
(x >>> n) >>> m = x >>> (n + m) := by
@@ -1549,4 +1597,12 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true
15491597
simp [hx]
15501598
· by_cases hik' : k < i + 1 <;> simp [hik, hik'] <;> omega
15511599

1600+
/-- Bitwise `and` of `(x : BitVec w`) with `1#w` equals zero extending the `lsb` to `w`. -/
1601+
theorem and_one_eq_zeroExtend_ofBool_getLsb {x : BitVec w} :
1602+
(x &&& 1#w) = zeroExtend w (ofBool (x.getLsb 0)) := by
1603+
ext i
1604+
simp only [getLsb_and, getLsb_one, getLsb_zeroExtend, Fin.is_lt, decide_True, getLsb_ofBool,
1605+
Bool.true_and]
1606+
by_cases h : (0 = (i : Nat)) <;> simp [h] <;> omega
1607+
15521608
end BitVec

0 commit comments

Comments
 (0)