From c99a18e861346ad14fc03499aaa74a9c1b198170 Mon Sep 17 00:00:00 2001
From: somethingelseentirely
 <74891396+somethingelseentirely@users.noreply.github.com>
Date: Tue, 10 Sep 2024 10:41:05 +0200
Subject: [PATCH] remove redundant bits_per_element property from WaveletMatrix
 (#12)

---
 src/wavelet/mod.rs | 98 ++++++++++++++++++++++++----------------------
 1 file changed, 51 insertions(+), 47 deletions(-)

diff --git a/src/wavelet/mod.rs b/src/wavelet/mod.rs
index 7d65ab2..1baa313 100644
--- a/src/wavelet/mod.rs
+++ b/src/wavelet/mod.rs
@@ -66,7 +66,6 @@ use std::ops::Range;
 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
 pub struct WaveletMatrix {
     data: Box<[RsVec]>,
-    bits_per_element: u16,
 }
 
 impl WaveletMatrix {
@@ -124,7 +123,6 @@ impl WaveletMatrix {
 
         Self {
             data: data.into_iter().map(BitVec::into).collect(),
-            bits_per_element,
         }
     }
 
@@ -221,7 +219,6 @@ impl WaveletMatrix {
 
         Self {
             data: data.into_iter().map(BitVec::into).collect(),
-            bits_per_element,
         }
     }
 
@@ -297,7 +294,7 @@ impl WaveletMatrix {
     /// The function is used by the `get_value` and `get_u64` functions, deduplicating code.
     #[inline(always)]
     fn reconstruct_value_unchecked<F: FnMut(u64)>(&self, mut i: usize, mut target_func: F) {
-        for level in 0..self.bits_per_element as usize {
+        for level in 0..self.bits_per_element() {
             let bit = self.data[level].get_unchecked(i);
             target_func(bit);
             if bit == 0 {
@@ -346,8 +343,8 @@ impl WaveletMatrix {
     /// [`get_value`]: WaveletMatrix::get_value
     #[must_use]
     pub fn get_value_unchecked(&self, i: usize) -> BitVec {
-        let mut value = BitVec::from_zeros(self.bits_per_element as usize);
-        let mut level = self.bits_per_element - 1;
+        let mut value = BitVec::from_zeros(self.bits_per_element());
+        let mut level = self.bits_per_element() - 1;
         self.reconstruct_value_unchecked(i, |bit| {
             value.set_unchecked(level as usize, bit);
             level = level.saturating_sub(1);
@@ -373,7 +370,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn get_u64(&self, i: usize) -> Option<u64> {
-        if self.bits_per_element > 64 || self.data.is_empty() || i >= self.data[0].len() {
+        if self.bits_per_element() > 64 || self.data.is_empty() || i >= self.data[0].len() {
             None
         } else {
             Some(self.get_u64_unchecked(i))
@@ -419,7 +416,7 @@ impl WaveletMatrix {
     #[must_use]
     pub fn rank_range_unchecked(&self, mut range: Range<usize>, symbol: &BitVec) -> usize {
         for (level, data) in self.data.iter().enumerate() {
-            if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
+            if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
                 range.start = data.rank0(range.start);
                 range.end = data.rank0(range.end);
             } else {
@@ -456,7 +453,7 @@ impl WaveletMatrix {
     pub fn rank_range(&self, range: Range<usize>, symbol: &BitVec) -> Option<usize> {
         if range.start >= self.len()
             || range.end > self.len()
-            || symbol.len() != self.bits_per_element as usize
+            || symbol.len() != self.bits_per_element()
         {
             None
         } else {
@@ -483,7 +480,7 @@ impl WaveletMatrix {
     #[must_use]
     pub fn rank_range_u64_unchecked(&self, mut range: Range<usize>, symbol: u64) -> usize {
         for (level, data) in self.data.iter().enumerate() {
-            if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
+            if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
                 range.start = data.rank0(range.start);
                 range.end = data.rank0(range.end);
             } else {
@@ -516,7 +513,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn rank_range_u64(&self, range: Range<usize>, symbol: u64) -> Option<usize> {
-        if range.start >= self.len() || range.end > self.len() || self.bits_per_element > 64 {
+        if range.start >= self.len() || range.end > self.len() || self.bits_per_element() > 64 {
             None
         } else {
             Some(self.rank_range_u64_unchecked(range, symbol))
@@ -584,7 +581,7 @@ impl WaveletMatrix {
         if offset > i
             || offset >= self.len()
             || i > self.len()
-            || symbol.len() != self.bits_per_element as usize
+            || symbol.len() != self.bits_per_element()
         {
             None
         } else {
@@ -644,7 +641,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn rank_offset_u64(&self, offset: usize, i: usize, symbol: u64) -> Option<usize> {
-        if offset > i || offset >= self.len() || i > self.len() || self.bits_per_element > 64 {
+        if offset > i || offset >= self.len() || i > self.len() || self.bits_per_element() > 64 {
             None
         } else {
             Some(self.rank_offset_u64_unchecked(offset, i, symbol))
@@ -697,7 +694,7 @@ impl WaveletMatrix {
     /// [`BitVec`]: BitVec
     #[must_use]
     pub fn rank(&self, i: usize, symbol: &BitVec) -> Option<usize> {
-        if i > self.len() || symbol.len() != self.bits_per_element as usize {
+        if i > self.len() || symbol.len() != self.bits_per_element() {
             None
         } else {
             Some(self.rank_range_unchecked(0..i, symbol))
@@ -745,7 +742,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn rank_u64(&self, i: usize, symbol: u64) -> Option<usize> {
-        if i > self.len() || self.bits_per_element > 64 {
+        if i > self.len() || self.bits_per_element() > 64 {
             None
         } else {
             Some(self.rank_range_u64_unchecked(0..i, symbol))
@@ -777,7 +774,7 @@ impl WaveletMatrix {
         let mut range_start = offset;
 
         for (level, data) in self.data.iter().enumerate() {
-            if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
+            if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
                 range_start = data.rank0(range_start);
             } else {
                 range_start = data.rank0 + data.rank1(range_start);
@@ -787,7 +784,7 @@ impl WaveletMatrix {
         let mut range_end = range_start + rank;
 
         for (level, data) in self.data.iter().enumerate().rev() {
-            if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
+            if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
                 range_end = data.select0(range_end);
             } else {
                 range_end = data.select1(range_end - data.rank0);
@@ -822,7 +819,7 @@ impl WaveletMatrix {
     /// [`BitVec`]: BitVec
     #[must_use]
     pub fn select_offset(&self, offset: usize, rank: usize, symbol: &BitVec) -> Option<usize> {
-        if offset >= self.len() || symbol.len() != self.bits_per_element as usize {
+        if offset >= self.len() || symbol.len() != self.bits_per_element() {
             None
         } else {
             let idx = self.select_offset_unchecked(offset, rank, symbol);
@@ -857,7 +854,7 @@ impl WaveletMatrix {
         let mut range_start = offset;
 
         for (level, data) in self.data.iter().enumerate() {
-            if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
+            if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
                 range_start = data.rank0(range_start);
             } else {
                 range_start = data.rank0 + data.rank1(range_start);
@@ -867,7 +864,7 @@ impl WaveletMatrix {
         let mut range_end = range_start + rank;
 
         for (level, data) in self.data.iter().enumerate().rev() {
-            if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
+            if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
                 range_end = data.select0(range_end);
             } else {
                 range_end = data.select1(range_end - data.rank0);
@@ -899,7 +896,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn select_offset_u64(&self, offset: usize, rank: usize, symbol: u64) -> Option<usize> {
-        if offset >= self.len() || self.bits_per_element > 64 {
+        if offset >= self.len() || self.bits_per_element() > 64 {
             None
         } else {
             let idx = self.select_offset_u64_unchecked(offset, rank, symbol);
@@ -956,7 +953,7 @@ impl WaveletMatrix {
     /// [`BitVec`]: BitVec
     #[must_use]
     pub fn select(&self, rank: usize, symbol: &BitVec) -> Option<usize> {
-        if symbol.len() == self.bits_per_element as usize {
+        if symbol.len() == self.bits_per_element() {
             let idx = self.select_unchecked(rank, symbol);
             if idx < self.len() {
                 Some(idx)
@@ -1008,7 +1005,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn select_u64(&self, rank: usize, symbol: u64) -> Option<usize> {
-        if self.bits_per_element > 64 {
+        if self.bits_per_element() > 64 {
             None
         } else {
             let idx = self.select_u64_unchecked(rank, symbol);
@@ -1036,7 +1033,7 @@ impl WaveletMatrix {
     /// [`quantile`]: WaveletMatrix::quantile
     #[must_use]
     pub fn quantile_unchecked(&self, range: Range<usize>, k: usize) -> BitVec {
-        let result = BitVec::from_zeros(self.bits_per_element as usize);
+        let result = BitVec::from_zeros(self.bits_per_element());
 
         self.partial_quantile_search_unchecked(range, k, 0, result)
     }
@@ -1054,7 +1051,7 @@ impl WaveletMatrix {
         start_level: usize,
         mut prefix: BitVec,
     ) -> BitVec {
-        debug_assert!(prefix.len() == self.bits_per_element as usize);
+        debug_assert!(prefix.len() == self.bits_per_element());
         debug_assert!(!range.is_empty());
         debug_assert!(range.end <= self.len());
 
@@ -1070,7 +1067,7 @@ impl WaveletMatrix {
             } else {
                 // the element is among the ones, so we set the bit to 1, and move the range
                 // into the 1-partition of the next level
-                prefix.set_unchecked((self.bits_per_element - 1) as usize - level, 1);
+                prefix.set_unchecked((self.bits_per_element() - 1) - level, 1);
                 k -= zeros;
                 range.start = data.rank0 + (range.start - zeros_start); // range.start - zeros_start is the rank1 of range.start
                 range.end = data.rank0 + (range.end - zeros_end); // same here
@@ -1183,7 +1180,7 @@ impl WaveletMatrix {
         start_level: usize,
         mut prefix: u64,
     ) -> u64 {
-        debug_assert!(self.bits_per_element <= 64);
+        debug_assert!(self.bits_per_element() <= 64);
         debug_assert!(!range.is_empty());
         debug_assert!(range.end <= self.len());
 
@@ -1230,7 +1227,7 @@ impl WaveletMatrix {
     pub fn quantile_u64(&self, range: Range<usize>, k: usize) -> Option<u64> {
         if range.start >= self.len()
             || range.end > self.len()
-            || self.bits_per_element > 64
+            || self.bits_per_element() > 64
             || k >= range.end - range.start
         {
             None
@@ -1274,7 +1271,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn get_sorted_u64(&self, i: usize) -> Option<u64> {
-        if i >= self.len() || self.bits_per_element > 64 {
+        if i >= self.len() || self.bits_per_element() > 64 {
             None
         } else {
             Some(self.get_sorted_u64_unchecked(i))
@@ -1547,7 +1544,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn range_median_u64(&self, range: Range<usize>) -> Option<u64> {
-        if range.is_empty() || self.bits_per_element > 64 || range.end > self.len() {
+        if range.is_empty() || self.bits_per_element() > 64 || range.end > self.len() {
             None
         } else {
             let k = (range.end - 1 - range.start) / 2;
@@ -1676,7 +1673,7 @@ impl WaveletMatrix {
     /// [`BitVec`]: BitVec
     #[must_use]
     pub fn predecessor(&self, range: Range<usize>, symbol: &BitVec) -> Option<BitVec> {
-        if symbol.len() != self.bits_per_element as usize
+        if symbol.len() != self.bits_per_element()
             || range.is_empty()
             || self.is_empty()
             || range.end > self.len()
@@ -1687,10 +1684,10 @@ impl WaveletMatrix {
         self.predecessor_generic_unchecked(
             range,
             symbol,
-            BitVec::from_zeros(self.bits_per_element as usize),
-            |level, symbol| symbol.get_unchecked((self.bits_per_element - 1) as usize - level),
+            BitVec::from_zeros(self.bits_per_element()),
+            |level, symbol| symbol.get_unchecked((self.bits_per_element() - 1) - level),
             |bit, level, result| {
-                result.set_unchecked((self.bits_per_element - 1) as usize - level, bit);
+                result.set_unchecked((self.bits_per_element() - 1) - level, bit);
             },
             Self::partial_quantile_search_unchecked,
         )
@@ -1720,7 +1717,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn predecessor_u64(&self, range: Range<usize>, symbol: u64) -> Option<u64> {
-        if self.bits_per_element > 64
+        if self.bits_per_element() > 64
             || range.is_empty()
             || self.is_empty()
             || range.end > self.len()
@@ -1732,7 +1729,7 @@ impl WaveletMatrix {
             range,
             &symbol,
             0,
-            |level, symbol| symbol >> ((self.bits_per_element - 1) as usize - level) & 1,
+            |level, symbol| symbol >> ((self.bits_per_element() - 1) - level) & 1,
             |bit, _level, result| {
                 // we ignore the level here, and instead rely on the fact that the bits are set in order.
                 // we have to do that, because the quantile_search_u64 does the same.
@@ -1863,7 +1860,7 @@ impl WaveletMatrix {
     /// [`BitVec`]: BitVec
     #[must_use]
     pub fn successor(&self, range: Range<usize>, symbol: &BitVec) -> Option<BitVec> {
-        if symbol.len() != self.bits_per_element as usize
+        if symbol.len() != self.bits_per_element()
             || range.is_empty()
             || self.is_empty()
             || range.end > self.len()
@@ -1874,10 +1871,10 @@ impl WaveletMatrix {
         self.successor_generic_unchecked(
             range,
             symbol,
-            BitVec::from_zeros(self.bits_per_element as usize),
-            |level, symbol| symbol.get_unchecked((self.bits_per_element - 1) as usize - level),
+            BitVec::from_zeros(self.bits_per_element()),
+            |level, symbol| symbol.get_unchecked((self.bits_per_element() - 1) - level),
             |bit, level, result| {
-                result.set_unchecked((self.bits_per_element - 1) as usize - level, bit);
+                result.set_unchecked((self.bits_per_element() - 1) - level, bit);
             },
             Self::partial_quantile_search_unchecked,
         )
@@ -1907,7 +1904,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn successor_u64(&self, range: Range<usize>, symbol: u64) -> Option<u64> {
-        if self.bits_per_element > 64
+        if self.bits_per_element() > 64
             || range.is_empty()
             || self.is_empty()
             || range.end > self.len()
@@ -1919,7 +1916,7 @@ impl WaveletMatrix {
             range,
             &symbol,
             0,
-            |level, symbol| symbol >> ((self.bits_per_element - 1) as usize - level) & 1,
+            |level, symbol| symbol >> ((self.bits_per_element() - 1) - level) & 1,
             |bit, _level, result| {
                 // we ignore the level here, and instead rely on the fact that the bits are set in order.
                 // we have to do that, because the quantile_search_u64 does the same.
@@ -1946,7 +1943,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn iter_u64(&self) -> Option<WaveletNumRefIter> {
-        if self.bits_per_element > 64 {
+        if self.bits_per_element() > 64 {
             None
         } else {
             Some(WaveletNumRefIter::new(self))
@@ -1958,7 +1955,7 @@ impl WaveletMatrix {
     /// If the number of bits per element exceeds 64, `None` is returned.
     #[must_use]
     pub fn into_iter_u64(self) -> Option<WaveletNumIter> {
-        if self.bits_per_element > 64 {
+        if self.bits_per_element() > 64 {
             None
         } else {
             Some(WaveletNumIter::new(self))
@@ -1997,7 +1994,7 @@ impl WaveletMatrix {
     /// ```
     #[must_use]
     pub fn iter_sorted_u64(&self) -> Option<WaveletSortedNumRefIter> {
-        if self.bits_per_element > 64 {
+        if self.bits_per_element() > 64 {
             None
         } else {
             Some(WaveletSortedNumRefIter::new(self))
@@ -2009,7 +2006,7 @@ impl WaveletMatrix {
     /// If the number of bits per element exceeds 64, `None` is returned.
     #[must_use]
     pub fn into_iter_sorted_u64(self) -> Option<WaveletSortedNumIter> {
-        if self.bits_per_element > 64 {
+        if self.bits_per_element() > 64 {
             None
         } else {
             Some(WaveletSortedNumIter::new(self))
@@ -2018,8 +2015,15 @@ impl WaveletMatrix {
 
     /// Get the number of bits per element in the alphabet of the encoded sequence.
     #[must_use]
+    #[inline(always)]
+    pub fn bits_per_element(&self) -> usize {
+        self.data.len()
+    }
+
+    /// Get the number of bits per element in the alphabet of the encoded sequence.
+    #[deprecated(since = "1.5.1", note = "please use `bits_per_element` instead")]
     pub fn bit_len(&self) -> u16 {
-        self.bits_per_element
+        self.bits_per_element() as u16
     }
 
     /// Get the number of elements stored in the encoded sequence.