diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index dd1db9ba316d..e888a2251b64 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -9,6 +9,7 @@ import Init.Data.Nat.Bitwise.Lemmas import Init.Data.Nat.Power2 import Init.Data.Int.Bitwise import Init.Data.BitVec.BasicAux +import Init.Data.Vector.Basic /-! We define the basic algebraic structure of bitvectors. We choose the `Fin` representation over @@ -760,4 +761,40 @@ def reverse : {w : Nat} → BitVec w → BitVec w | 0, x => x | w + 1, x => concat (reverse (x.truncate w)) (x.msb) +/- ### vectors of bitvectors -/ + +/-- Split a bitvector into a vector of bitvectors of equal length where the vector ends on the +most significant part ('BE' for 'big endian'). See also `BitVec.splitLE`. -/ +def splitBE (m n : Nat) (x : BitVec (m * n)) : Vector (BitVec m) n := + (Vector.range n).map (fun i => x.extractLsb' (m * i) m) + +/-- Split a bitvector into a vector of bitvectors of equal length where the vector ends on the +least significant part ('LE' for 'little endian'). See also `BitVec.splitBE`. -/ +def splitLE (m n : Nat) (x : BitVec (m * n)) : Vector (BitVec m) n := + (Vector.range n).map (fun i => x.extractLsb' (m * (n - i - 1)) m) + end BitVec + +/-- Flatten a `Vector α n` to a `BitVec (m * n)` using a function `f : α → BitVec m`. +The most significant bits are expected to be at the start of the vector. -/ +def Vector.flatMapBitVecLE (m : Nat) (v : Vector α n) (f : α → BitVec m) : BitVec (m * n) := + match n with + | 0 => 0#0 + | n + 1 => ((v.pop.flatMapBitVecLE m f).cast (by simp) : BitVec (m * n)) ++ f v.back + +/-- Flatten a vector of bitvectors of equal length to a single bitvector. +The most significant bits are expected to be at the start of the vector. -/ +def Vector.flattenBitVecLE (m : Nat) (v : Vector (BitVec m) n) : BitVec (m * n) := + v.flatMapBitVecLE m id + +/-- Flatten a `Vector α n` to a `BitVec (m * n)` using a function `f : α → BitVec m`. +The most significant bits are expected to be at the start of the vector. -/ +def Vector.flatMapBitVecBE (m : Nat) (v : Vector α n) (f : α → BitVec m) : BitVec (m * n) := + match n with + | 0 => 0#0 + | n + 1 => ((v.tail.flatMapBitVecBE m f).cast (by simp) : BitVec (m * n)) ++ f v.head + +/-- Flatten a vector of bitvectors of equal length to a single bitvector. +The most significant bits are expected to be at the start of the vector. -/ +def Vector.flattenBitVecBE (m : Nat) (v : Vector (BitVec m) n) : BitVec (m * n) := + v.flatMapBitVecBE m id diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index d80fc889377c..7ccbc08c1b27 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -16,6 +16,7 @@ import Init.Data.Int.Bitwise.Lemmas import Init.Data.Int.LemmasAux import Init.Data.Int.Pow import Init.Data.Int.LemmasAux +import Init.Data.Vector.Lemmas set_option linter.missingDocs true @@ -5127,6 +5128,86 @@ theorem msb_replicate {n w : Nat} {x : BitVec w} : simp only [BitVec.msb, getMsbD_replicate, Nat.zero_mod] cases n <;> cases w <;> simp +/-! ### Vectors of bitvectors -/ + +@[simp] +theorem pop_splitLE : (splitLE m (n + 1) x).pop = splitLE m n (x.extractLsb' m _) := by + ext k hk i hi + simp at hk + have : m * (n - k - 1) + i < m * n := by + apply Nat.lt_of_lt_of_le (Nat.add_lt_add_left hi _) + rw [← Nat.mul_add_one] + apply Nat.mul_le_mul_left + omega + simp only [Nat.add_one_sub_one, splitLE, Vector.range, Vector.map_mk, Vector.pop_mk, + Vector.getElem_mk, Array.getElem_pop, Array.getElem_map, Array.getElem_range, + getElem_extractLsb', this, getLsbD_eq_getElem, ← Nat.add_assoc] + rw [Nat.sub_add_comm (Nat.le_of_lt hk), Nat.sub_add_comm (Nat.le_sub_of_add_le' hk), + Nat.mul_add m (n - k - 1), Nat.add_comm m] + simp + +@[simp] +theorem back_splitLE : (splitLE m (n + 1) x).back = x.extractLsb' 0 m := by + ext i hi + simp [splitLE, Vector.back, Vector.range, Nat.add_sub_cancel_left] + +@[simp] +theorem flattenBitVecLE_splitLE {m n : Nat} {x : BitVec (m * n)} : + (x.splitLE m n).flattenBitVecLE m = x := by + induction n with + | zero => + simp [splitLE, Vector.flattenBitVecLE, Vector.flatMapBitVecLE, of_length_zero] + | succ n ih => + simp only [Vector.flattenBitVecLE, Vector.flatMapBitVecLE, Nat.add_one_sub_one, pop_splitLE, + cast_eq, back_splitLE, id_eq] at ih ⊢ + rw [ih] + ext i hi + by_cases hi' : i < m + · simp [getElem_append, hi', BitVec.getLsbD_eq_getElem hi] + · simp [getElem_append, hi', Nat.add_sub_of_le (Nat.le_of_not_lt hi'), + BitVec.getLsbD_eq_getElem hi] + +theorem getElem_flattenBitVecLE {m n : Nat} (v : Vector (BitVec m) n) {i : Nat} (hi : i < m * n) : + (v.flattenBitVecLE m)[i] = + (v[n - (i / m + 1)]'(Nat.sub_lt (Nat.pos_of_lt_mul_left hi) (i / m).zero_lt_succ))[i % m]' + (Nat.mod_lt i (Nat.pos_of_lt_mul_right hi)) := by + induction n generalizing i with + | zero => + simp at hi + | succ n ih => + simp only [Vector.flattenBitVecLE, Vector.flatMapBitVecLE, Nat.add_one_sub_one, cast_eq, id_eq] + have hi' := hi + rw [Nat.mul_add, Nat.mul_one] at hi + apply (BitVec.getElem_append hi).trans + by_cases him : i < m + · simp [him, Vector.back, Nat.mod_eq_of_lt him, Nat.div_eq_of_lt him] + · simp only [Vector.flattenBitVecLE] at ih + simp only [him, ↓reduceDIte, ih v.pop (i := i - m) (by omega)] + simp [Vector.pop, ← Nat.mod_eq_sub_mod (Nat.ge_of_not_lt him), + ← Nat.div_eq_sub_div (Nat.pos_of_lt_mul_right hi') (Nat.ge_of_not_lt him)] + +theorem getElem_splitLE {i : Nat} (x : BitVec (m * n)) (hi : i < n) : + (x.splitLE m n)[i] = x.extractLsb' (m * (n - i - 1)) m := by + simp [splitLE, Vector.range] + +@[simp] +theorem splitLE_flattenBitVecLE {m n : Nat} {v : Vector (BitVec m) n} : + (v.flattenBitVecLE m).splitLE m n = v := by + ext i hi j hj + have : m * (n - i - 1) + j < m * n := by + apply Nat.lt_of_lt_of_le (Nat.add_lt_add_left hj _) + rw [← Nat.mul_add_one] + apply Nat.mul_le_mul_left + omega + simp only [getElem_splitLE, getElem_extractLsb', getLsbD_eq_getElem this, getElem_flattenBitVecLE, + Nat.mul_add_mod_self_left, Nat.mod_eq_of_lt hj] + simp only [Nat.add_div (Nat.pos_of_lt_mul_right this), + Nat.mul_div_cancel_left _ (Nat.pos_of_lt_mul_right this), Nat.div_eq_of_lt hj, Nat.add_zero, + Nat.mul_mod_right, Nat.zero_add, Nat.not_le_of_gt (m := j % m) (n := m) (Nat.mod_lt_of_lt hj), + ↓reduceIte] + congr + omega + /-! ### Decidable quantifiers -/ theorem forall_zero_iff {P : BitVec 0 → Prop} :