@@ -158,13 +158,23 @@ impl BoolReader {
158
158
// Do not inline this because inlining seems to worsen performance.
159
159
#[ inline( never) ]
160
160
pub ( crate ) fn read_bool ( & mut self , probability : u8 ) -> BitResult < bool > {
161
- if let Some ( b) = self . fast ( ) . read_bit ( probability) {
161
+ if let Some ( b) = self . fast ( ) . read_bool ( probability) {
162
162
return BitResult :: ok ( b) ;
163
163
}
164
164
165
165
self . cold_read_bool ( probability)
166
166
}
167
167
168
+ // Do not inline this because inlining seems to worsen performance.
169
+ #[ inline( never) ]
170
+ pub ( crate ) fn read_flag ( & mut self ) -> BitResult < bool > {
171
+ if let Some ( b) = self . fast ( ) . read_flag ( ) {
172
+ return BitResult :: ok ( b) ;
173
+ }
174
+
175
+ self . cold_read_flag ( )
176
+ }
177
+
168
178
// Do not inline this because inlining seems to worsen performance.
169
179
#[ inline( never) ]
170
180
pub ( crate ) fn read_literal ( & mut self , n : u8 ) -> BitResult < u8 > {
@@ -206,13 +216,6 @@ impl BoolReader {
206
216
self . cold_read_with_tree ( tree, usize:: from ( first_node. index ) )
207
217
}
208
218
209
- // This should be inlined to allow it to share the instruction cache with
210
- // `read_bool`, as both functions are short and called often.
211
- #[ inline]
212
- pub ( crate ) fn read_flag ( & mut self ) -> BitResult < bool > {
213
- self . read_bool ( 128 )
214
- }
215
-
216
219
// As a similar (but different) speedup to BitResult, the FastReader reads
217
220
// bits under an assumption and validates it at the end.
218
221
//
@@ -312,15 +315,21 @@ impl BoolReader {
312
315
self . cold_read_bit ( probability)
313
316
}
314
317
318
+ #[ cold]
319
+ #[ inline( never) ]
320
+ fn cold_read_flag ( & mut self ) -> BitResult < bool > {
321
+ self . cold_read_bit ( 128 )
322
+ }
323
+
315
324
#[ cold]
316
325
#[ inline( never) ]
317
326
fn cold_read_literal ( & mut self , n : u8 ) -> BitResult < u8 > {
318
327
let mut v = 0u8 ;
319
328
let mut res = self . start_accumulated_result ( ) ;
320
329
321
330
for _ in 0 ..n {
322
- let b = self . cold_read_bit ( 128 ) . or_accumulate ( & mut res) ;
323
- v = ( v << 1 ) + b as u8 ;
331
+ let b = self . cold_read_flag ( ) . or_accumulate ( & mut res) ;
332
+ v = ( v << 1 ) + u8 :: from ( b ) ;
324
333
}
325
334
326
335
self . keep_accumulating ( res, v)
@@ -330,13 +339,13 @@ impl BoolReader {
330
339
#[ inline( never) ]
331
340
fn cold_read_optional_signed_value ( & mut self , n : u8 ) -> BitResult < i32 > {
332
341
let mut res = self . start_accumulated_result ( ) ;
333
- let flag = self . cold_read_bool ( 128 ) . or_accumulate ( & mut res) ;
342
+ let flag = self . cold_read_flag ( ) . or_accumulate ( & mut res) ;
334
343
if !flag {
335
344
// We should not read further bits if the flag is not set.
336
345
return self . keep_accumulating ( res, 0 ) ;
337
346
}
338
347
let magnitude = self . cold_read_literal ( n) . or_accumulate ( & mut res) ;
339
- let sign = self . cold_read_bool ( 128 ) . or_accumulate ( & mut res) ;
348
+ let sign = self . cold_read_flag ( ) . or_accumulate ( & mut res) ;
340
349
341
350
let value = if sign {
342
351
-i32:: from ( magnitude)
@@ -380,24 +389,29 @@ impl FastReader<'_> {
380
389
}
381
390
}
382
391
383
- fn read_bit ( mut self , probability : u8 ) -> Option < bool > {
392
+ fn read_bool ( mut self , probability : u8 ) -> Option < bool > {
384
393
let bit = self . fast_read_bit ( probability) ;
385
394
self . commit_if_valid ( bit)
386
395
}
387
396
397
+ fn read_flag ( mut self ) -> Option < bool > {
398
+ let value = self . fast_read_flag ( ) ;
399
+ self . commit_if_valid ( value)
400
+ }
401
+
388
402
fn read_literal ( mut self , n : u8 ) -> Option < u8 > {
389
403
let value = self . fast_read_literal ( n) ;
390
404
self . commit_if_valid ( value)
391
405
}
392
406
393
407
fn read_optional_signed_value ( mut self , n : u8 ) -> Option < i32 > {
394
- let flag = self . fast_read_bit ( 128 ) ;
408
+ let flag = self . fast_read_flag ( ) ;
395
409
if !flag {
396
410
// We should not read further bits if the flag is not set.
397
411
return self . commit_if_valid ( 0 ) ;
398
412
}
399
413
let magnitude = self . fast_read_literal ( n) ;
400
- let sign = self . fast_read_bit ( 128 ) ;
414
+ let sign = self . fast_read_flag ( ) ;
401
415
let value = if sign {
402
416
-i32:: from ( magnitude)
403
417
} else {
@@ -467,11 +481,67 @@ impl FastReader<'_> {
467
481
retval
468
482
}
469
483
484
+ fn fast_read_flag ( & mut self ) -> bool {
485
+ let State {
486
+ mut chunk_index,
487
+ mut value,
488
+ mut range,
489
+ mut bit_count,
490
+ } = self . uncommitted_state ;
491
+
492
+ if bit_count < 0 {
493
+ let chunk = self . chunks . get ( chunk_index) . copied ( ) ;
494
+ // We ignore invalid data inside the `fast_` functions,
495
+ // but we increase `chunk_index` below, so we can check
496
+ // whether we read invalid data in `commit_if_valid`.
497
+ let chunk = chunk. unwrap_or_default ( ) ;
498
+
499
+ let v = u32:: from_be_bytes ( chunk) ;
500
+ chunk_index += 1 ;
501
+ value <<= 32 ;
502
+ value |= u64:: from ( v) ;
503
+ bit_count += 32 ;
504
+ }
505
+ debug_assert ! ( bit_count >= 0 ) ;
506
+
507
+ let half_range = range / 2 ;
508
+ let split = range - half_range;
509
+ let bigsplit = u64:: from ( split) << bit_count;
510
+
511
+ let retval = if let Some ( new_value) = value. checked_sub ( bigsplit) {
512
+ range = half_range;
513
+ value = new_value;
514
+ true
515
+ } else {
516
+ range = split;
517
+ false
518
+ } ;
519
+ debug_assert ! ( range > 0 ) ;
520
+
521
+ // Compute shift required to satisfy `range >= 128`.
522
+ // Apply that shift to `range` and `self.bitcount`.
523
+ //
524
+ // Subtract 24 because we only care about leading zeros in the
525
+ // lowest byte of `range` which is a `u32`.
526
+ let shift = range. leading_zeros ( ) . saturating_sub ( 24 ) ;
527
+ range <<= shift;
528
+ bit_count -= shift as i32 ;
529
+ debug_assert ! ( range >= 128 ) ;
530
+
531
+ self . uncommitted_state = State {
532
+ chunk_index,
533
+ value,
534
+ range,
535
+ bit_count,
536
+ } ;
537
+ retval
538
+ }
539
+
470
540
fn fast_read_literal ( & mut self , n : u8 ) -> u8 {
471
541
let mut v = 0u8 ;
472
542
for _ in 0 ..n {
473
- let b = self . fast_read_bit ( 128 ) ;
474
- v = ( v << 1 ) + b as u8 ;
543
+ let b = self . fast_read_flag ( ) ;
544
+ v = ( v << 1 ) + u8 :: from ( b ) ;
475
545
}
476
546
v
477
547
}
@@ -502,7 +572,7 @@ mod tests {
502
572
buf. as_mut_slice ( ) . as_flattened_mut ( ) [ ..size] . copy_from_slice ( & data[ ..] ) ;
503
573
reader. init ( buf, size) . unwrap ( ) ;
504
574
let mut res = reader. start_accumulated_result ( ) ;
505
- assert_eq ! ( false , reader. read_bool ( 128 ) . or_accumulate( & mut res) ) ;
575
+ assert_eq ! ( false , reader. read_flag ( ) . or_accumulate( & mut res) ) ;
506
576
assert_eq ! ( true , reader. read_bool( 10 ) . or_accumulate( & mut res) ) ;
507
577
assert_eq ! ( false , reader. read_bool( 250 ) . or_accumulate( & mut res) ) ;
508
578
assert_eq ! ( 1 , reader. read_literal( 1 ) . or_accumulate( & mut res) ) ;
@@ -521,7 +591,7 @@ mod tests {
521
591
buf. as_mut_slice ( ) . as_flattened_mut ( ) [ ..size] . copy_from_slice ( & data[ ..] ) ;
522
592
reader. init ( buf, size) . unwrap ( ) ;
523
593
let mut res = reader. start_accumulated_result ( ) ;
524
- assert_eq ! ( false , reader. read_bool ( 128 ) . or_accumulate( & mut res) ) ;
594
+ assert_eq ! ( false , reader. read_flag ( ) . or_accumulate( & mut res) ) ;
525
595
assert_eq ! ( true , reader. read_bool( 10 ) . or_accumulate( & mut res) ) ;
526
596
assert_eq ! ( false , reader. read_bool( 250 ) . or_accumulate( & mut res) ) ;
527
597
assert_eq ! ( 1 , reader. read_literal( 1 ) . or_accumulate( & mut res) ) ;
0 commit comments