Skip to content

Commit

Permalink
remove redundant bits_per_element property from WaveletMatrix (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
somethingelseentirely authored Sep 10, 2024
1 parent d669432 commit c99a18e
Showing 1 changed file with 51 additions and 47 deletions.
98 changes: 51 additions & 47 deletions src/wavelet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ use std::ops::Range;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WaveletMatrix {
data: Box<[RsVec]>,
bits_per_element: u16,
}

impl WaveletMatrix {
Expand Down Expand Up @@ -124,7 +123,6 @@ impl WaveletMatrix {

Self {
data: data.into_iter().map(BitVec::into).collect(),
bits_per_element,
}
}

Expand Down Expand Up @@ -221,7 +219,6 @@ impl WaveletMatrix {

Self {
data: data.into_iter().map(BitVec::into).collect(),
bits_per_element,
}
}

Expand Down Expand Up @@ -297,7 +294,7 @@ impl WaveletMatrix {
/// The function is used by the `get_value` and `get_u64` functions, deduplicating code.
#[inline(always)]
fn reconstruct_value_unchecked<F: FnMut(u64)>(&self, mut i: usize, mut target_func: F) {
for level in 0..self.bits_per_element as usize {
for level in 0..self.bits_per_element() {
let bit = self.data[level].get_unchecked(i);
target_func(bit);
if bit == 0 {
Expand Down Expand Up @@ -346,8 +343,8 @@ impl WaveletMatrix {
/// [`get_value`]: WaveletMatrix::get_value
#[must_use]
pub fn get_value_unchecked(&self, i: usize) -> BitVec {
let mut value = BitVec::from_zeros(self.bits_per_element as usize);
let mut level = self.bits_per_element - 1;
let mut value = BitVec::from_zeros(self.bits_per_element());
let mut level = self.bits_per_element() - 1;
self.reconstruct_value_unchecked(i, |bit| {
value.set_unchecked(level as usize, bit);
level = level.saturating_sub(1);
Expand All @@ -373,7 +370,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn get_u64(&self, i: usize) -> Option<u64> {
if self.bits_per_element > 64 || self.data.is_empty() || i >= self.data[0].len() {
if self.bits_per_element() > 64 || self.data.is_empty() || i >= self.data[0].len() {
None
} else {
Some(self.get_u64_unchecked(i))
Expand Down Expand Up @@ -419,7 +416,7 @@ impl WaveletMatrix {
#[must_use]
pub fn rank_range_unchecked(&self, mut range: Range<usize>, symbol: &BitVec) -> usize {
for (level, data) in self.data.iter().enumerate() {
if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
range.start = data.rank0(range.start);
range.end = data.rank0(range.end);
} else {
Expand Down Expand Up @@ -456,7 +453,7 @@ impl WaveletMatrix {
pub fn rank_range(&self, range: Range<usize>, symbol: &BitVec) -> Option<usize> {
if range.start >= self.len()
|| range.end > self.len()
|| symbol.len() != self.bits_per_element as usize
|| symbol.len() != self.bits_per_element()
{
None
} else {
Expand All @@ -483,7 +480,7 @@ impl WaveletMatrix {
#[must_use]
pub fn rank_range_u64_unchecked(&self, mut range: Range<usize>, symbol: u64) -> usize {
for (level, data) in self.data.iter().enumerate() {
if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
range.start = data.rank0(range.start);
range.end = data.rank0(range.end);
} else {
Expand Down Expand Up @@ -516,7 +513,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn rank_range_u64(&self, range: Range<usize>, symbol: u64) -> Option<usize> {
if range.start >= self.len() || range.end > self.len() || self.bits_per_element > 64 {
if range.start >= self.len() || range.end > self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.rank_range_u64_unchecked(range, symbol))
Expand Down Expand Up @@ -584,7 +581,7 @@ impl WaveletMatrix {
if offset > i
|| offset >= self.len()
|| i > self.len()
|| symbol.len() != self.bits_per_element as usize
|| symbol.len() != self.bits_per_element()
{
None
} else {
Expand Down Expand Up @@ -644,7 +641,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn rank_offset_u64(&self, offset: usize, i: usize, symbol: u64) -> Option<usize> {
if offset > i || offset >= self.len() || i > self.len() || self.bits_per_element > 64 {
if offset > i || offset >= self.len() || i > self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.rank_offset_u64_unchecked(offset, i, symbol))
Expand Down Expand Up @@ -697,7 +694,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn rank(&self, i: usize, symbol: &BitVec) -> Option<usize> {
if i > self.len() || symbol.len() != self.bits_per_element as usize {
if i > self.len() || symbol.len() != self.bits_per_element() {
None
} else {
Some(self.rank_range_unchecked(0..i, symbol))
Expand Down Expand Up @@ -745,7 +742,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn rank_u64(&self, i: usize, symbol: u64) -> Option<usize> {
if i > self.len() || self.bits_per_element > 64 {
if i > self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.rank_range_u64_unchecked(0..i, symbol))
Expand Down Expand Up @@ -777,7 +774,7 @@ impl WaveletMatrix {
let mut range_start = offset;

for (level, data) in self.data.iter().enumerate() {
if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
range_start = data.rank0(range_start);
} else {
range_start = data.rank0 + data.rank1(range_start);
Expand All @@ -787,7 +784,7 @@ impl WaveletMatrix {
let mut range_end = range_start + rank;

for (level, data) in self.data.iter().enumerate().rev() {
if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
range_end = data.select0(range_end);
} else {
range_end = data.select1(range_end - data.rank0);
Expand Down Expand Up @@ -822,7 +819,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn select_offset(&self, offset: usize, rank: usize, symbol: &BitVec) -> Option<usize> {
if offset >= self.len() || symbol.len() != self.bits_per_element as usize {
if offset >= self.len() || symbol.len() != self.bits_per_element() {
None
} else {
let idx = self.select_offset_unchecked(offset, rank, symbol);
Expand Down Expand Up @@ -857,7 +854,7 @@ impl WaveletMatrix {
let mut range_start = offset;

for (level, data) in self.data.iter().enumerate() {
if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
range_start = data.rank0(range_start);
} else {
range_start = data.rank0 + data.rank1(range_start);
Expand All @@ -867,7 +864,7 @@ impl WaveletMatrix {
let mut range_end = range_start + rank;

for (level, data) in self.data.iter().enumerate().rev() {
if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
range_end = data.select0(range_end);
} else {
range_end = data.select1(range_end - data.rank0);
Expand Down Expand Up @@ -899,7 +896,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn select_offset_u64(&self, offset: usize, rank: usize, symbol: u64) -> Option<usize> {
if offset >= self.len() || self.bits_per_element > 64 {
if offset >= self.len() || self.bits_per_element() > 64 {
None
} else {
let idx = self.select_offset_u64_unchecked(offset, rank, symbol);
Expand Down Expand Up @@ -956,7 +953,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn select(&self, rank: usize, symbol: &BitVec) -> Option<usize> {
if symbol.len() == self.bits_per_element as usize {
if symbol.len() == self.bits_per_element() {
let idx = self.select_unchecked(rank, symbol);
if idx < self.len() {
Some(idx)
Expand Down Expand Up @@ -1008,7 +1005,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn select_u64(&self, rank: usize, symbol: u64) -> Option<usize> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
let idx = self.select_u64_unchecked(rank, symbol);
Expand Down Expand Up @@ -1036,7 +1033,7 @@ impl WaveletMatrix {
/// [`quantile`]: WaveletMatrix::quantile
#[must_use]
pub fn quantile_unchecked(&self, range: Range<usize>, k: usize) -> BitVec {
let result = BitVec::from_zeros(self.bits_per_element as usize);
let result = BitVec::from_zeros(self.bits_per_element());

self.partial_quantile_search_unchecked(range, k, 0, result)
}
Expand All @@ -1054,7 +1051,7 @@ impl WaveletMatrix {
start_level: usize,
mut prefix: BitVec,
) -> BitVec {
debug_assert!(prefix.len() == self.bits_per_element as usize);
debug_assert!(prefix.len() == self.bits_per_element());
debug_assert!(!range.is_empty());
debug_assert!(range.end <= self.len());

Expand All @@ -1070,7 +1067,7 @@ impl WaveletMatrix {
} else {
// the element is among the ones, so we set the bit to 1, and move the range
// into the 1-partition of the next level
prefix.set_unchecked((self.bits_per_element - 1) as usize - level, 1);
prefix.set_unchecked((self.bits_per_element() - 1) - level, 1);
k -= zeros;
range.start = data.rank0 + (range.start - zeros_start); // range.start - zeros_start is the rank1 of range.start
range.end = data.rank0 + (range.end - zeros_end); // same here
Expand Down Expand Up @@ -1183,7 +1180,7 @@ impl WaveletMatrix {
start_level: usize,
mut prefix: u64,
) -> u64 {
debug_assert!(self.bits_per_element <= 64);
debug_assert!(self.bits_per_element() <= 64);
debug_assert!(!range.is_empty());
debug_assert!(range.end <= self.len());

Expand Down Expand Up @@ -1230,7 +1227,7 @@ impl WaveletMatrix {
pub fn quantile_u64(&self, range: Range<usize>, k: usize) -> Option<u64> {
if range.start >= self.len()
|| range.end > self.len()
|| self.bits_per_element > 64
|| self.bits_per_element() > 64
|| k >= range.end - range.start
{
None
Expand Down Expand Up @@ -1274,7 +1271,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn get_sorted_u64(&self, i: usize) -> Option<u64> {
if i >= self.len() || self.bits_per_element > 64 {
if i >= self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.get_sorted_u64_unchecked(i))
Expand Down Expand Up @@ -1547,7 +1544,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn range_median_u64(&self, range: Range<usize>) -> Option<u64> {
if range.is_empty() || self.bits_per_element > 64 || range.end > self.len() {
if range.is_empty() || self.bits_per_element() > 64 || range.end > self.len() {
None
} else {
let k = (range.end - 1 - range.start) / 2;
Expand Down Expand Up @@ -1676,7 +1673,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn predecessor(&self, range: Range<usize>, symbol: &BitVec) -> Option<BitVec> {
if symbol.len() != self.bits_per_element as usize
if symbol.len() != self.bits_per_element()
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1687,10 +1684,10 @@ impl WaveletMatrix {
self.predecessor_generic_unchecked(
range,
symbol,
BitVec::from_zeros(self.bits_per_element as usize),
|level, symbol| symbol.get_unchecked((self.bits_per_element - 1) as usize - level),
BitVec::from_zeros(self.bits_per_element()),
|level, symbol| symbol.get_unchecked((self.bits_per_element() - 1) - level),
|bit, level, result| {
result.set_unchecked((self.bits_per_element - 1) as usize - level, bit);
result.set_unchecked((self.bits_per_element() - 1) - level, bit);
},
Self::partial_quantile_search_unchecked,
)
Expand Down Expand Up @@ -1720,7 +1717,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn predecessor_u64(&self, range: Range<usize>, symbol: u64) -> Option<u64> {
if self.bits_per_element > 64
if self.bits_per_element() > 64
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1732,7 +1729,7 @@ impl WaveletMatrix {
range,
&symbol,
0,
|level, symbol| symbol >> ((self.bits_per_element - 1) as usize - level) & 1,
|level, symbol| symbol >> ((self.bits_per_element() - 1) - level) & 1,
|bit, _level, result| {
// we ignore the level here, and instead rely on the fact that the bits are set in order.
// we have to do that, because the quantile_search_u64 does the same.
Expand Down Expand Up @@ -1863,7 +1860,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn successor(&self, range: Range<usize>, symbol: &BitVec) -> Option<BitVec> {
if symbol.len() != self.bits_per_element as usize
if symbol.len() != self.bits_per_element()
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1874,10 +1871,10 @@ impl WaveletMatrix {
self.successor_generic_unchecked(
range,
symbol,
BitVec::from_zeros(self.bits_per_element as usize),
|level, symbol| symbol.get_unchecked((self.bits_per_element - 1) as usize - level),
BitVec::from_zeros(self.bits_per_element()),
|level, symbol| symbol.get_unchecked((self.bits_per_element() - 1) - level),
|bit, level, result| {
result.set_unchecked((self.bits_per_element - 1) as usize - level, bit);
result.set_unchecked((self.bits_per_element() - 1) - level, bit);
},
Self::partial_quantile_search_unchecked,
)
Expand Down Expand Up @@ -1907,7 +1904,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn successor_u64(&self, range: Range<usize>, symbol: u64) -> Option<u64> {
if self.bits_per_element > 64
if self.bits_per_element() > 64
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1919,7 +1916,7 @@ impl WaveletMatrix {
range,
&symbol,
0,
|level, symbol| symbol >> ((self.bits_per_element - 1) as usize - level) & 1,
|level, symbol| symbol >> ((self.bits_per_element() - 1) - level) & 1,
|bit, _level, result| {
// we ignore the level here, and instead rely on the fact that the bits are set in order.
// we have to do that, because the quantile_search_u64 does the same.
Expand All @@ -1946,7 +1943,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn iter_u64(&self) -> Option<WaveletNumRefIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletNumRefIter::new(self))
Expand All @@ -1958,7 +1955,7 @@ impl WaveletMatrix {
/// If the number of bits per element exceeds 64, `None` is returned.
#[must_use]
pub fn into_iter_u64(self) -> Option<WaveletNumIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletNumIter::new(self))
Expand Down Expand Up @@ -1997,7 +1994,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn iter_sorted_u64(&self) -> Option<WaveletSortedNumRefIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletSortedNumRefIter::new(self))
Expand All @@ -2009,7 +2006,7 @@ impl WaveletMatrix {
/// If the number of bits per element exceeds 64, `None` is returned.
#[must_use]
pub fn into_iter_sorted_u64(self) -> Option<WaveletSortedNumIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletSortedNumIter::new(self))
Expand All @@ -2018,8 +2015,15 @@ impl WaveletMatrix {

/// Get the number of bits per element in the alphabet of the encoded sequence.
#[must_use]
#[inline(always)]
pub fn bits_per_element(&self) -> usize {
self.data.len()
}

/// Get the number of bits per element in the alphabet of the encoded sequence.
#[deprecated(since = "1.5.1", note = "please use `bits_per_element` instead")]
pub fn bit_len(&self) -> u16 {
self.bits_per_element
self.bits_per_element() as u16
}

/// Get the number of elements stored in the encoded sequence.
Expand Down

0 comments on commit c99a18e

Please sign in to comment.