From 7b25ba81c1a942d1d39ea5fd7e91ea08690f3ea2 Mon Sep 17 00:00:00 2001 From: Sander in 't Veld Date: Sat, 28 Dec 2024 18:23:22 +0100 Subject: [PATCH 1/3] Faster BoolReader --- src/bool_reader.rs | 543 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/vp8.rs | 471 +++++++++++++++++---------------------- 3 files changed, 751 insertions(+), 264 deletions(-) create mode 100644 src/bool_reader.rs diff --git a/src/bool_reader.rs b/src/bool_reader.rs new file mode 100644 index 0000000..725c629 --- /dev/null +++ b/src/bool_reader.rs @@ -0,0 +1,543 @@ +use crate::decoder::DecodingError; + +use super::vp8::TreeNode; + +#[must_use] +#[repr(transparent)] +pub(crate) struct BitResult { + value_if_not_past_eof: T, +} + +#[must_use] +pub(crate) struct BitResultAccumulator; + +impl BitResult { + const fn ok(value: T) -> Self { + Self { + value_if_not_past_eof: value, + } + } + + /// Instead of checking this result now, accumulate the burden of checking + /// into an accumulator. This accumulator must be checked in the end. + #[inline(always)] + pub(crate) fn or_accumulate(self, acc: &mut BitResultAccumulator) -> T { + let _ = acc; + self.value_if_not_past_eof + } +} + +impl BitResult { + fn err() -> Self { + Self { + value_if_not_past_eof: T::default(), + } + } +} + +#[cfg_attr(test, derive(Debug))] +pub(crate) struct BoolReader { + chunks: Box<[[u8; 4]]>, + state: State, + final_bytes: [u8; 3], + final_bytes_remaining: i8, +} + +#[cfg_attr(test, derive(Debug))] +#[derive(Clone, Copy)] +struct State { + chunk_index: usize, + value: u64, + range: u32, + bit_count: i32, +} + +#[cfg_attr(test, derive(Debug))] +struct FastReader<'a> { + chunks: &'a [[u8; 4]], + uncommitted_state: State, + save_state: &'a mut State, +} + +impl BoolReader { + pub(crate) fn new() -> BoolReader { + let state = State { + chunk_index: 0, + value: 0, + range: 255, + bit_count: -8, + }; + BoolReader { + chunks: Box::new([]), + state, + final_bytes: [0; 3], + final_bytes_remaining: Self::FINAL_BYTES_REMAINING_EOF, + } + } + + pub(crate) fn init(&mut self, mut buf: Vec<[u8; 4]>, len: usize) -> Result<(), DecodingError> { + let mut final_bytes = [0; 3]; + let final_bytes_remaining = if len == 4 * buf.len() { + 0 + } else { + // Pop the last chunk (which is partial), then get length. + let Some(last_chunk) = buf.pop() else { + return Err(DecodingError::NotEnoughInitData); + }; + let len_rounded_down = 4 * buf.len(); + let num_bytes_popped = len - len_rounded_down; + debug_assert!(num_bytes_popped <= 3); + final_bytes[..num_bytes_popped].copy_from_slice(&last_chunk[..num_bytes_popped]); + for i in num_bytes_popped..4 { + debug_assert_eq!(last_chunk[i], 0, "unexpected {last_chunk:?}"); + } + num_bytes_popped as i8 + }; + + let chunks = buf.into_boxed_slice(); + let state = State { + chunk_index: 0, + value: 0, + range: 255, + bit_count: -8, + }; + *self = Self { + chunks, + state, + final_bytes, + final_bytes_remaining, + }; + Ok(()) + } + + /// Start a span of reading operations from the buffer, without stopping + /// when the buffer runs out. For all valid webp images, the buffer will not + /// run out prematurely. Conversely if the buffer ends early, the webp image + /// cannot be correctly decoded and any intermediate results need to be + /// discarded anyway. + /// + /// Each call to `start_accumulated_result` must be followed by a call to + /// `check` on the *same* `BoolReader`. + #[inline(always)] + pub(crate) fn start_accumulated_result(&mut self) -> BitResultAccumulator { + BitResultAccumulator + } + + /// Check that the read operations done so far were all valid. + #[inline(always)] + pub(crate) fn check( + &self, + acc: BitResultAccumulator, + value_if_not_past_eof: T, + ) -> Result { + // The accumulator does not store any state because doing so is + // too computationally expensive. Passing it around is a bit of + // formality (that is optimized out) to ensure we call `check` . + // Instead we check whether we have read past the end of the file. + let BitResultAccumulator = acc; + + if self.is_past_eof() { + Err(DecodingError::BitStreamError) + } else { + Ok(value_if_not_past_eof) + } + } + + fn keep_accumulating( + &self, + acc: BitResultAccumulator, + value_if_not_past_eof: T, + ) -> BitResult { + // The BitResult will be checked later by a different accumulator. + // Because it does not carry state, that is fine. + let BitResultAccumulator = acc; + + BitResult::ok(value_if_not_past_eof) + } + + // Do not inline this because inlining seems to worsen performance. + #[inline(never)] + pub(crate) fn read_bool(&mut self, probability: u8) -> BitResult { + if let Some(b) = self.fast().read_bit(probability) { + return BitResult::ok(b); + } + + self.cold_read_bool(probability) + } + + // Do not inline this because inlining seems to worsen performance. + #[inline(never)] + pub(crate) fn read_literal(&mut self, n: u8) -> BitResult { + if let Some(v) = self.fast().read_literal(n) { + return BitResult::ok(v); + } + + self.cold_read_literal(n) + } + + // Do not inline this because inlining seems to worsen performance. + #[inline(never)] + pub(crate) fn read_optional_signed_value(&mut self, n: u8) -> BitResult { + if let Some(v) = self.fast().read_optional_signed_value(n) { + return BitResult::ok(v); + } + + self.cold_read_optional_signed_value(n) + } + + // This is generic and inlined just to skip the first bounds check. + #[inline] + pub(crate) fn read_with_tree(&mut self, tree: &[TreeNode; N]) -> BitResult { + let first_node = tree[0]; + self.read_with_tree_with_first_node(tree, first_node) + } + + // Do not inline this because inlining significantly worsens performance. + #[inline(never)] + pub(crate) fn read_with_tree_with_first_node( + &mut self, + tree: &[TreeNode], + first_node: TreeNode, + ) -> BitResult { + if let Some(v) = self.fast().read_with_tree(tree, first_node) { + return BitResult::ok(v); + } + + self.cold_read_with_tree(tree, usize::from(first_node.index)) + } + + // This should be inlined to allow it to share the instruction cache with + // `read_bool`, as both functions are short and called often. + #[inline] + pub(crate) fn read_flag(&mut self) -> BitResult { + self.read_bool(128) + } + + // As a similar (but different) speedup to BitResult, the FastReader reads + // bits under an assumption and validates it at the end. + // + // The idea here is that for normal-sized webp images, the vast majority + // of bits are somewhere other than in the last four bytes. Therefore we + // can pretend the buffer has infinite size. After we are done reading, + // we check if we actually read past the end of `self.chunks`. + // If so, we backtrack (or rather we discard `uncommitted_state`) + // and try again with the slow approach. This might result in doing double + // work for those last few bytes -- in fact we even keep retrying the fast + // method to save an if-statement --, but more than make up for that by + // speeding up reading from the other thousands or millions of bytes. + fn fast(&mut self) -> FastReader<'_> { + FastReader { + chunks: &self.chunks, + uncommitted_state: self.state, + save_state: &mut self.state, + } + } + + const FINAL_BYTES_REMAINING_EOF: i8 = -0xE; + + fn load_from_final_bytes(&mut self) { + match self.final_bytes_remaining { + 1.. => { + self.final_bytes_remaining -= 1; + let byte = self.final_bytes[0]; + self.final_bytes.rotate_left(1); + self.state.value <<= 8; + self.state.value |= u64::from(byte); + self.state.bit_count += 8; + } + 0 => { + // libwebp seems to (sometimes?) allow bitstreams that read one byte past the end. + // This replicates that logic. + self.final_bytes_remaining -= 1; + self.state.value <<= 8; + self.state.bit_count += 8; + } + _ => { + self.final_bytes_remaining = Self::FINAL_BYTES_REMAINING_EOF; + } + } + } + + fn is_past_eof(&self) -> bool { + self.final_bytes_remaining == Self::FINAL_BYTES_REMAINING_EOF + } + + fn cold_read_bit(&mut self, probability: u8) -> BitResult { + if self.state.bit_count < 0 { + if let Some(chunk) = self.chunks.get(self.state.chunk_index).copied() { + let v = u32::from_be_bytes(chunk); + self.state.chunk_index += 1; + self.state.value <<= 32; + self.state.value |= u64::from(v); + self.state.bit_count += 32; + } else { + self.load_from_final_bytes(); + if self.is_past_eof() { + return BitResult::err(); + } + } + } + debug_assert!(self.state.bit_count >= 0); + + let probability = u32::from(probability); + let split = 1 + (((self.state.range - 1) * probability) >> 8); + let bigsplit = u64::from(split) << self.state.bit_count; + + let retval = if let Some(new_value) = self.state.value.checked_sub(bigsplit) { + self.state.range -= split; + self.state.value = new_value; + true + } else { + self.state.range = split; + false + }; + debug_assert!(self.state.range > 0); + + // Compute shift required to satisfy `self.state.range >= 128`. + // Apply that shift to `self.state.range` and `self.state.bitcount`. + // + // Subtract 24 because we only care about leading zeros in the + // lowest byte of `self.state.range` which is a `u32`. + let shift = self.state.range.leading_zeros().saturating_sub(24); + self.state.range <<= shift; + self.state.bit_count -= shift as i32; + debug_assert!(self.state.range >= 128); + + BitResult::ok(retval) + } + + #[cold] + #[inline(never)] + fn cold_read_bool(&mut self, probability: u8) -> BitResult { + self.cold_read_bit(probability) + } + + #[cold] + #[inline(never)] + fn cold_read_literal(&mut self, n: u8) -> BitResult { + let mut v = 0u8; + let mut res = self.start_accumulated_result(); + + for _ in 0..n { + let b = self.cold_read_bit(128).or_accumulate(&mut res); + v = (v << 1) + b as u8; + } + + self.keep_accumulating(res, v) + } + + #[cold] + #[inline(never)] + fn cold_read_optional_signed_value(&mut self, n: u8) -> BitResult { + let mut res = self.start_accumulated_result(); + let flag = self.cold_read_bool(128).or_accumulate(&mut res); + if !flag { + // We should not read further bits if the flag is not set. + return self.keep_accumulating(res, 0); + } + let magnitude = self.cold_read_literal(n).or_accumulate(&mut res); + let sign = self.cold_read_bool(128).or_accumulate(&mut res); + + let value = if sign { + -i32::from(magnitude) + } else { + i32::from(magnitude) + }; + self.keep_accumulating(res, value) + } + + #[cold] + #[inline(never)] + fn cold_read_with_tree(&mut self, tree: &[TreeNode], start: usize) -> BitResult { + let mut index = start; + let mut res = self.start_accumulated_result(); + + loop { + let node = tree[index]; + let prob = node.prob; + let b = self.cold_read_bit(prob).or_accumulate(&mut res); + let t = if b { node.right } else { node.left }; + let new_index = usize::from(t); + if new_index < tree.len() { + index = new_index; + } else { + let value = TreeNode::value_from_branch(t); + return self.keep_accumulating(res, value); + } + } + } +} + +impl FastReader<'_> { + fn commit_if_valid(self, value_if_not_past_eof: T) -> Option { + // If `chunk_index > self.chunks.len()`, it means we used zeroes + // instead of an actual chunk and `value_if_not_past_eof` is nonsense. + if self.uncommitted_state.chunk_index <= self.chunks.len() { + *self.save_state = self.uncommitted_state; + Some(value_if_not_past_eof) + } else { + None + } + } + + fn read_bit(mut self, probability: u8) -> Option { + let bit = self.fast_read_bit(probability); + self.commit_if_valid(bit) + } + + fn read_literal(mut self, n: u8) -> Option { + let value = self.fast_read_literal(n); + self.commit_if_valid(value) + } + + fn read_optional_signed_value(mut self, n: u8) -> Option { + let flag = self.fast_read_bit(128); + if !flag { + // We should not read further bits if the flag is not set. + return self.commit_if_valid(0); + } + let magnitude = self.fast_read_literal(n); + let sign = self.fast_read_bit(128); + let value = if sign { + -i32::from(magnitude) + } else { + i32::from(magnitude) + }; + self.commit_if_valid(value) + } + + fn read_with_tree(mut self, tree: &[TreeNode], first_node: TreeNode) -> Option { + let value = self.fast_read_with_tree(tree, first_node); + self.commit_if_valid(value) + } + + fn fast_read_bit(&mut self, probability: u8) -> bool { + let State { + mut chunk_index, + mut value, + mut range, + mut bit_count, + } = self.uncommitted_state; + + if bit_count < 0 { + let chunk = self.chunks.get(chunk_index).copied(); + // We ignore invalid data inside the `fast_` functions, + // but we increase `chunk_index` below, so we can check + // whether we read invalid data in `commit_if_valid`. + let chunk = chunk.unwrap_or_default(); + + let v = u32::from_be_bytes(chunk); + chunk_index += 1; + value <<= 32; + value |= u64::from(v); + bit_count += 32; + } + debug_assert!(bit_count >= 0); + + let probability = u32::from(probability); + let split = 1 + (((range - 1) * probability) >> 8); + let bigsplit = u64::from(split) << bit_count; + + let retval = if let Some(new_value) = value.checked_sub(bigsplit) { + range -= split; + value = new_value; + true + } else { + range = split; + false + }; + debug_assert!(range > 0); + + // Compute shift required to satisfy `range >= 128`. + // Apply that shift to `range` and `self.bitcount`. + // + // Subtract 24 because we only care about leading zeros in the + // lowest byte of `range` which is a `u32`. + let shift = range.leading_zeros().saturating_sub(24); + range <<= shift; + bit_count -= shift as i32; + debug_assert!(range >= 128); + + self.uncommitted_state = State { + chunk_index, + value, + range, + bit_count, + }; + retval + } + + fn fast_read_literal(&mut self, n: u8) -> u8 { + let mut v = 0u8; + for _ in 0..n { + let b = self.fast_read_bit(128); + v = (v << 1) + b as u8; + } + v + } + + fn fast_read_with_tree(&mut self, tree: &[TreeNode], mut node: TreeNode) -> i8 { + loop { + let prob = node.prob; + let b = self.fast_read_bit(prob); + let i = if b { node.right } else { node.left }; + let Some(next_node) = tree.get(usize::from(i)) else { + return TreeNode::value_from_branch(i); + }; + node = *next_node; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bool_reader_hello_short() { + let mut reader = BoolReader::new(); + let data = b"hel"; + let size = data.len(); + let mut buf = vec![[0u8; 4]; 1]; + buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); + reader.init(buf, size).unwrap(); + let mut res = reader.start_accumulated_result(); + assert_eq!(false, reader.read_bool(128).or_accumulate(&mut res)); + assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res)); + assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res)); + assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res)); + assert_eq!(5, reader.read_literal(3).or_accumulate(&mut res)); + assert_eq!(64, reader.read_literal(8).or_accumulate(&mut res)); + assert_eq!(185, reader.read_literal(8).or_accumulate(&mut res)); + reader.check(res, ()).unwrap(); + } + + #[test] + fn test_bool_reader_hello_long() { + let mut reader = BoolReader::new(); + let data = b"hello world"; + let size = data.len(); + let mut buf = vec![[0u8; 4]; (size + 3) / 4]; + buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); + reader.init(buf, size).unwrap(); + let mut res = reader.start_accumulated_result(); + assert_eq!(false, reader.read_bool(128).or_accumulate(&mut res)); + assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res)); + assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res)); + assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res)); + assert_eq!(5, reader.read_literal(3).or_accumulate(&mut res)); + assert_eq!(64, reader.read_literal(8).or_accumulate(&mut res)); + assert_eq!(185, reader.read_literal(8).or_accumulate(&mut res)); + assert_eq!(31, reader.read_literal(8).or_accumulate(&mut res)); + reader.check(res, ()).unwrap(); + } + + #[test] + fn test_bool_reader_uninit() { + let mut reader = BoolReader::new(); + let mut res = reader.start_accumulated_result(); + let _ = reader.read_flag().or_accumulate(&mut res); + let result = reader.check(res, ()); + assert!(result.is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 93edeb1..645e250 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub use self::decoder::{DecodingError, LoopCount, WebPDecoder}; pub use self::encoder::{ColorType, EncoderParams, EncodingError, WebPEncoder}; mod alpha_blending; +mod bool_reader; mod decoder; mod encoder; mod extended; diff --git a/src/vp8.rs b/src/vp8.rs index 1d772a6..b30a5c3 100644 --- a/src/vp8.rs +++ b/src/vp8.rs @@ -10,14 +10,16 @@ //! * [rfc-6386](http://tools.ietf.org/html/rfc6386) - The VP8 Data Format and Decoding Guide //! * [VP8.pdf](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37073.pdf) - An overview of of the VP8 format -use byteorder_lite::{BigEndian, LittleEndian, ReadBytesExt}; +use byteorder_lite::{LittleEndian, ReadBytesExt}; use std::cmp; use std::default::Default; -use std::io::{Cursor, ErrorKind, Read}; +use std::io::Read; use crate::decoder::DecodingError; -use super::{loop_filter, transform}; +use super::bool_reader::BoolReader; +use super::loop_filter; +use super::transform; const MAX_SEGMENTS: usize = 4; const NUM_DCT_TOKENS: usize = 12; @@ -96,23 +98,77 @@ enum IntraMode { type Prob = u8; -static SEGMENT_ID_TREE: [i8; 6] = [2, 4, -0, -1, -2, -3]; +#[derive(Clone, Copy)] +pub(crate) struct TreeNode { + pub left: u8, + pub right: u8, + pub prob: Prob, + pub index: u8, +} + +impl TreeNode { + const UNINIT: TreeNode = TreeNode { + left: 0, + right: 0, + prob: 0, + index: 0, + }; + + const fn prepare_branch(t: i8) -> u8 { + if t > 0 { + (t as u8) / 2 + } else { + let value = -t; + 0x80 | (value as u8) + } + } + + pub(crate) const fn value_from_branch(t: u8) -> i8 { + (t & !0x80) as i8 + } +} + +const fn tree_nodes_from( + tree: [i8; N], + probs: [Prob; M], +) -> [TreeNode; M] { + if N != 2 * M { + panic!("invalid tree with probs"); + } + let mut nodes = [TreeNode::UNINIT; M]; + let mut i = 0; + while i < M { + nodes[i].left = TreeNode::prepare_branch(tree[2 * i]); + nodes[i].right = TreeNode::prepare_branch(tree[2 * i + 1]); + nodes[i].prob = probs[i]; + nodes[i].index = i as u8; + i += 1; + } + nodes +} + +const SEGMENT_ID_TREE: [i8; 6] = [2, 4, -0, -1, -2, -3]; + +const SEGMENT_TREE_NODE_DEFAULTS: [TreeNode; 3] = tree_nodes_from(SEGMENT_ID_TREE, [255; 3]); // Section 11.2 // Tree for determining the keyframe luma intra prediction modes: -static KEYFRAME_YMODE_TREE: [i8; 8] = [-B_PRED, 2, 4, 6, -DC_PRED, -V_PRED, -H_PRED, -TM_PRED]; +const KEYFRAME_YMODE_TREE: [i8; 8] = [-B_PRED, 2, 4, 6, -DC_PRED, -V_PRED, -H_PRED, -TM_PRED]; // Default probabilities for decoding the keyframe luma modes -static KEYFRAME_YMODE_PROBS: [Prob; 4] = [145, 156, 163, 128]; +const KEYFRAME_YMODE_PROBS: [Prob; 4] = [145, 156, 163, 128]; + +const KEYFRAME_YMODE_NODES: [TreeNode; 4] = + tree_nodes_from(KEYFRAME_YMODE_TREE, KEYFRAME_YMODE_PROBS); // Tree for determining the keyframe B_PRED mode: -static KEYFRAME_BPRED_MODE_TREE: [i8; 18] = [ +const KEYFRAME_BPRED_MODE_TREE: [i8; 18] = [ -B_DC_PRED, 2, -B_TM_PRED, 4, -B_VE_PRED, 6, 8, 12, -B_HE_PRED, 10, -B_RD_PRED, -B_VR_PRED, -B_LD_PRED, 14, -B_VL_PRED, 16, -B_HD_PRED, -B_HU_PRED, ]; // Probabilities for the BPRED_MODE_TREE -static KEYFRAME_BPRED_MODE_PROBS: [[[u8; 9]; 10]; 10] = [ +const KEYFRAME_BPRED_MODE_PROBS: [[[Prob; 9]; 10]; 10] = [ [ [231, 120, 48, 89, 115, 113, 120, 152, 112], [152, 179, 64, 126, 170, 118, 46, 70, 95], @@ -235,17 +291,36 @@ static KEYFRAME_BPRED_MODE_PROBS: [[[u8; 9]; 10]; 10] = [ ], ]; +const KEYFRAME_BPRED_MODE_NODES: [[[TreeNode; 9]; 10]; 10] = { + let mut output = [[[TreeNode::UNINIT; 9]; 10]; 10]; + let mut i = 0; + while i < output.len() { + let mut j = 0; + while j < output[i].len() { + output[i][j] = + tree_nodes_from(KEYFRAME_BPRED_MODE_TREE, KEYFRAME_BPRED_MODE_PROBS[i][j]); + j += 1; + } + i += 1; + } + output +}; + // Section 11.4 Tree for determining macroblock the chroma mode -static KEYFRAME_UV_MODE_TREE: [i8; 6] = [-DC_PRED, 2, -V_PRED, 4, -H_PRED, -TM_PRED]; +const KEYFRAME_UV_MODE_TREE: [i8; 6] = [-DC_PRED, 2, -V_PRED, 4, -H_PRED, -TM_PRED]; // Probabilities for determining macroblock mode -static KEYFRAME_UV_MODE_PROBS: [Prob; 3] = [142, 114, 183]; +const KEYFRAME_UV_MODE_PROBS: [Prob; 3] = [142, 114, 183]; + +const KEYFRAME_UV_MODE_NODES: [TreeNode; 3] = + tree_nodes_from(KEYFRAME_UV_MODE_TREE, KEYFRAME_UV_MODE_PROBS); // Section 13.4 type TokenProbTables = [[[[Prob; NUM_DCT_TOKENS - 1]; 3]; 8]; 4]; +type TokenProbTreeNodes = [[[[TreeNode; NUM_DCT_TOKENS - 1]; 3]; 8]; 4]; // Probabilities that a token's probability will be updated -static COEFF_UPDATE_PROBS: TokenProbTables = [ +const COEFF_UPDATE_PROBS: TokenProbTables = [ [ [ [255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255], @@ -418,7 +493,7 @@ static COEFF_UPDATE_PROBS: TokenProbTables = [ // Section 13.5 // Default Probabilities for tokens -static COEFF_PROBS: TokenProbTables = [ +const COEFF_PROBS: TokenProbTables = [ [ [ [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128], @@ -589,6 +664,24 @@ static COEFF_PROBS: TokenProbTables = [ ], ]; +const COEFF_PROB_NODES: TokenProbTreeNodes = { + let mut output = [[[[TreeNode::UNINIT; 11]; 3]; 8]; 4]; + let mut i = 0; + while i < output.len() { + let mut j = 0; + while j < output[i].len() { + let mut k = 0; + while k < output[i][j].len() { + output[i][j][k] = tree_nodes_from(DCT_TOKEN_TREE, COEFF_PROBS[i][j][k]); + k += 1; + } + j += 1; + } + i += 1; + } + output +}; + // DCT Tokens const DCT_0: i8 = 0; const DCT_1: i8 = 1; @@ -603,12 +696,12 @@ const DCT_CAT5: i8 = 9; const DCT_CAT6: i8 = 10; const DCT_EOB: i8 = 11; -static DCT_TOKEN_TREE: [i8; 22] = [ +const DCT_TOKEN_TREE: [i8; 22] = [ -DCT_EOB, 2, -DCT_0, 4, -DCT_1, 6, 8, 12, -DCT_2, 10, -DCT_3, -DCT_4, 14, 16, -DCT_CAT1, -DCT_CAT2, 18, 20, -DCT_CAT3, -DCT_CAT4, -DCT_CAT5, -DCT_CAT6, ]; -static PROB_DCT_CAT: [[Prob; 12]; 6] = [ +const PROB_DCT_CAT: [[Prob; 12]; 6] = [ [159, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [165, 145, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [173, 148, 140, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -617,11 +710,11 @@ static PROB_DCT_CAT: [[Prob; 12]; 6] = [ [254, 254, 243, 230, 196, 177, 153, 140, 133, 130, 129, 0], ]; -static DCT_CAT_BASE: [u8; 6] = [5, 7, 11, 19, 35, 67]; -static COEFF_BANDS: [u8; 16] = [0, 1, 2, 3, 6, 4, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7]; +const DCT_CAT_BASE: [u8; 6] = [5, 7, 11, 19, 35, 67]; +const COEFF_BANDS: [u8; 16] = [0, 1, 2, 3, 6, 4, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7]; #[rustfmt::skip] -static DC_QUANT: [i16; 128] = [ +const DC_QUANT: [i16; 128] = [ 4, 5, 6, 7, 8, 9, 10, 10, 11, 12, 13, 14, 15, 16, 17, 17, 18, 19, 20, 20, 21, 21, 22, 22, @@ -641,7 +734,7 @@ static DC_QUANT: [i16; 128] = [ ]; #[rustfmt::skip] -static AC_QUANT: [i16; 128] = [ +const AC_QUANT: [i16; 128] = [ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, @@ -660,128 +753,7 @@ static AC_QUANT: [i16; 128] = [ 249, 254, 259, 264, 269, 274, 279, 284, ]; -static ZIGZAG: [u8; 16] = [0, 1, 4, 8, 5, 2, 3, 6, 9, 12, 13, 10, 7, 11, 14, 15]; - -struct BoolReader { - reader: Cursor>, - range: u32, - value: u32, - bit_count: u8, - eof: bool, -} - -impl BoolReader { - pub(crate) fn new() -> Self { - Self { - reader: Default::default(), - range: 0, - value: 0, - bit_count: 0, - eof: false, - } - } - - pub(crate) fn init(&mut self, buf: Vec) -> Result<(), DecodingError> { - if buf.len() < 2 { - return Err(DecodingError::NotEnoughInitData); - } - - self.reader = Cursor::new(buf); - self.value = u32::from(self.reader.read_u16::()?); - self.range = 255; - self.bit_count = 0; - - Ok(()) - } - - pub(crate) fn read_bool(&mut self, probability: u8) -> Result { - let split = 1 + (((self.range - 1) * u32::from(probability)) >> 8); - let bigsplit = split << 8; - - let retval = if self.value >= bigsplit { - self.range -= split; - self.value -= bigsplit; - true - } else { - self.range = split; - false - }; - - if self.range < 128 { - // Compute shift required to satisfy `self.range >= 128`. - // Apply that shift to `self.range`, `self.value`, and `self.bitcount`. - // - // Subtract 24 because we only care about leading zeros in the - // lowest byte of `self.range` which is a `u32`. - let shift = self.range.leading_zeros() - 24; - self.value <<= shift; - self.range <<= shift; - self.bit_count += shift as u8; - - if self.bit_count >= 8 { - self.bit_count %= 8; - - // libwebp seems to (sometimes?) allow bitstreams that read one byte past the end. - // This match statement replicates that logic. - match self.reader.read_u8() { - Ok(v) => self.value |= u32::from(v) << self.bit_count, - Err(e) if e.kind() == ErrorKind::UnexpectedEof && !self.eof => self.eof = true, - Err(e) => return Err(DecodingError::IoError(e)), - }; - } - } - - Ok(retval) - } - - pub(crate) fn read_literal(&mut self, n: u8) -> Result { - let mut v = 0u8; - let mut n = n; - - while n != 0 { - v = (v << 1) + u8::from(self.read_bool(128u8)?); - n -= 1; - } - - Ok(v) - } - - pub(crate) fn read_magnitude_and_sign(&mut self, n: u8) -> Result { - let magnitude = self.read_literal(n)?; - let sign = self.read_literal(1)?; - - if sign == 1 { - Ok(-i32::from(magnitude)) - } else { - Ok(i32::from(magnitude)) - } - } - - pub(crate) fn read_with_tree( - &mut self, - tree: &[i8], - probs: &[Prob], - start: isize, - ) -> Result { - let mut index = start; - - loop { - let a = self.read_bool(probs[index as usize >> 1])?; - let b = index + isize::from(a); - index = tree[b as usize] as isize; - - if index <= 0 { - break; - } - } - - Ok(-index as i8) - } - - pub(crate) fn read_flag(&mut self) -> Result { - Ok(0 != self.read_literal(1)?) - } -} +const ZIGZAG: [u8; 16] = [0, 1, 4, 8, 5, 2, 3, 6, 9, 12, 13, 10, 7, 11, 14, 15]; #[derive(Default, Clone, Copy)] struct MacroBlock { @@ -1067,8 +1039,8 @@ pub struct Vp8Decoder { partitions: [BoolReader; 8], num_partitions: u8, - segment_tree_probs: [Prob; 3], - token_probs: Box, + segment_tree_nodes: [TreeNode; 3], + token_probs: Box, // Section 9.10 prob_intra: Prob, @@ -1120,8 +1092,8 @@ impl Vp8Decoder { num_partitions: 1, - segment_tree_probs: [255u8; 3], - token_probs: Box::new(COEFF_PROBS), + segment_tree_nodes: SEGMENT_TREE_NODE_DEFAULTS, + token_probs: Box::new(COEFF_PROB_NODES), // Section 9.10 prob_intra: 0u8, @@ -1138,20 +1110,20 @@ impl Vp8Decoder { } fn update_token_probabilities(&mut self) -> Result<(), DecodingError> { + let mut res = self.b.start_accumulated_result(); for (i, is) in COEFF_UPDATE_PROBS.iter().enumerate() { for (j, js) in is.iter().enumerate() { for (k, ks) in js.iter().enumerate() { for (t, prob) in ks.iter().enumerate().take(NUM_DCT_TOKENS - 1) { - if self.b.read_bool(*prob)? { - let v = self.b.read_literal(8)?; - self.token_probs[i][j][k][t] = v; + if self.b.read_bool(*prob).or_accumulate(&mut res) { + let v = self.b.read_literal(8).or_accumulate(&mut res); + self.token_probs[i][j][k][t].prob = v; } } } } } - - Ok(()) + self.b.check(res, ()) } fn init_partitions(&mut self, n: usize) -> Result<(), DecodingError> { @@ -1164,16 +1136,20 @@ impl Vp8Decoder { .read_u24::() .expect("Reading from &[u8] can't fail and the chunk is complete"); - let mut buf = vec![0; size as usize]; - self.r.read_exact(buf.as_mut_slice())?; - - self.partitions[i].init(buf)?; + let size = size as usize; + let mut buf = vec![[0; 4]; (size + 3) / 4]; + let bytes: &mut [u8] = buf.as_mut_slice().as_flattened_mut(); + self.r.read_exact(&mut bytes[..size])?; + self.partitions[i].init(buf, size)?; } } let mut buf = Vec::new(); self.r.read_to_end(&mut buf)?; - self.partitions[n - 1].init(buf)?; + let size = buf.len(); + let mut chunks = vec![[0; 4]; (size + 3) / 4]; + chunks.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&buf); + self.partitions[n - 1].init(chunks, size)?; Ok(()) } @@ -1187,36 +1163,14 @@ impl Vp8Decoder { AC_QUANT[index.clamp(0, 127) as usize] } - let yac_abs = self.b.read_literal(7)?; - let ydc_delta = if self.b.read_flag()? { - self.b.read_magnitude_and_sign(4)? - } else { - 0 - }; + let mut res = self.b.start_accumulated_result(); - let y2dc_delta = if self.b.read_flag()? { - self.b.read_magnitude_and_sign(4)? - } else { - 0 - }; - - let y2ac_delta = if self.b.read_flag()? { - self.b.read_magnitude_and_sign(4)? - } else { - 0 - }; - - let uvdc_delta = if self.b.read_flag()? { - self.b.read_magnitude_and_sign(4)? - } else { - 0 - }; - - let uvac_delta = if self.b.read_flag()? { - self.b.read_magnitude_and_sign(4)? - } else { - 0 - }; + let yac_abs = self.b.read_literal(7).or_accumulate(&mut res); + let ydc_delta = self.b.read_optional_signed_value(4).or_accumulate(&mut res); + let y2dc_delta = self.b.read_optional_signed_value(4).or_accumulate(&mut res); + let y2ac_delta = self.b.read_optional_signed_value(4).or_accumulate(&mut res); + let uvdc_delta = self.b.read_optional_signed_value(4).or_accumulate(&mut res); + let uvac_delta = self.b.read_optional_signed_value(4).or_accumulate(&mut res); let n = if self.segments_enabled { MAX_SEGMENTS @@ -1253,77 +1207,64 @@ impl Vp8Decoder { } } - Ok(()) + self.b.check(res, ()) } fn read_loop_filter_adjustments(&mut self) -> Result<(), DecodingError> { - if self.b.read_flag()? { - for i in 0usize..4 { - let ref_frame_delta_update_flag = self.b.read_flag()?; + let mut res = self.b.start_accumulated_result(); - self.ref_delta[i] = if ref_frame_delta_update_flag { - self.b.read_magnitude_and_sign(6)? - } else { - 0i32 - }; + if self.b.read_flag().or_accumulate(&mut res) { + for i in 0usize..4 { + self.ref_delta[i] = self.b.read_optional_signed_value(6).or_accumulate(&mut res); } for i in 0usize..4 { - let mb_mode_delta_update_flag = self.b.read_flag()?; - - self.mode_delta[i] = if mb_mode_delta_update_flag { - self.b.read_magnitude_and_sign(6)? - } else { - 0i32 - }; + self.mode_delta[i] = self.b.read_optional_signed_value(6).or_accumulate(&mut res); } } - Ok(()) + self.b.check(res, ()) } fn read_segment_updates(&mut self) -> Result<(), DecodingError> { + let mut res = self.b.start_accumulated_result(); + // Section 9.3 - self.segments_update_map = self.b.read_flag()?; - let update_segment_feature_data = self.b.read_flag()?; + self.segments_update_map = self.b.read_flag().or_accumulate(&mut res); + let update_segment_feature_data = self.b.read_flag().or_accumulate(&mut res); if update_segment_feature_data { - let segment_feature_mode = self.b.read_flag()?; + let segment_feature_mode = self.b.read_flag().or_accumulate(&mut res); for i in 0usize..MAX_SEGMENTS { self.segment[i].delta_values = !segment_feature_mode; } for i in 0usize..MAX_SEGMENTS { - let update = self.b.read_flag()?; - - self.segment[i].quantizer_level = if update { - self.b.read_magnitude_and_sign(7)? - } else { - 0i32 - } as i8; + self.segment[i].quantizer_level = + self.b.read_optional_signed_value(7).or_accumulate(&mut res) as i8; } for i in 0usize..MAX_SEGMENTS { - let update = self.b.read_flag()?; - - self.segment[i].loopfilter_level = if update { - self.b.read_magnitude_and_sign(6)? - } else { - 0i32 - } as i8; + self.segment[i].loopfilter_level = + self.b.read_optional_signed_value(6).or_accumulate(&mut res) as i8; } } if self.segments_update_map { for i in 0usize..3 { - let update = self.b.read_flag()?; + let update = self.b.read_flag().or_accumulate(&mut res); - self.segment_tree_probs[i] = if update { self.b.read_literal(8)? } else { 255 }; + let prob = if update { + self.b.read_literal(8).or_accumulate(&mut res) + } else { + 255 + }; + self.segment_tree_nodes[i].prob = prob; } } - Ok(()) + self.b.check(res, ()) } fn read_frame_header(&mut self) -> Result<(), DecodingError> { @@ -1366,37 +1307,42 @@ impl Vp8Decoder { self.left_border = vec![129u8; 1 + 16]; } - let mut buf = vec![0; first_partition_size as usize]; - self.r.read_exact(&mut buf)?; + let size = first_partition_size as usize; + let mut buf = vec![[0; 4]; (size + 3) / 4]; + let bytes: &mut [u8] = buf.as_mut_slice().as_flattened_mut(); + self.r.read_exact(&mut bytes[..size])?; // initialise binary decoder - self.b.init(buf)?; + self.b.init(buf, size)?; + let mut res = self.b.start_accumulated_result(); if self.frame.keyframe { - let color_space = self.b.read_literal(1)?; - self.frame.pixel_type = self.b.read_literal(1)?; + let color_space = self.b.read_literal(1).or_accumulate(&mut res); + self.frame.pixel_type = self.b.read_literal(1).or_accumulate(&mut res); if color_space != 0 { return Err(DecodingError::ColorSpaceInvalid(color_space)); } } - self.segments_enabled = self.b.read_flag()?; + self.segments_enabled = self.b.read_flag().or_accumulate(&mut res); if self.segments_enabled { self.read_segment_updates()?; } - self.frame.filter_type = self.b.read_flag()?; - self.frame.filter_level = self.b.read_literal(6)?; - self.frame.sharpness_level = self.b.read_literal(3)?; + self.frame.filter_type = self.b.read_flag().or_accumulate(&mut res); + self.frame.filter_level = self.b.read_literal(6).or_accumulate(&mut res); + self.frame.sharpness_level = self.b.read_literal(3).or_accumulate(&mut res); - let lf_adjust_enable = self.b.read_flag()?; + let lf_adjust_enable = self.b.read_flag().or_accumulate(&mut res); if lf_adjust_enable { self.read_loop_filter_adjustments()?; } - self.num_partitions = (1usize << self.b.read_literal(2)? as usize) as u8; - let num_partitions = self.num_partitions as usize; + let num_partitions = 1 << self.b.read_literal(2).or_accumulate(&mut res) as usize; + self.b.check(res, ())?; + + self.num_partitions = num_partitions as u8; self.init_partitions(num_partitions)?; self.read_quantization_indices()?; @@ -1414,12 +1360,14 @@ impl Vp8Decoder { self.update_token_probabilities()?; - let mb_no_skip_coeff = self.b.read_literal(1)?; + let mut res = self.b.start_accumulated_result(); + let mb_no_skip_coeff = self.b.read_literal(1).or_accumulate(&mut res); self.prob_skip_false = if mb_no_skip_coeff == 1 { - Some(self.b.read_literal(8)?) + Some(self.b.read_literal(8).or_accumulate(&mut res)) } else { None }; + self.b.check(res, ())?; if !self.frame.keyframe { // 9.10 remaining frame data @@ -1438,22 +1386,21 @@ impl Vp8Decoder { fn read_macroblock_header(&mut self, mbx: usize) -> Result { let mut mb = MacroBlock::default(); + let mut res = self.b.start_accumulated_result(); if self.segments_enabled && self.segments_update_map { - mb.segmentid = self - .b - .read_with_tree(&SEGMENT_ID_TREE, &self.segment_tree_probs, 0)? - as u8; + mb.segmentid = + (self.b.read_with_tree(&self.segment_tree_nodes)).or_accumulate(&mut res) as u8; }; - mb.coeffs_skipped = if self.prob_skip_false.is_some() { - self.b.read_bool(*self.prob_skip_false.as_ref().unwrap())? + mb.coeffs_skipped = if let Some(prob) = self.prob_skip_false { + self.b.read_bool(prob).or_accumulate(&mut res) } else { false }; let inter_predicted = if !self.frame.keyframe { - self.b.read_bool(self.prob_intra)? + self.b.read_bool(self.prob_intra).or_accumulate(&mut res) } else { false }; @@ -1466,9 +1413,7 @@ impl Vp8Decoder { if self.frame.keyframe { // intra prediction - let luma = self - .b - .read_with_tree(&KEYFRAME_YMODE_TREE, &KEYFRAME_YMODE_PROBS, 0)?; + let luma = (self.b.read_with_tree(&KEYFRAME_YMODE_NODES)).or_accumulate(&mut res); mb.luma_mode = LumaMode::from_i8(luma).ok_or(DecodingError::LumaPredictionModeInvalid(luma))?; @@ -1480,10 +1425,9 @@ impl Vp8Decoder { let top = self.top[mbx].bpred[12 + x]; let left = self.left.bpred[y]; let intra = self.b.read_with_tree( - &KEYFRAME_BPRED_MODE_TREE, - &KEYFRAME_BPRED_MODE_PROBS[top as usize][left as usize], - 0, - )?; + &KEYFRAME_BPRED_MODE_NODES[top as usize][left as usize], + ); + let intra = intra.or_accumulate(&mut res); let bmode = IntraMode::from_i8(intra) .ok_or(DecodingError::IntraPredictionModeInvalid(intra))?; mb.bpred[x + y * 4] = bmode; @@ -1501,9 +1445,7 @@ impl Vp8Decoder { } } - let chroma = - self.b - .read_with_tree(&KEYFRAME_UV_MODE_TREE, &KEYFRAME_UV_MODE_PROBS, 0)?; + let chroma = (self.b.read_with_tree(&KEYFRAME_UV_MODE_NODES)).or_accumulate(&mut res); mb.chroma_mode = ChromaMode::from_i8(chroma) .ok_or(DecodingError::ChromaPredictionModeInvalid(chroma))?; } @@ -1512,7 +1454,7 @@ impl Vp8Decoder { self.top[mbx].luma_mode = mb.luma_mode; self.top[mbx].bpred = mb.bpred; - Ok(mb) + self.b.check(res, mb) } fn intra_predict_luma(&mut self, mbx: usize, mby: usize, mb: &MacroBlock, resdata: &[i32]) { @@ -1689,21 +1631,21 @@ impl Vp8Decoder { let first = if plane == 0 { 1usize } else { 0usize }; let probs = &self.token_probs[plane]; - let tree = &DCT_TOKEN_TREE; let reader = &mut self.partitions[p]; + let mut res = reader.start_accumulated_result(); + let mut complexity = complexity; let mut has_coefficients = false; let mut skip = false; for i in first..16usize { - let table = &probs[COEFF_BANDS[i] as usize][complexity]; + let band = COEFF_BANDS[i] as usize; + let tree = &probs[band][complexity]; - let token = if !skip { - reader.read_with_tree(tree, table, 0)? - } else { - reader.read_with_tree(tree, table, 2)? - }; + let token = reader + .read_with_tree_with_first_node(tree, tree[skip as usize]) + .or_accumulate(&mut res); let mut abs_value = i32::from(match token { DCT_EOB => break, @@ -1726,7 +1668,8 @@ impl Vp8Decoder { if t == 0 { break; } - extra = extra + extra + i16::from(reader.read_bool(t)?); + let b = reader.read_bool(t).or_accumulate(&mut res); + extra = extra + extra + b as i16; } i16::from(DCT_CAT_BASE[(category - DCT_CAT1) as usize]) + extra @@ -1745,17 +1688,17 @@ impl Vp8Decoder { 2 }; - if reader.read_bool(128)? { + if reader.read_flag().or_accumulate(&mut res) { abs_value = -abs_value; } - block[ZIGZAG[i] as usize] = - abs_value * i32::from(if ZIGZAG[i] > 0 { acq } else { dcq }); + let zigzag = ZIGZAG[i] as usize; + block[zigzag] = abs_value * i32::from(if zigzag > 0 { acq } else { dcq }); has_coefficients = true; } - Ok(has_coefficients) + reader.check(res, has_coefficients) } fn read_residual_data( From 3607c213a948b3c9e45f664bc2f1b0d687e86da1 Mon Sep 17 00:00:00 2001 From: Sander in 't Veld Date: Sun, 29 Dec 2024 16:39:31 +0100 Subject: [PATCH 2/3] Optimize FastReader::read_flag --- src/bool_reader.rs | 108 +++++++++++++++++++++++++++++++++++++-------- src/lossless.rs | 2 +- 2 files changed, 90 insertions(+), 20 deletions(-) diff --git a/src/bool_reader.rs b/src/bool_reader.rs index 725c629..4c2ce8d 100644 --- a/src/bool_reader.rs +++ b/src/bool_reader.rs @@ -158,13 +158,23 @@ impl BoolReader { // Do not inline this because inlining seems to worsen performance. #[inline(never)] pub(crate) fn read_bool(&mut self, probability: u8) -> BitResult { - if let Some(b) = self.fast().read_bit(probability) { + if let Some(b) = self.fast().read_bool(probability) { return BitResult::ok(b); } self.cold_read_bool(probability) } + // Do not inline this because inlining seems to worsen performance. + #[inline(never)] + pub(crate) fn read_flag(&mut self) -> BitResult { + if let Some(b) = self.fast().read_flag() { + return BitResult::ok(b); + } + + self.cold_read_flag() + } + // Do not inline this because inlining seems to worsen performance. #[inline(never)] pub(crate) fn read_literal(&mut self, n: u8) -> BitResult { @@ -206,13 +216,6 @@ impl BoolReader { self.cold_read_with_tree(tree, usize::from(first_node.index)) } - // This should be inlined to allow it to share the instruction cache with - // `read_bool`, as both functions are short and called often. - #[inline] - pub(crate) fn read_flag(&mut self) -> BitResult { - self.read_bool(128) - } - // As a similar (but different) speedup to BitResult, the FastReader reads // bits under an assumption and validates it at the end. // @@ -312,6 +315,12 @@ impl BoolReader { self.cold_read_bit(probability) } + #[cold] + #[inline(never)] + fn cold_read_flag(&mut self) -> BitResult { + self.cold_read_bit(128) + } + #[cold] #[inline(never)] fn cold_read_literal(&mut self, n: u8) -> BitResult { @@ -319,8 +328,8 @@ impl BoolReader { let mut res = self.start_accumulated_result(); for _ in 0..n { - let b = self.cold_read_bit(128).or_accumulate(&mut res); - v = (v << 1) + b as u8; + let b = self.cold_read_flag().or_accumulate(&mut res); + v = (v << 1) + u8::from(b); } self.keep_accumulating(res, v) @@ -330,13 +339,13 @@ impl BoolReader { #[inline(never)] fn cold_read_optional_signed_value(&mut self, n: u8) -> BitResult { let mut res = self.start_accumulated_result(); - let flag = self.cold_read_bool(128).or_accumulate(&mut res); + let flag = self.cold_read_flag().or_accumulate(&mut res); if !flag { // We should not read further bits if the flag is not set. return self.keep_accumulating(res, 0); } let magnitude = self.cold_read_literal(n).or_accumulate(&mut res); - let sign = self.cold_read_bool(128).or_accumulate(&mut res); + let sign = self.cold_read_flag().or_accumulate(&mut res); let value = if sign { -i32::from(magnitude) @@ -380,24 +389,29 @@ impl FastReader<'_> { } } - fn read_bit(mut self, probability: u8) -> Option { + fn read_bool(mut self, probability: u8) -> Option { let bit = self.fast_read_bit(probability); self.commit_if_valid(bit) } + fn read_flag(mut self) -> Option { + let value = self.fast_read_flag(); + self.commit_if_valid(value) + } + fn read_literal(mut self, n: u8) -> Option { let value = self.fast_read_literal(n); self.commit_if_valid(value) } fn read_optional_signed_value(mut self, n: u8) -> Option { - let flag = self.fast_read_bit(128); + let flag = self.fast_read_flag(); if !flag { // We should not read further bits if the flag is not set. return self.commit_if_valid(0); } let magnitude = self.fast_read_literal(n); - let sign = self.fast_read_bit(128); + let sign = self.fast_read_flag(); let value = if sign { -i32::from(magnitude) } else { @@ -467,11 +481,67 @@ impl FastReader<'_> { retval } + fn fast_read_flag(&mut self) -> bool { + let State { + mut chunk_index, + mut value, + mut range, + mut bit_count, + } = self.uncommitted_state; + + if bit_count < 0 { + let chunk = self.chunks.get(chunk_index).copied(); + // We ignore invalid data inside the `fast_` functions, + // but we increase `chunk_index` below, so we can check + // whether we read invalid data in `commit_if_valid`. + let chunk = chunk.unwrap_or_default(); + + let v = u32::from_be_bytes(chunk); + chunk_index += 1; + value <<= 32; + value |= u64::from(v); + bit_count += 32; + } + debug_assert!(bit_count >= 0); + + let half_range = range / 2; + let split = range - half_range; + let bigsplit = u64::from(split) << bit_count; + + let retval = if let Some(new_value) = value.checked_sub(bigsplit) { + range = half_range; + value = new_value; + true + } else { + range = split; + false + }; + debug_assert!(range > 0); + + // Compute shift required to satisfy `range >= 128`. + // Apply that shift to `range` and `self.bitcount`. + // + // Subtract 24 because we only care about leading zeros in the + // lowest byte of `range` which is a `u32`. + let shift = range.leading_zeros().saturating_sub(24); + range <<= shift; + bit_count -= shift as i32; + debug_assert!(range >= 128); + + self.uncommitted_state = State { + chunk_index, + value, + range, + bit_count, + }; + retval + } + fn fast_read_literal(&mut self, n: u8) -> u8 { let mut v = 0u8; for _ in 0..n { - let b = self.fast_read_bit(128); - v = (v << 1) + b as u8; + let b = self.fast_read_flag(); + v = (v << 1) + u8::from(b); } v } @@ -502,7 +572,7 @@ mod tests { buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); reader.init(buf, size).unwrap(); let mut res = reader.start_accumulated_result(); - assert_eq!(false, reader.read_bool(128).or_accumulate(&mut res)); + assert_eq!(false, reader.read_flag().or_accumulate(&mut res)); assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res)); assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res)); assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res)); @@ -521,7 +591,7 @@ mod tests { buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); reader.init(buf, size).unwrap(); let mut res = reader.start_accumulated_result(); - assert_eq!(false, reader.read_bool(128).or_accumulate(&mut res)); + assert_eq!(false, reader.read_flag().or_accumulate(&mut res)); assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res)); assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res)); assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res)); diff --git a/src/lossless.rs b/src/lossless.rs index 14070f9..51c06bb 100644 --- a/src/lossless.rs +++ b/src/lossless.rs @@ -316,7 +316,7 @@ impl LosslessDecoder { entropy_image = data .chunks_exact(4) .map(|pixel| { - let meta_huff_code = u16::from(pixel[0]) << 8 | u16::from(pixel[1]); + let meta_huff_code = (u16::from(pixel[0]) << 8) | u16::from(pixel[1]); if u32::from(meta_huff_code) >= num_huff_groups { num_huff_groups = u32::from(meta_huff_code) + 1; } From cadb88f4f5838152c918d7727ce76c1bb582537b Mon Sep 17 00:00:00 2001 From: Sander in 't Veld Date: Wed, 1 Jan 2025 20:03:43 +0100 Subject: [PATCH 3/3] Rename BoolReader to ArithmeticDecoder --- src/lib.rs | 2 +- src/vp8.rs | 41 +++++----- ...ol_reader.rs => vp8_arithmetic_decoder.rs} | 80 +++++++++---------- 3 files changed, 61 insertions(+), 62 deletions(-) rename src/{bool_reader.rs => vp8_arithmetic_decoder.rs} (89%) diff --git a/src/lib.rs b/src/lib.rs index 645e250..dedc9f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,6 @@ pub use self::decoder::{DecodingError, LoopCount, WebPDecoder}; pub use self::encoder::{ColorType, EncoderParams, EncodingError, WebPEncoder}; mod alpha_blending; -mod bool_reader; mod decoder; mod encoder; mod extended; @@ -22,5 +21,6 @@ mod loop_filter; mod lossless; mod lossless_transform; mod transform; +mod vp8_arithmetic_decoder; pub mod vp8; diff --git a/src/vp8.rs b/src/vp8.rs index b30a5c3..6e79eb8 100644 --- a/src/vp8.rs +++ b/src/vp8.rs @@ -17,9 +17,8 @@ use std::io::Read; use crate::decoder::DecodingError; -use super::bool_reader::BoolReader; -use super::loop_filter; -use super::transform; +use super::vp8_arithmetic_decoder::ArithmeticDecoder; +use super::{loop_filter, transform}; const MAX_SEGMENTS: usize = 4; const NUM_DCT_TOKENS: usize = 12; @@ -1021,7 +1020,7 @@ struct Segment { /// Only decodes keyframes pub struct Vp8Decoder { r: R, - b: BoolReader, + b: ArithmeticDecoder, mbwidth: u16, mbheight: u16, @@ -1036,7 +1035,7 @@ pub struct Vp8Decoder { ref_delta: [i32; 4], mode_delta: [i32; 4], - partitions: [BoolReader; 8], + partitions: [ArithmeticDecoder; 8], num_partitions: u8, segment_tree_nodes: [TreeNode; 3], @@ -1065,7 +1064,7 @@ impl Vp8Decoder { Self { r, - b: BoolReader::new(), + b: ArithmeticDecoder::new(), mbwidth: 0, mbheight: 0, @@ -1080,14 +1079,14 @@ impl Vp8Decoder { mode_delta: [0; 4], partitions: [ - BoolReader::new(), - BoolReader::new(), - BoolReader::new(), - BoolReader::new(), - BoolReader::new(), - BoolReader::new(), - BoolReader::new(), - BoolReader::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), + ArithmeticDecoder::new(), ], num_partitions: 1, @@ -1631,9 +1630,9 @@ impl Vp8Decoder { let first = if plane == 0 { 1usize } else { 0usize }; let probs = &self.token_probs[plane]; - let reader = &mut self.partitions[p]; + let decoder = &mut self.partitions[p]; - let mut res = reader.start_accumulated_result(); + let mut res = decoder.start_accumulated_result(); let mut complexity = complexity; let mut has_coefficients = false; @@ -1643,7 +1642,7 @@ impl Vp8Decoder { let band = COEFF_BANDS[i] as usize; let tree = &probs[band][complexity]; - let token = reader + let token = decoder .read_with_tree_with_first_node(tree, tree[skip as usize]) .or_accumulate(&mut res); @@ -1668,8 +1667,8 @@ impl Vp8Decoder { if t == 0 { break; } - let b = reader.read_bool(t).or_accumulate(&mut res); - extra = extra + extra + b as i16; + let b = decoder.read_bool(t).or_accumulate(&mut res); + extra = extra + extra + i16::from(b); } i16::from(DCT_CAT_BASE[(category - DCT_CAT1) as usize]) + extra @@ -1688,7 +1687,7 @@ impl Vp8Decoder { 2 }; - if reader.read_flag().or_accumulate(&mut res) { + if decoder.read_flag().or_accumulate(&mut res) { abs_value = -abs_value; } @@ -1698,7 +1697,7 @@ impl Vp8Decoder { has_coefficients = true; } - reader.check(res, has_coefficients) + decoder.check(res, has_coefficients) } fn read_residual_data( diff --git a/src/bool_reader.rs b/src/vp8_arithmetic_decoder.rs similarity index 89% rename from src/bool_reader.rs rename to src/vp8_arithmetic_decoder.rs index 4c2ce8d..7f8c4be 100644 --- a/src/bool_reader.rs +++ b/src/vp8_arithmetic_decoder.rs @@ -36,7 +36,7 @@ impl BitResult { } #[cfg_attr(test, derive(Debug))] -pub(crate) struct BoolReader { +pub(crate) struct ArithmeticDecoder { chunks: Box<[[u8; 4]]>, state: State, final_bytes: [u8; 3], @@ -53,21 +53,21 @@ struct State { } #[cfg_attr(test, derive(Debug))] -struct FastReader<'a> { +struct FastDecoder<'a> { chunks: &'a [[u8; 4]], uncommitted_state: State, save_state: &'a mut State, } -impl BoolReader { - pub(crate) fn new() -> BoolReader { +impl ArithmeticDecoder { + pub(crate) fn new() -> ArithmeticDecoder { let state = State { chunk_index: 0, value: 0, range: 255, bit_count: -8, }; - BoolReader { + ArithmeticDecoder { chunks: Box::new([]), state, final_bytes: [0; 3], @@ -117,7 +117,7 @@ impl BoolReader { /// discarded anyway. /// /// Each call to `start_accumulated_result` must be followed by a call to - /// `check` on the *same* `BoolReader`. + /// `check` on the *same* `ArithmeticDecoder`. #[inline(always)] pub(crate) fn start_accumulated_result(&mut self) -> BitResultAccumulator { BitResultAccumulator @@ -216,7 +216,7 @@ impl BoolReader { self.cold_read_with_tree(tree, usize::from(first_node.index)) } - // As a similar (but different) speedup to BitResult, the FastReader reads + // As a similar (but different) speedup to BitResult, the FastDecoder reads // bits under an assumption and validates it at the end. // // The idea here is that for normal-sized webp images, the vast majority @@ -228,8 +228,8 @@ impl BoolReader { // work for those last few bytes -- in fact we even keep retrying the fast // method to save an if-statement --, but more than make up for that by // speeding up reading from the other thousands or millions of bytes. - fn fast(&mut self) -> FastReader<'_> { - FastReader { + fn fast(&mut self) -> FastDecoder<'_> { + FastDecoder { chunks: &self.chunks, uncommitted_state: self.state, save_state: &mut self.state, @@ -377,7 +377,7 @@ impl BoolReader { } } -impl FastReader<'_> { +impl FastDecoder<'_> { fn commit_if_valid(self, value_if_not_past_eof: T) -> Option { // If `chunk_index > self.chunks.len()`, it means we used zeroes // instead of an actual chunk and `value_if_not_past_eof` is nonsense. @@ -564,50 +564,50 @@ mod tests { use super::*; #[test] - fn test_bool_reader_hello_short() { - let mut reader = BoolReader::new(); + fn test_arithmetic_decoder_hello_short() { + let mut decoder = ArithmeticDecoder::new(); let data = b"hel"; let size = data.len(); let mut buf = vec![[0u8; 4]; 1]; buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); - reader.init(buf, size).unwrap(); - let mut res = reader.start_accumulated_result(); - assert_eq!(false, reader.read_flag().or_accumulate(&mut res)); - assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res)); - assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res)); - assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res)); - assert_eq!(5, reader.read_literal(3).or_accumulate(&mut res)); - assert_eq!(64, reader.read_literal(8).or_accumulate(&mut res)); - assert_eq!(185, reader.read_literal(8).or_accumulate(&mut res)); - reader.check(res, ()).unwrap(); + decoder.init(buf, size).unwrap(); + let mut res = decoder.start_accumulated_result(); + assert_eq!(false, decoder.read_flag().or_accumulate(&mut res)); + assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res)); + assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res)); + assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res)); + assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res)); + assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res)); + assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res)); + decoder.check(res, ()).unwrap(); } #[test] - fn test_bool_reader_hello_long() { - let mut reader = BoolReader::new(); + fn test_arithmetic_decoder_hello_long() { + let mut decoder = ArithmeticDecoder::new(); let data = b"hello world"; let size = data.len(); let mut buf = vec![[0u8; 4]; (size + 3) / 4]; buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); - reader.init(buf, size).unwrap(); - let mut res = reader.start_accumulated_result(); - assert_eq!(false, reader.read_flag().or_accumulate(&mut res)); - assert_eq!(true, reader.read_bool(10).or_accumulate(&mut res)); - assert_eq!(false, reader.read_bool(250).or_accumulate(&mut res)); - assert_eq!(1, reader.read_literal(1).or_accumulate(&mut res)); - assert_eq!(5, reader.read_literal(3).or_accumulate(&mut res)); - assert_eq!(64, reader.read_literal(8).or_accumulate(&mut res)); - assert_eq!(185, reader.read_literal(8).or_accumulate(&mut res)); - assert_eq!(31, reader.read_literal(8).or_accumulate(&mut res)); - reader.check(res, ()).unwrap(); + decoder.init(buf, size).unwrap(); + let mut res = decoder.start_accumulated_result(); + assert_eq!(false, decoder.read_flag().or_accumulate(&mut res)); + assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res)); + assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res)); + assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res)); + assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res)); + assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res)); + assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res)); + assert_eq!(31, decoder.read_literal(8).or_accumulate(&mut res)); + decoder.check(res, ()).unwrap(); } #[test] - fn test_bool_reader_uninit() { - let mut reader = BoolReader::new(); - let mut res = reader.start_accumulated_result(); - let _ = reader.read_flag().or_accumulate(&mut res); - let result = reader.check(res, ()); + fn test_arithmetic_decoder_uninit() { + let mut decoder = ArithmeticDecoder::new(); + let mut res = decoder.start_accumulated_result(); + let _ = decoder.read_flag().or_accumulate(&mut res); + let result = decoder.check(res, ()); assert!(result.is_err()); } }