diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 91e1ae9d34e2..c83cb696124c 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -170,6 +170,9 @@ instance : GetElem (BitVec w) Nat Bool fun _ i => i < w where @[simp] theorem getLsb?_eq_getElem? (x : BitVec w) (i : Nat) : x.getLsb? i = x[i]? := rfl +theorem getElem_eq_toNat_testBit (x : BitVec w) (i : Fin w) : + x[i] = x.toNat.testBit i := rfl + end getElem section Int diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 234df0166ae3..3ec55d6bbc72 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -107,6 +107,8 @@ theorem eq_of_toNat_eq {n} : ∀ {x y : BitVec n}, x.toNat = y.toNat → x = y theorem testBit_toNat (x : BitVec w) : x.toNat.testBit i = x.getLsbD i := rfl +theorem testBit_toNat_getElem (x : BitVec w) (i : Fin w) : x.toNat.testBit i = x[i.val] := rfl + theorem getMsb'_eq_getLsb' (x : BitVec w) (i : Fin w) : x.getMsb' i = x.getLsb' ⟨w - 1 - i, by omega⟩ := by simp only [getMsb', getLsb'] @@ -168,9 +170,8 @@ theorem getMsbD_eq_getMsb?_getD (x : BitVec w) (i : Nat) : intros omega --- We choose `eq_of_getLsbD_eq` as the `@[ext]` theorem for `BitVec` --- somewhat arbitrarily over `eq_of_getMsbD_eq`. -@[ext] theorem eq_of_getLsbD_eq {x y : BitVec w} +@[ext] +theorem eq_of_getLsbD_eq {x y : BitVec w} (pred : ∀(i : Fin w), x.getLsbD i.val = y.getLsbD i.val) : x = y := by apply eq_of_toNat_eq apply Nat.eq_of_testBit_eq @@ -181,6 +182,30 @@ theorem getMsbD_eq_getMsb?_getD (x : BitVec w) (i : Nat) : have p : i ≥ w := Nat.le_of_not_gt i_lt simp [testBit_toNat, getLsbD_ge _ _ p] +theorem eq_of_getElem_eq_fin {x y : BitVec w} + (pred : ∀(i : Fin w), x.getLsbD i.val = y.getLsbD i.val) : x = y := by + apply eq_of_toNat_eq + apply Nat.eq_of_testBit_eq + intro i + if i_lt : i < w then + exact pred ⟨i, i_lt⟩ + else + have _ : 2 ^ w ≤ 2 ^ i := Nat.pow_le_pow_of_le (by omega) (by omega) + rw [Nat.testBit_lt_two_pow (by omega), Nat.testBit_lt_two_pow (by omega)] + +theorem eq_of_getElem_eq_nat {x y : BitVec w} + (pred : ∀ (i : Nat) (_ : i < w), x[i] = y[i]) : x = y := by + apply eq_of_toNat_eq + apply Nat.eq_of_testBit_eq + intro i + if i_lt : i < w then + rw [testBit_toNat_getElem x ⟨i, i_lt⟩] + rw [testBit_toNat_getElem y ⟨i, i_lt⟩] + exact pred i i_lt + else + have _ : 2 ^ w ≤ 2 ^ i := Nat.pow_le_pow_of_le (by omega) (by omega) + rw [Nat.testBit_lt_two_pow (by omega), Nat.testBit_lt_two_pow (by omega)] + theorem eq_of_getMsbD_eq {x y : BitVec w} (pred : ∀(i : Fin w), x.getMsbD i.val = y.getMsbD i.val) : x = y := by simp only [getMsbD] at pred @@ -201,7 +226,7 @@ theorem eq_of_getMsbD_eq {x y : BitVec w} simpa [q_lt, Nat.sub_sub_self, r] using q -- This cannot be a `@[simp]` lemma, as it would be tried at every term. -theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp +theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp [BitVec.eq_nil x] @[simp] theorem toNat_zero_length (x : BitVec 0) : x.toNat = 0 := by simp [of_length_zero] theorem getLsbD_zero_length (x : BitVec 0) : x.getLsbD i = false := by simp @@ -482,6 +507,20 @@ theorem nat_eq_toNat (x : BitVec w) (y : Nat) getLsbD (zeroExtend m x) i = (decide (i < m) && getLsbD x i) := by simp [getLsbD, toNat_zeroExtend, Nat.testBit_mod_two_pow] +@[simp] theorem getElem_zeroExtend_fin (m : Nat) (x : BitVec n) (i : Fin n) (h : i < m) : + (zeroExtend m x)[i] = x[i] := by + rw [getElem_eq_toNat_testBit] + have rlb := BitVec.getElem_eq_toNat_testBit (zeroExtend m x) ⟨i.val, h⟩ + simp only [Fin.getElem_fin, toNat_truncate, Nat.testBit_mod_two_pow] at rlb + simp [rlb, h] + +@[simp] theorem getElem_zeroExtend_nat (m : Nat) (x : BitVec n) (i : Nat) (h1 : i < n) (h : i < m) : + (zeroExtend m x)[i] = x[i] := by + have rla := BitVec.getElem_eq_toNat_testBit (x) ⟨i, h1⟩ + have rlb := BitVec.getElem_eq_toNat_testBit (zeroExtend m x) ⟨i, h⟩ + simp only [Fin.getElem_fin, toNat_truncate, Nat.testBit_mod_two_pow] at rlb rla + simp [rla, rlb, h] + @[simp] theorem getMsbD_zeroExtend_add {x : BitVec w} (h : k ≤ i) : (x.zeroExtend (w + k)).getMsbD i = x.getMsbD (i - k) := by by_cases h : w = 0 @@ -509,6 +548,18 @@ theorem msb_truncate (x : BitVec w) : (x.truncate (k + 1)).msb = x.getLsbD k := revert p cases getLsbD x i <;> simp; omega +@[simp] theorem zeroExtend_zeroExtend_of_le_getElem_fin (x : BitVec w) (h : k ≤ l) : + (x.zeroExtend l).zeroExtend k = x.zeroExtend k := by + apply eq_of_getElem_eq_fin + intros i + simp [getElem_zeroExtend_fin, h] + +@[simp] theorem zeroExtend_zeroExtend_of_le_getElem_nat (x : BitVec w) (h : k ≤ l) : + (x.zeroExtend l).zeroExtend k = x.zeroExtend k := by + apply eq_of_getElem_eq_nat + intros i _ + simp [getElem_zeroExtend_nat, h] + @[simp] theorem truncate_truncate_of_le (x : BitVec w) (h : k ≤ l) : (x.truncate l).truncate k = x.truncate k := zeroExtend_zeroExtend_of_le x h