Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Init.Data.Nat.Bitwise.Lemmas
import Init.Data.Nat.Power2
import Init.Data.Int.Bitwise
import Init.Data.BitVec.BasicAux
import Init.Data.Vector.Basic

/-!
We define the basic algebraic structure of bitvectors. We choose the `Fin` representation over
Expand Down Expand Up @@ -760,4 +761,40 @@ def reverse : {w : Nat} → BitVec w → BitVec w
| 0, x => x
| w + 1, x => concat (reverse (x.truncate w)) (x.msb)

/- ### vectors of bitvectors -/

/-- Split a bitvector into a vector of bitvectors of equal length where the vector ends on the
most significant part ('BE' for 'big endian'). See also `BitVec.splitLE`. -/
def splitBE (m n : Nat) (x : BitVec (m * n)) : Vector (BitVec m) n :=
(Vector.range n).map (fun i => x.extractLsb' (m * i) m)

/-- Split a bitvector into a vector of bitvectors of equal length where the vector ends on the
least significant part ('LE' for 'little endian'). See also `BitVec.splitBE`. -/
def splitLE (m n : Nat) (x : BitVec (m * n)) : Vector (BitVec m) n :=
(Vector.range n).map (fun i => x.extractLsb' (m * (n - i - 1)) m)

end BitVec

/-- Flatten a `Vector α n` to a `BitVec (m * n)` using a function `f : α → BitVec m`.
The most significant bits are expected to be at the start of the vector. -/
def Vector.flatMapBitVecLE (m : Nat) (v : Vector α n) (f : α → BitVec m) : BitVec (m * n) :=
match n with
| 0 => 0#0
| n + 1 => ((v.pop.flatMapBitVecLE m f).cast (by simp) : BitVec (m * n)) ++ f v.back

/-- Flatten a vector of bitvectors of equal length to a single bitvector.
The most significant bits are expected to be at the start of the vector. -/
def Vector.flattenBitVecLE (m : Nat) (v : Vector (BitVec m) n) : BitVec (m * n) :=
v.flatMapBitVecLE m id

/-- Flatten a `Vector α n` to a `BitVec (m * n)` using a function `f : α → BitVec m`.
The most significant bits are expected to be at the start of the vector. -/
def Vector.flatMapBitVecBE (m : Nat) (v : Vector α n) (f : α → BitVec m) : BitVec (m * n) :=
match n with
| 0 => 0#0
| n + 1 => ((v.tail.flatMapBitVecBE m f).cast (by simp) : BitVec (m * n)) ++ f v.head

/-- Flatten a vector of bitvectors of equal length to a single bitvector.
The most significant bits are expected to be at the start of the vector. -/
def Vector.flattenBitVecBE (m : Nat) (v : Vector (BitVec m) n) : BitVec (m * n) :=
v.flatMapBitVecBE m id
81 changes: 81 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import Init.Data.Int.Bitwise.Lemmas
import Init.Data.Int.LemmasAux
import Init.Data.Int.Pow
import Init.Data.Int.LemmasAux
import Init.Data.Vector.Lemmas

set_option linter.missingDocs true

Expand Down Expand Up @@ -5127,6 +5128,86 @@ theorem msb_replicate {n w : Nat} {x : BitVec w} :
simp only [BitVec.msb, getMsbD_replicate, Nat.zero_mod]
cases n <;> cases w <;> simp

/-! ### Vectors of bitvectors -/

@[simp]
theorem pop_splitLE : (splitLE m (n + 1) x).pop = splitLE m n (x.extractLsb' m _) := by
ext k hk i hi
simp at hk
have : m * (n - k - 1) + i < m * n := by
apply Nat.lt_of_lt_of_le (Nat.add_lt_add_left hi _)
rw [← Nat.mul_add_one]
apply Nat.mul_le_mul_left
omega
simp only [Nat.add_one_sub_one, splitLE, Vector.range, Vector.map_mk, Vector.pop_mk,
Vector.getElem_mk, Array.getElem_pop, Array.getElem_map, Array.getElem_range,
getElem_extractLsb', this, getLsbD_eq_getElem, ← Nat.add_assoc]
rw [Nat.sub_add_comm (Nat.le_of_lt hk), Nat.sub_add_comm (Nat.le_sub_of_add_le' hk),
Nat.mul_add m (n - k - 1), Nat.add_comm m]
simp

@[simp]
theorem back_splitLE : (splitLE m (n + 1) x).back = x.extractLsb' 0 m := by
ext i hi
simp [splitLE, Vector.back, Vector.range, Nat.add_sub_cancel_left]

@[simp]
theorem flattenBitVecLE_splitLE {m n : Nat} {x : BitVec (m * n)} :
(x.splitLE m n).flattenBitVecLE m = x := by
induction n with
| zero =>
simp [splitLE, Vector.flattenBitVecLE, Vector.flatMapBitVecLE, of_length_zero]
| succ n ih =>
simp only [Vector.flattenBitVecLE, Vector.flatMapBitVecLE, Nat.add_one_sub_one, pop_splitLE,
cast_eq, back_splitLE, id_eq] at ih ⊢
rw [ih]
ext i hi
by_cases hi' : i < m
· simp [getElem_append, hi', BitVec.getLsbD_eq_getElem hi]
· simp [getElem_append, hi', Nat.add_sub_of_le (Nat.le_of_not_lt hi'),
BitVec.getLsbD_eq_getElem hi]

theorem getElem_flattenBitVecLE {m n : Nat} (v : Vector (BitVec m) n) {i : Nat} (hi : i < m * n) :
(v.flattenBitVecLE m)[i] =
(v[n - (i / m + 1)]'(Nat.sub_lt (Nat.pos_of_lt_mul_left hi) (i / m).zero_lt_succ))[i % m]'
(Nat.mod_lt i (Nat.pos_of_lt_mul_right hi)) := by
induction n generalizing i with
| zero =>
simp at hi
| succ n ih =>
simp only [Vector.flattenBitVecLE, Vector.flatMapBitVecLE, Nat.add_one_sub_one, cast_eq, id_eq]
have hi' := hi
rw [Nat.mul_add, Nat.mul_one] at hi
apply (BitVec.getElem_append hi).trans
by_cases him : i < m
· simp [him, Vector.back, Nat.mod_eq_of_lt him, Nat.div_eq_of_lt him]
· simp only [Vector.flattenBitVecLE] at ih
simp only [him, ↓reduceDIte, ih v.pop (i := i - m) (by omega)]
simp [Vector.pop, ← Nat.mod_eq_sub_mod (Nat.ge_of_not_lt him),
← Nat.div_eq_sub_div (Nat.pos_of_lt_mul_right hi') (Nat.ge_of_not_lt him)]

theorem getElem_splitLE {i : Nat} (x : BitVec (m * n)) (hi : i < n) :
(x.splitLE m n)[i] = x.extractLsb' (m * (n - i - 1)) m := by
simp [splitLE, Vector.range]

@[simp]
theorem splitLE_flattenBitVecLE {m n : Nat} {v : Vector (BitVec m) n} :
(v.flattenBitVecLE m).splitLE m n = v := by
ext i hi j hj
have : m * (n - i - 1) + j < m * n := by
apply Nat.lt_of_lt_of_le (Nat.add_lt_add_left hj _)
rw [← Nat.mul_add_one]
apply Nat.mul_le_mul_left
omega
simp only [getElem_splitLE, getElem_extractLsb', getLsbD_eq_getElem this, getElem_flattenBitVecLE,
Nat.mul_add_mod_self_left, Nat.mod_eq_of_lt hj]
simp only [Nat.add_div (Nat.pos_of_lt_mul_right this),
Nat.mul_div_cancel_left _ (Nat.pos_of_lt_mul_right this), Nat.div_eq_of_lt hj, Nat.add_zero,
Nat.mul_mod_right, Nat.zero_add, Nat.not_le_of_gt (m := j % m) (n := m) (Nat.mod_lt_of_lt hj),
↓reduceIte]
congr
omega

/-! ### Decidable quantifiers -/

theorem forall_zero_iff {P : BitVec 0 → Prop} :
Expand Down
Loading