Skip to content

Commit 3607c21

Browse files
SLiV9kornelski
authored andcommitted
Optimize FastReader::read_flag
1 parent 7b25ba8 commit 3607c21

File tree

2 files changed

+90
-20
lines changed

2 files changed

+90
-20
lines changed

src/bool_reader.rs

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,23 @@ impl BoolReader {
158158
// Do not inline this because inlining seems to worsen performance.
159159
#[inline(never)]
160160
pub(crate) fn read_bool(&mut self, probability: u8) -> BitResult<bool> {
161-
if let Some(b) = self.fast().read_bit(probability) {
161+
if let Some(b) = self.fast().read_bool(probability) {
162162
return BitResult::ok(b);
163163
}
164164

165165
self.cold_read_bool(probability)
166166
}
167167

168+
// Do not inline this because inlining seems to worsen performance.
169+
#[inline(never)]
170+
pub(crate) fn read_flag(&mut self) -> BitResult<bool> {
171+
if let Some(b) = self.fast().read_flag() {
172+
return BitResult::ok(b);
173+
}
174+
175+
self.cold_read_flag()
176+
}
177+
168178
// Do not inline this because inlining seems to worsen performance.
169179
#[inline(never)]
170180
pub(crate) fn read_literal(&mut self, n: u8) -> BitResult<u8> {
@@ -206,13 +216,6 @@ impl BoolReader {
206216
self.cold_read_with_tree(tree, usize::from(first_node.index))
207217
}
208218

209-
// This should be inlined to allow it to share the instruction cache with
210-
// `read_bool`, as both functions are short and called often.
211-
#[inline]
212-
pub(crate) fn read_flag(&mut self) -> BitResult<bool> {
213-
self.read_bool(128)
214-
}
215-
216219
// As a similar (but different) speedup to BitResult, the FastReader reads
217220
// bits under an assumption and validates it at the end.
218221
//
@@ -312,15 +315,21 @@ impl BoolReader {
312315
self.cold_read_bit(probability)
313316
}
314317

318+
#[cold]
319+
#[inline(never)]
320+
fn cold_read_flag(&mut self) -> BitResult<bool> {
321+
self.cold_read_bit(128)
322+
}
323+
315324
#[cold]
316325
#[inline(never)]
317326
fn cold_read_literal(&mut self, n: u8) -> BitResult<u8> {
318327
let mut v = 0u8;
319328
let mut res = self.start_accumulated_result();
320329

321330
for _ in 0..n {
322-
let b = self.cold_read_bit(128).or_accumulate(&mut res);
323-
v = (v << 1) + b as u8;
331+
let b = self.cold_read_flag().or_accumulate(&mut res);
332+
v = (v << 1) + u8::from(b);
324333
}
325334

326335
self.keep_accumulating(res, v)
@@ -330,13 +339,13 @@ impl BoolReader {
330339
#[inline(never)]
331340
fn cold_read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> {
332341
let mut res = self.start_accumulated_result();
333-
let flag = self.cold_read_bool(128).or_accumulate(&mut res);
342+
let flag = self.cold_read_flag().or_accumulate(&mut res);
334343
if !flag {
335344
// We should not read further bits if the flag is not set.
336345
return self.keep_accumulating(res, 0);
337346
}
338347
let magnitude = self.cold_read_literal(n).or_accumulate(&mut res);
339-
let sign = self.cold_read_bool(128).or_accumulate(&mut res);
348+
let sign = self.cold_read_flag().or_accumulate(&mut res);
340349

341350
let value = if sign {
342351
-i32::from(magnitude)
@@ -380,24 +389,29 @@ impl FastReader<'_> {
380389
}
381390
}
382391

383-
fn read_bit(mut self, probability: u8) -> Option<bool> {
392+
fn read_bool(mut self, probability: u8) -> Option<bool> {
384393
let bit = self.fast_read_bit(probability);
385394
self.commit_if_valid(bit)
386395
}
387396

397+
fn read_flag(mut self) -> Option<bool> {
398+
let value = self.fast_read_flag();
399+
self.commit_if_valid(value)
400+
}
401+
388402
fn read_literal(mut self, n: u8) -> Option<u8> {
389403
let value = self.fast_read_literal(n);
390404
self.commit_if_valid(value)
391405
}
392406

393407
fn read_optional_signed_value(mut self, n: u8) -> Option<i32> {
394-
let flag = self.fast_read_bit(128);
408+
let flag = self.fast_read_flag();
395409
if !flag {
396410
// We should not read further bits if the flag is not set.
397411
return self.commit_if_valid(0);
398412
}
399413
let magnitude = self.fast_read_literal(n);
400-
let sign = self.fast_read_bit(128);
414+
let sign = self.fast_read_flag();
401415
let value = if sign {
402416
-i32::from(magnitude)
403417
} else {
@@ -467,11 +481,67 @@ impl FastReader<'_> {
467481
retval
468482
}
469483

484+
fn fast_read_flag(&mut self) -> bool {
485+
let State {
486+
mut chunk_index,
487+
mut value,
488+
mut range,
489+
mut bit_count,
490+
} = self.uncommitted_state;
491+
492+
if bit_count < 0 {
493+
let chunk = self.chunks.get(chunk_index).copied();
494+
// We ignore invalid data inside the `fast_` functions,
495+
// but we increase `chunk_index` below, so we can check
496+
// whether we read invalid data in `commit_if_valid`.
497+
let chunk = chunk.unwrap_or_default();
498+
499+
let v = u32::from_be_bytes(chunk);
500+
chunk_index += 1;
501+
value <<= 32;
502+
value |= u64::from(v);
503+
bit_count += 32;
504+
}
505+
debug_assert!(bit_count >= 0);
506+
507+
let half_range = range / 2;
508+
let split = range - half_range;
509+
let bigsplit = u64::from(split) << bit_count;
510+
511+
let retval = if let Some(new_value) = value.checked_sub(bigsplit) {
512+
range = half_range;
513+
value = new_value;
514+
true
515+
} else {
516+
range = split;
517+
false
518+
};
519+
debug_assert!(range > 0);
520+
521+
// Compute shift required to satisfy `range >= 128`.
522+
// Apply that shift to `range` and `self.bitcount`.
523+
//
524+
// Subtract 24 because we only care about leading zeros in the
525+
// lowest byte of `range` which is a `u32`.
526+
let shift = range.leading_zeros().saturating_sub(24);
527+
range <<= shift;
528+
bit_count -= shift as i32;
529+
debug_assert!(range >= 128);
530+
531+
self.uncommitted_state = State {
532+
chunk_index,
533+
value,
534+
range,
535+
bit_count,
536+
};
537+
retval
538+
}
539+
470540
fn fast_read_literal(&mut self, n: u8) -> u8 {
471541
let mut v = 0u8;
472542
for _ in 0..n {
473-
let b = self.fast_read_bit(128);
474-
v = (v << 1) + b as u8;
543+
let b = self.fast_read_flag();
544+
v = (v << 1) + u8::from(b);
475545
}
476546
v
477547
}
@@ -502,7 +572,7 @@ mod tests {
502572
buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]);
503573
reader.init(buf, size).unwrap();
504574
let mut res = reader.start_accumulated_result();
505-
assert_eq!(false, reader.read_bool(128).or_accumulate(&mut res));
575+
assert_eq!(false, reader.read_flag().or_accumulate(&mut res));
506576
assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res));
507577
assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res));
508578
assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res));
@@ -521,7 +591,7 @@ mod tests {
521591
buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]);
522592
reader.init(buf, size).unwrap();
523593
let mut res = reader.start_accumulated_result();
524-
assert_eq!(false, reader.read_bool(128).or_accumulate(&mut res));
594+
assert_eq!(false, reader.read_flag().or_accumulate(&mut res));
525595
assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res));
526596
assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res));
527597
assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res));

src/lossless.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ impl<R: BufRead> LosslessDecoder<R> {
316316
entropy_image = data
317317
.chunks_exact(4)
318318
.map(|pixel| {
319-
let meta_huff_code = u16::from(pixel[0]) << 8 | u16::from(pixel[1]);
319+
let meta_huff_code = (u16::from(pixel[0]) << 8) | u16::from(pixel[1]);
320320
if u32::from(meta_huff_code) >= num_huff_groups {
321321
num_huff_groups = u32::from(meta_huff_code) + 1;
322322
}

0 commit comments

Comments
 (0)