From 1352dcbeaeb75930bfea485c01fd7226a19f3f1b Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Mon, 3 Jun 2024 12:49:31 -0700 Subject: [PATCH] aes: Clarify counter overflow checking. Move `Conuter` and `Iv` to aes/counter.rs. Create a more robust internal API for counter/nonce/IV management that makes the usage within AES-GCM more clearly correct. The new design is easier to test. `git difftool HEAD^1:src/aead/aes.rs src/aead/aes/counter.rs` --- src/aead/aes.rs | 30 +++- src/aead/aes/bs.rs | 14 +- src/aead/aes/counter.rs | 351 +++++++++++++++++++++++++++++++++++++-- src/aead/aes/fallback.rs | 20 ++- src/aead/aes/ffi.rs | 35 ++-- src/aead/aes/hw.rs | 22 ++- src/aead/aes/vp.rs | 111 +++++++++---- src/aead/aes_gcm.rs | 97 +++++++---- src/limb.rs | 7 +- src/polyfill.rs | 8 + 10 files changed, 585 insertions(+), 110 deletions(-) diff --git a/src/aead/aes.rs b/src/aead/aes.rs index e83c0b5934..9d28430646 100644 --- a/src/aead/aes.rs +++ b/src/aead/aes.rs @@ -21,7 +21,10 @@ use crate::{ use cfg_if::cfg_if; use core::ops::RangeFrom; -pub(super) use self::{counter::Iv, ffi::Counter}; +pub(super) use self::{ + counter::{CounterOverflowError, Iv, IvBlock}, + ffi::Counter, +}; #[macro_use] mod ffi; @@ -114,6 +117,14 @@ pub enum KeyBytes<'a> { AES_256(&'a [u8; AES_256_KEY_LEN]), } +pub(super) struct InOutLenInconsistentWithIvBlockLenError(()); +impl InOutLenInconsistentWithIvBlockLenError { + #[cold] + fn new() -> Self { + Self(()) + } +} + pub(super) type Block = [u8; BLOCK_LEN]; pub(super) const BLOCK_LEN: usize = 16; pub(super) const ZERO_BLOCK: Block = [0u8; BLOCK_LEN]; @@ -124,7 +135,12 @@ pub(super) trait EncryptBlock { } pub(super) trait EncryptCtr32 { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter); + fn ctr32_encrypt_within( + &self, + in_out: &mut [u8], + src: RangeFrom, + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError>; } #[allow(dead_code)] @@ -144,11 +160,11 @@ fn encrypt_iv_xor_block_using_encrypt_block( #[allow(dead_code)] fn encrypt_iv_xor_block_using_ctr32(key: &impl EncryptCtr32, iv: Iv, mut block: Block) -> Block { - // This is OK because we're only encrypting one block, and `iv` is already - // reserved for us to use. - let mut ctr = Counter(iv.into_block_less_safe()); - key.ctr32_encrypt_within(&mut block, 0.., &mut ctr); - block + let iv_block = IvBlock::from_iv(iv); + match key.ctr32_encrypt_within(&mut block, 0.., iv_block) { + Ok(()) => block, + Result::<_, InOutLenInconsistentWithIvBlockLenError>::Err(_) => unreachable!(), + } } #[cfg(test)] diff --git a/src/aead/aes/bs.rs b/src/aead/aes/bs.rs index f1c5408a26..c9f0246fd0 100644 --- a/src/aead/aes/bs.rs +++ b/src/aead/aes/bs.rs @@ -14,7 +14,7 @@ #![cfg(target_arch = "arm")] -use super::{Counter, AES_KEY}; +use super::{IvBlock, AES_KEY}; use core::ops::RangeFrom; /// SAFETY: @@ -31,8 +31,8 @@ pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key( in_out: &mut [u8], src: RangeFrom, vpaes_key: &AES_KEY, - ctr: &mut Counter, -) { + iv_block: IvBlock, +) -> Result<(), super::InOutLenInconsistentWithIvBlockLenError> { prefixed_extern! { // bsaes_ctr32_encrypt_blocks requires transformation of an existing // VPAES key; there is no `bsaes_set_encrypt_key`. @@ -57,6 +57,12 @@ pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key( // * `bsaes_ctr32_encrypt_blocks` satisfies the contract for // `ctr32_encrypt_blocks`. unsafe { - ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, src, &bsaes_key, ctr); + ctr32_encrypt_blocks!( + bsaes_ctr32_encrypt_blocks, + in_out, + src, + &bsaes_key, + iv_block + ) } } diff --git a/src/aead/aes/counter.rs b/src/aead/aes/counter.rs index f9a6194d67..b45a5d2304 100644 --- a/src/aead/aes/counter.rs +++ b/src/aead/aes/counter.rs @@ -17,48 +17,377 @@ use super::{ ffi::Counter, Block, BLOCK_LEN, }; +use crate::polyfill::{nonzerousize_from_nonzerou32, unwrap_const}; +use core::num::{NonZeroU32, NonZeroUsize}; // `Counter` is `ffi::Counter` as its representation is dictated by its use in // the FFI. impl Counter { - pub fn one(nonce: Nonce) -> Self { + pub fn one_two(nonce: Nonce) -> (Iv, Self) { let mut value = [0u8; BLOCK_LEN]; value[..NONCE_LEN].copy_from_slice(nonce.as_ref()); value[BLOCK_LEN - 1] = 1; - Self(value) + let iv = Iv::new_less_safe(value); + value[BLOCK_LEN - 1] = 2; + (iv, Self(value)) } - pub fn increment(&mut self) -> Iv { + pub fn try_into_iv(self) -> Result { let iv = Iv(self.0); - self.increment_by_less_safe(1); - iv + let [.., c0, c1, c2, c3] = &self.0; + let old_value: u32 = u32::from_be_bytes([*c0, *c1, *c2, *c3]); + if old_value == 0 { + return Err(CounterOverflowError::new()); + } + Ok(iv) + } + + pub fn increment_by( + &mut self, + increment_by: NonZeroUsize, + ) -> Result { + #[cold] + #[inline(never)] + fn overflowed(sum: u32) -> Result { + match sum { + 0 => Ok(0), + _ => Err(CounterOverflowError::new()), + } + } + + let iv = Iv(self.0); + + let increment_by = match NonZeroU32::try_from(increment_by) { + Ok(value) => value, + _ => return Err(CounterOverflowError::new()), + }; + + let [.., c0, c1, c2, c3] = &mut self.0; + let old_value: u32 = u32::from_be_bytes([*c0, *c1, *c2, *c3]); + if old_value == 0 { + return Err(CounterOverflowError::new()); + } + let new_value = match old_value.overflowing_add(increment_by.get()) { + (sum, false) => sum, + (sum, true) => overflowed(sum)?, + }; + [*c0, *c1, *c2, *c3] = u32::to_be_bytes(new_value); + + Ok(IvBlock { + initial_iv: iv, + len: increment_by, + }) } - pub(super) fn increment_by_less_safe(&mut self, increment_by: u32) { + #[cfg(target_arch = "x86")] + pub(super) fn increment_unchecked_less_safe(&mut self) -> Iv { + let iv = Iv(self.0); + let [.., c0, c1, c2, c3] = &mut self.0; let old_value: u32 = u32::from_be_bytes([*c0, *c1, *c2, *c3]); - let new_value = old_value + increment_by; + debug_assert_ne!(old_value, 0); + // TODO: unchecked_add? + let new_value = old_value.wrapping_add(1); + // Note that it *is* valid for new_value to be zero! [*c0, *c1, *c2, *c3] = u32::to_be_bytes(new_value); + + iv + } +} + +pub(in super::super) struct CounterOverflowError(()); + +impl CounterOverflowError { + #[cold] + fn new() -> Self { + Self(()) + } +} + +pub(in super::super) struct IvBlock { + initial_iv: Iv, + // invariant: 0 < len && len <= u32::MAX + len: NonZeroU32, +} + +impl IvBlock { + pub(super) fn from_iv(iv: Iv) -> Self { + const _1: NonZeroU32 = unwrap_const(NonZeroU32::new(1)); + Self { + initial_iv: iv, + len: _1, + } + } + + // This conversion cannot fail. + pub fn len(&self) -> NonZeroUsize { + nonzerousize_from_nonzerou32(self.len) + } + + // "Less safe" because this subverts the IV reuse prevention machinery. The + // caller must ensure the IV is used only once. + pub(super) fn into_initial_iv(self) -> Iv { + self.initial_iv + } + + #[cfg(any(target_arch = "arm", test))] + pub(super) fn split_at( + self, + num_blocks: usize, + ) -> Result<(Option, Option), CounterOverflowError> { + let num_before = u32::try_from(num_blocks).map_err(|_| CounterOverflowError::new())?; + let num_after = self + .len + .get() + .checked_sub(num_before) + .ok_or_else(CounterOverflowError::new)?; + + let num_before = match NonZeroU32::new(num_before) { + Some(num_blocks) => num_blocks, + None => return Ok((None, Some(self))), + }; + let num_after = match NonZeroU32::new(num_after) { + Some(num_after) => num_after, + None => return Ok((Some(self), None)), + }; + let mut ctr = Counter(self.initial_iv.0); + match ctr.increment_by(nonzerousize_from_nonzerou32(num_before)) { + Ok(before) => { + let after = Self { + initial_iv: Iv::new_less_safe(ctr.0), + len: num_after, + }; + Ok((Some(before), Some(after))) + } + Result::<_, CounterOverflowError>::Err(_) => { + unreachable!() + } + } + } + + #[cfg(target_arch = "x86")] + pub(super) fn into_counter_less_safe( + self, + input_blocks: usize, + ) -> Result { + if input_blocks != self.len().get() { + return Err(super::InOutLenInconsistentWithIvBlockLenError::new()); + } + Ok(Counter(self.initial_iv.0)) } } /// The IV for a single block encryption. /// /// Intentionally not `Clone` to ensure each is used only once. -pub struct Iv(Block); +pub(in super::super) struct Iv(Block); impl Iv { + // This is "less safe" because it subverts the counter reuse protection. + // The caller needs to ensure that the IV isn't reused. pub(super) fn new_less_safe(value: Block) -> Self { Self(value) } + /// "Less safe" because it defeats attempts to use the type system to prevent reuse of the IV. + #[inline] pub(super) fn into_block_less_safe(self) -> Block { self.0 } } -impl From for Iv { - fn from(counter: Counter) -> Self { - Self(counter.0) +#[cfg(test)] +mod tests { + use super::*; + use crate::polyfill::usize_from_u32; + + const DUMMY_ONCE_VALUE: [u8; NONCE_LEN] = [ + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, + ]; + fn dummy_nonce() -> Nonce { + Nonce::assume_unique_for_key(DUMMY_ONCE_VALUE) + } + + fn dummy_value(counter: [u8; 4]) -> [u8; BLOCK_LEN] { + let mut value = [0u8; BLOCK_LEN]; + value[..NONCE_LEN].copy_from_slice(&DUMMY_ONCE_VALUE); + value[NONCE_LEN..].copy_from_slice(&counter); + value + } + + const _1: NonZeroUsize = unwrap_const(NonZeroUsize::new(1)); + const _2: NonZeroUsize = unwrap_const(NonZeroUsize::new(2)); + const MAX: NonZeroUsize = unwrap_const(NonZeroUsize::new(usize_from_u32(u32::MAX))); + const MAX_MINUS_1: NonZeroUsize = unwrap_const(NonZeroUsize::new(MAX.get() - 1)); + const MAX_MINUS_2: NonZeroUsize = unwrap_const(NonZeroUsize::new(MAX.get() - 2)); + + const USIZE_MAX: NonZeroUsize = unwrap_const(NonZeroUsize::new(usize::MAX)); + + #[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32")))] + const MAX_PLUS_1: NonZeroUsize = unwrap_const(NonZeroUsize::new(MAX.get() + 1)); + + #[test] + fn one_is_one() { + let (one, _two) = Counter::one_two(dummy_nonce()); + let as_block = one.into_block_less_safe(); + assert_eq!(as_block, dummy_value([0, 0, 0, 1])); + } + + #[test] + fn two_is_two() { + let (_one, two) = Counter::one_two(dummy_nonce()); + let as_block = two.try_into_iv().ok().unwrap().into_block_less_safe(); + assert_eq!(as_block, dummy_value([0, 0, 0, 2])); + } + + #[test] + fn smallest_increment() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let _: IvBlock = ctr.increment_by(_1).ok().unwrap(); + assert_eq!( + ctr.try_into_iv().ok().unwrap().into_block_less_safe(), + dummy_value([0, 0, 0, 3]) + ); + } + + #[test] + fn carries() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let iv_block: IvBlock = ctr + .increment_by(NonZeroUsize::new(0xfe).unwrap()) + .ok() + .unwrap(); + assert_eq!( + iv_block.into_initial_iv().into_block_less_safe(), + dummy_value([0, 0, 0, 2]) + ); + let iv_block = ctr + .increment_by(NonZeroUsize::new(0xff_00).unwrap()) + .ok() + .unwrap(); + assert_eq!( + iv_block.into_initial_iv().into_block_less_safe(), + dummy_value([0, 0, 1, 0]) + ); + let iv_block = ctr + .increment_by(NonZeroUsize::new(0xff_00_00).unwrap()) + .ok() + .unwrap(); + assert_eq!( + iv_block.into_initial_iv().into_block_less_safe(), + dummy_value([0, 1, 0, 0]) + ); + let iv_block = ctr + .increment_by(NonZeroUsize::new(0xff_00_00_00).unwrap()) + .ok() + .unwrap(); + assert_eq!( + iv_block.into_initial_iv().into_block_less_safe(), + dummy_value([1, 0, 0, 0]) + ); + assert_eq!(&ctr.0[..], dummy_value([0, 0, 0, 0])); + assert!(ctr.try_into_iv().is_err()); // Because it is zero + } + + #[test] + fn large_increment() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let _: IvBlock = ctr.increment_by(MAX_MINUS_2).ok().unwrap(); + let iv_block = ctr.increment_by(_1).ok().unwrap(); + assert_eq!( + iv_block.into_initial_iv().into_block_less_safe(), + dummy_value([0xff, 0xff, 0xff, 0xff]) + ); + assert!(ctr.increment_by(_1).is_err()); + } + + #[test] + fn larger_increment_then_increment() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let _: IvBlock = ctr.increment_by(MAX_MINUS_1).ok().unwrap(); + assert_eq!(&ctr.0[..], dummy_value([0, 0, 0, 0])); + assert!(ctr.increment_by(MAX_MINUS_1).is_err()); + } + + #[test] + fn larger_increment_then_into_iv() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let _: IvBlock = ctr.increment_by(MAX_MINUS_1).ok().unwrap(); + assert_eq!(&ctr.0[..], dummy_value([0, 0, 0, 0])); + assert!(ctr.try_into_iv().is_err()); + } + + #[test] + fn even_larger_increment() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + assert!(ctr.increment_by(MAX).is_err()); + } + + #[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32")))] + #[test] + fn even_larger_still_increment() { + const MAX_PLUS_1: NonZeroUsize = unwrap_const(NonZeroUsize::new(MAX.get() + 1)); + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + assert!(ctr.increment_by(MAX_PLUS_1).is_err()); + } + + #[test] + fn way_too_large_increment() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + assert!(ctr.increment_by(USIZE_MAX).is_err()); + } + + #[test] + fn split_at_start() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let iv_block = ctr.increment_by(_1).ok().unwrap(); + let (a, b) = iv_block.split_at(0).ok().unwrap(); + assert!(a.is_none()); + assert_eq!( + b.unwrap().into_initial_iv().into_block_less_safe(), + dummy_value([0, 0, 0, 2]) + ); + } + + #[test] + fn split_at_end() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let iv_block = ctr.increment_by(_1).ok().unwrap(); + let (a, b) = iv_block.split_at(1).ok().unwrap(); + assert_eq!( + a.unwrap().into_initial_iv().into_block_less_safe(), + dummy_value([0, 0, 0, 2]) + ); + assert!(b.is_none()); + } + + #[test] + fn split_at_middle() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let iv_block = ctr.increment_by(_2).ok().unwrap(); + let (a, b) = iv_block.split_at(1).ok().unwrap(); + assert_eq!( + a.unwrap().into_initial_iv().into_block_less_safe(), + dummy_value([0, 0, 0, 2]) + ); + assert_eq!( + b.unwrap().into_initial_iv().into_block_less_safe(), + dummy_value([0, 0, 0, 3]) + ); + } + + #[test] + fn split_at_overflow() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let iv_block = ctr.increment_by(_1).ok().unwrap(); + assert!(iv_block.split_at(2).is_err()); + } + + #[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32")))] + #[test] + fn split_at_u32_max_plus_1() { + let (_, mut ctr) = Counter::one_two(dummy_nonce()); + let iv_block = ctr.increment_by(MAX_MINUS_2).ok().unwrap(); + assert!(iv_block.split_at(MAX_PLUS_1.get()).is_err()); } } diff --git a/src/aead/aes/fallback.rs b/src/aead/aes/fallback.rs index 00caa694ab..ea2e988889 100644 --- a/src/aead/aes/fallback.rs +++ b/src/aead/aes/fallback.rs @@ -12,7 +12,10 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY}; +use super::{ + Block, EncryptBlock, EncryptCtr32, InOutLenInconsistentWithIvBlockLenError, Iv, IvBlock, + KeyBytes, AES_KEY, +}; use crate::error; use core::ops::RangeFrom; @@ -39,9 +42,20 @@ impl EncryptBlock for Key { } impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { + fn ctr32_encrypt_within( + &self, + in_out: &mut [u8], + src: RangeFrom, + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError> { unsafe { - ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) + ctr32_encrypt_blocks!( + aes_nohw_ctr32_encrypt_blocks, + in_out, + src, + &self.inner, + iv_block + ) } } } diff --git a/src/aead/aes/ffi.rs b/src/aead/aes/ffi.rs index 840845059b..8ad76c72da 100644 --- a/src/aead/aes/ffi.rs +++ b/src/aead/aes/ffi.rs @@ -12,9 +12,9 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{Block, KeyBytes, BLOCK_LEN}; +use super::{Block, InOutLenInconsistentWithIvBlockLenError, IvBlock, KeyBytes, BLOCK_LEN}; use crate::{bits::BitLength, c, error, polyfill::slice}; -use core::{num::NonZeroUsize, ops::RangeFrom}; +use core::ops::RangeFrom; /// nonce || big-endian counter. #[repr(transparent)] @@ -127,9 +127,9 @@ impl AES_KEY { /// * The caller must ensure that fhe function `$name` satisfies the conditions /// for the `f` parameter to `ctr32_encrypt_blocks`. macro_rules! ctr32_encrypt_blocks { - ($name:ident, $in_out:expr, $src:expr, $key:expr, $ctr:expr $(,)? ) => {{ + ($name:ident, $in_out:expr, $src:expr, $key:expr, $iv_block:expr $(,)? ) => {{ use crate::{ - aead::aes::{ffi::AES_KEY, Counter, BLOCK_LEN}, + aead::aes::{ffi::AES_KEY, BLOCK_LEN}, c, }; prefixed_extern! { @@ -138,10 +138,10 @@ macro_rules! ctr32_encrypt_blocks { output: *mut [u8; BLOCK_LEN], blocks: c::NonZero_size_t, key: &AES_KEY, - ivec: &Counter, + ivec: &[u8; BLOCK_LEN], ); } - $key.ctr32_encrypt_blocks($name, $in_out, $src, $ctr) + $key.ctr32_encrypt_blocks($name, $in_out, $src, $iv_block) }}; } @@ -165,23 +165,22 @@ impl AES_KEY { output: *mut [u8; BLOCK_LEN], blocks: c::NonZero_size_t, key: &AES_KEY, - ivec: &Counter, + ivec: &[u8; BLOCK_LEN], ), in_out: &mut [u8], src: RangeFrom, - ctr: &mut Counter, - ) { + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError> { let (input, leftover) = slice::as_chunks(&in_out[src]); debug_assert_eq!(leftover.len(), 0); - let blocks = match NonZeroUsize::new(input.len()) { - Some(blocks) => blocks, - None => { - return; - } - }; + if input.len() != iv_block.len().get() { + return Err(InOutLenInconsistentWithIvBlockLenError::new()); + } + debug_assert!(!input.is_empty()); - let blocks_u32: u32 = blocks.get().try_into().unwrap(); + let iv_block_len = iv_block.len(); + let initial_iv = iv_block.into_initial_iv().into_block_less_safe(); let input = input.as_ptr(); let output: *mut [u8; BLOCK_LEN] = in_out.as_mut_ptr().cast(); @@ -196,9 +195,9 @@ impl AES_KEY { // * The caller is responsible for ensuring `key` was initialized by the // `set_encrypt_key!` invocation required by `f`. unsafe { - f(input, output, blocks, self, ctr); + f(input, output, iv_block_len, self, &initial_iv); } - ctr.increment_by_less_safe(blocks_u32); + Ok(()) } } diff --git a/src/aead/aes/hw.rs b/src/aead/aes/hw.rs index c7b1e51de7..218f93e4e7 100644 --- a/src/aead/aes/hw.rs +++ b/src/aead/aes/hw.rs @@ -14,7 +14,10 @@ #![cfg(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64"))] -use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY}; +use super::{ + Block, EncryptBlock, EncryptCtr32, InOutLenInconsistentWithIvBlockLenError, Iv, IvBlock, + KeyBytes, AES_KEY, +}; use crate::{cpu, error}; use core::ops::RangeFrom; @@ -56,9 +59,22 @@ impl EncryptBlock for Key { } impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { + fn ctr32_encrypt_within( + &self, + in_out: &mut [u8], + src: RangeFrom, + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError> { #[cfg(target_arch = "x86_64")] let _: cpu::Features = cpu::features(); - unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) } + unsafe { + ctr32_encrypt_blocks!( + aes_hw_ctr32_encrypt_blocks, + in_out, + src, + &self.inner, + iv_block + ) + } } } diff --git a/src/aead/aes/vp.rs b/src/aead/aes/vp.rs index 0893a9873c..6fc1ed4f0d 100644 --- a/src/aead/aes/vp.rs +++ b/src/aead/aes/vp.rs @@ -19,7 +19,10 @@ target_arch = "x86_64" ))] -use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY}; +use super::{ + Block, EncryptBlock, EncryptCtr32, InOutLenInconsistentWithIvBlockLenError, Iv, IvBlock, + KeyBytes, AES_KEY, +}; use crate::{cpu, error}; use core::ops::RangeFrom; @@ -57,34 +60,53 @@ impl EncryptBlock for Key { #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { - unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) } + fn ctr32_encrypt_within( + &self, + in_out: &mut [u8], + src: RangeFrom, + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError> { + unsafe { + ctr32_encrypt_blocks!( + vpaes_ctr32_encrypt_blocks, + in_out, + src, + &self.inner, + iv_block + ) + } } } #[cfg(target_arch = "arm")] impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { + fn ctr32_encrypt_within( + &self, + in_out: &mut [u8], + src: RangeFrom, + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError> { use super::{bs, BLOCK_LEN}; - let in_out = { - let blocks = in_out[src.clone()].len() / BLOCK_LEN; - - // bsaes operates in batches of 8 blocks. - let bsaes_blocks = if blocks >= 8 && (blocks % 8) < 6 { - // It's faster to use bsaes for all the full batches and then - // switch to vpaes for the last partial batch (if any). - blocks - (blocks % 8) - } else if blocks >= 8 { - // It's faster to let bsaes handle everything including - // the last partial batch. - blocks - } else { - // It's faster to let vpaes handle everything. - 0 - }; - let bsaes_in_out_len = bsaes_blocks * BLOCK_LEN; + let blocks = in_out[src.clone()].len() / BLOCK_LEN; + + // bsaes operates in batches of 8 blocks. + let bsaes_blocks = if blocks >= 8 && (blocks % 8) < 6 { + // It's faster to use bsaes for all the full batches and then + // switch to vpaes for the last partial batch (if any). + blocks - (blocks % 8) + } else if blocks >= 8 { + // It's faster to let bsaes handle everything including + // the last partial batch. + blocks + } else { + // It's faster to let vpaes handle everything. + 0 + }; + let (bsaes_iv_block, vpaes_iv_block) = iv_block.split_at(bsaes_blocks); + let bsaes_in_out_len = bsaes_blocks * BLOCK_LEN; + if let Some(iv_block) = bsaes_iv_block { // SAFETY: // * self.inner was initialized with `vpaes_set_encrypt_key` above, // as required by `bsaes_ctr32_encrypt_blocks_with_vpaes_key`. @@ -93,19 +115,30 @@ impl EncryptCtr32 for Key { &mut in_out[..(src.start + bsaes_in_out_len)], src.clone(), &self.inner, - ctr, - ); - } + iv_block, + ) + }?; + } - &mut in_out[bsaes_in_out_len..] - }; + if let Some(iv_block) = vpaes_iv_block { + let in_out = &mut in_out[bsaes_in_out_len..]; + // SAFETY: + // * self.inner was initialized with `vpaes_set_encrypt_key` above, + // as required by `vpaes_ctr32_encrypt_blocks`. + // * `vpaes_ctr32_encrypt_blocks` satisfies the contract for + // `ctr32_encrypt_blocks`. + unsafe { + ctr32_encrypt_blocks!( + vpaes_ctr32_encrypt_blocks, + in_out, + src, + &self.inner, + iv_block + ) + }?; + } - // SAFETY: - // * self.inner was initialized with `vpaes_set_encrypt_key` above, - // as required by `vpaes_ctr32_encrypt_blocks`. - // * `vpaes_ctr32_encrypt_blocks` satisfies the contract for - // `ctr32_encrypt_blocks`. - unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) } + Ok(()) } } @@ -122,9 +155,19 @@ impl EncryptBlock for Key { #[cfg(target_arch = "x86")] impl EncryptCtr32 for Key { - fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom, ctr: &mut Counter) { + fn ctr32_encrypt_within( + &self, + in_out: &mut [u8], + src: RangeFrom, + iv_block: IvBlock, + ) -> Result<(), InOutLenInconsistentWithIvBlockLenError> { + let (input_blocks, leftover): (&[[u8; super::BLOCK_LEN]], _) = + crate::polyfill::slice::as_chunks(&in_out[src.clone()]); + debug_assert_eq!(leftover.len(), 0); + let mut ctr = iv_block.into_counter_less_safe(input_blocks.len())?; super::super::shift::shift_full_blocks(in_out, src, |input| { - self.encrypt_iv_xor_block(ctr.increment(), *input) + self.encrypt_iv_xor_block(ctr.increment_unchecked_less_safe(), *input) }); + Ok(()) } } diff --git a/src/aead/aes_gcm.rs b/src/aead/aes_gcm.rs index feb4df1e28..ef7f0db2d8 100644 --- a/src/aead/aes_gcm.rs +++ b/src/aead/aes_gcm.rs @@ -13,14 +13,17 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use super::{ - aes::{self, Counter, BLOCK_LEN, ZERO_BLOCK}, + aes::{ + self, Counter, CounterOverflowError, InOutLenInconsistentWithIvBlockLenError, BLOCK_LEN, + ZERO_BLOCK, + }, gcm, shift, Aad, Nonce, Tag, }; use crate::{ cpu, error, polyfill::{slice, sliceutil::overwrite_at_start, usize_from_u64_saturated}, }; -use core::ops::RangeFrom; +use core::{num::NonZeroUsize, ops::RangeFrom}; #[cfg(target_arch = "x86_64")] use aes::EncryptCtr32 as _; @@ -118,8 +121,10 @@ pub(super) fn seal( aad: Aad<&[u8]>, in_out: &mut [u8], ) -> Result { - let mut ctr = Counter::one(nonce); - let tag_iv = ctr.increment(); + let (tag_iv, ctr) = Counter::one_two(nonce); + + #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] + let mut ctr = ctr; match key { #[cfg(target_arch = "x86_64")] @@ -160,8 +165,16 @@ pub(super) fn seal( } }; let (whole, remainder) = slice::as_chunks_mut(ramaining); - aes_key.ctr32_encrypt_within(slice::flatten_mut(whole), 0.., &mut ctr); - auth.update_blocks(whole); + if let Some(whole_len) = NonZeroUsize::new(whole.len()) { + let iv_block = ctr + .increment_by(whole_len) + .map_err(|_: CounterOverflowError| error::Unspecified)?; + match aes_key.ctr32_encrypt_within(slice::flatten_mut(whole), 0.., iv_block) { + Ok(()) => {} + Result::<_, InOutLenInconsistentWithIvBlockLenError>::Err(_) => unreachable!(), + } + auth.update_blocks(whole); + } seal_finish(aes_key, auth, remainder, ctr, tag_iv) } @@ -240,7 +253,14 @@ fn seal_strided {} + Err(_) => unreachable!(), + } auth.update_blocks(chunk); } @@ -257,7 +277,8 @@ fn seal_finish( if !remainder.is_empty() { let mut input = ZERO_BLOCK; overwrite_at_start(&mut input, remainder); - let mut output = aes_key.encrypt_iv_xor_block(ctr.into(), input); + let iv = ctr.try_into_iv().map_err(|_| error::Unspecified)?; + let mut output = aes_key.encrypt_iv_xor_block(iv, input); output[remainder.len()..].fill(0); auth.update_block(output); overwrite_at_start(remainder, &output); @@ -278,8 +299,10 @@ pub(super) fn open( #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] let input = in_out.get(src.clone()).ok_or(error::Unspecified)?; - let mut ctr = Counter::one(nonce); - let tag_iv = ctr.increment(); + let (tag_iv, ctr) = Counter::one_two(nonce); + + #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] + let mut ctr = ctr; match key { #[cfg(target_arch = "x86_64")] @@ -320,22 +343,31 @@ pub(super) fn open( unreachable!() } }; - // Authenticate any remaining whole blocks. + let input = match in_out.get(src.clone()) { Some(remaining_input) => remaining_input, None => unreachable!(), }; - let (whole, _) = slice::as_chunks(input); - auth.update_blocks(whole); - let whole_len = slice::flatten(whole).len(); - - // Decrypt any remaining whole blocks. - aes_key.ctr32_encrypt_within( - &mut in_out[..(src.start + whole_len)], - src.clone(), - &mut ctr, - ); + let (whole, _) = slice::as_chunks(input); + let whole_len = if let Some(whole_len) = NonZeroUsize::new(whole.len()) { + let iv_block = ctr + .increment_by(whole_len) + .map_err(|_: CounterOverflowError| error::Unspecified)?; + auth.update_blocks(whole); + let whole_len = slice::flatten(whole).len(); + match aes_key.ctr32_encrypt_within( + &mut in_out[..(src.start + whole_len)], + src.clone(), + iv_block, + ) { + Ok(()) => {} + Result::<_, InOutLenInconsistentWithIvBlockLenError>::Err(_) => unreachable!(), + } + whole_len + } else { + 0 + }; let in_out = match in_out.get_mut(whole_len..) { Some(partial) => partial, @@ -445,16 +477,24 @@ fn open_strided blocks, + None => break, + }; + let iv_block = ctr + .increment_by(num_blocks) + .map_err(|_| error::Unspecified)?; + auth.update_blocks(ciphertext); - aes_key.ctr32_encrypt_within( + match aes_key.ctr32_encrypt_within( &mut in_out[output..][..(chunk_len + in_prefix_len)], in_prefix_len.., - &mut ctr, - ); + iv_block, + ) { + Ok(()) => {} + Result::<_, InOutLenInconsistentWithIvBlockLenError>::Err(_) => unreachable!(), + } output += chunk_len; input += chunk_len; } @@ -471,11 +511,12 @@ fn open_finish( ctr: Counter, tag_iv: aes::Iv, ) -> Result { + let iv = ctr.try_into_iv().map_err(|_| error::Unspecified)?; shift::shift_partial((src.start, remainder), |remainder| { let mut input = ZERO_BLOCK; overwrite_at_start(&mut input, remainder); auth.update_block(input); - aes_key.encrypt_iv_xor_block(ctr.into(), input) + aes_key.encrypt_iv_xor_block(iv, input) }); Ok(finish(aes_key, auth, tag_iv)) diff --git a/src/limb.rs b/src/limb.rs index b18f929bab..2833497aeb 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -18,10 +18,13 @@ //! Limbs ordered least-significant-limb to most-significant-limb. The bits //! limbs use the native endianness. -use crate::{c, error, polyfill::ArrayFlatMap}; +use crate::{ + c, error, + polyfill::{usize_from_u32, ArrayFlatMap}, +}; #[cfg(any(test, feature = "alloc"))] -use crate::{bits, constant_time, polyfill::usize_from_u32}; +use crate::{bits, constant_time}; #[cfg(feature = "alloc")] use core::num::Wrapping; diff --git a/src/polyfill.rs b/src/polyfill.rs index 4d5a0ec1f0..32380e5edf 100644 --- a/src/polyfill.rs +++ b/src/polyfill.rs @@ -15,6 +15,8 @@ //! Polyfills for functionality that will (hopefully) be added to Rust's //! standard library soon. +use core::num::{NonZeroU32, NonZeroUsize}; + #[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] #[inline(always)] pub const fn u64_from_usize(x: usize) -> u64 { @@ -26,6 +28,12 @@ pub const fn usize_from_u32(x: u32) -> usize { x as usize } +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +pub const fn nonzerousize_from_nonzerou32(x: NonZeroU32) -> NonZeroUsize { + let value = usize_from_u32(x.get()); + unsafe { NonZeroUsize::new_unchecked(value) } +} + #[cfg(all(target_arch = "aarch64", target_pointer_width = "64"))] #[allow(clippy::cast_possible_truncation)] pub fn usize_from_u64(x: u64) -> usize {