diff --git a/build.rs b/build.rs index 41abfcd4b1..e09a250ddd 100644 --- a/build.rs +++ b/build.rs @@ -962,6 +962,7 @@ fn prefix_all_symbols(pp: char, prefix_prefix: &str, prefix: &str) -> String { "bn_sqr8x_internal", "bn_sqrx8x_internal", "bsaes_ctr32_encrypt_blocks", + "bssl_constant_time_test_conditional_memxor", "bssl_constant_time_test_main", "chacha20_poly1305_open", "chacha20_poly1305_seal", diff --git a/crypto/constant_time_test.c b/crypto/constant_time_test.c index 170525f470..11be4974df 100644 --- a/crypto/constant_time_test.c +++ b/crypto/constant_time_test.c @@ -88,36 +88,36 @@ static int test_select_w(crypto_word_t a, crypto_word_t b) { return 0; } -static crypto_word_t test_values_s[] = { - 0, - 1, - 1024, - 12345, - 32000, +static crypto_word_t test_values_w[] = { + 0, + 1, + 1024, + 12345, + 32000, #if defined(OPENSSL_64_BIT) - 0xffffffff / 2 - 1, - 0xffffffff / 2, - 0xffffffff / 2 + 1, - 0xffffffff - 1, - 0xffffffff, + 0xffffffff / 2 - 1, + 0xffffffff / 2, + 0xffffffff / 2 + 1, + 0xffffffff - 1, + 0xffffffff, #endif - SIZE_MAX / 2 - 1, - SIZE_MAX / 2, - SIZE_MAX / 2 + 1, - SIZE_MAX - 1, - SIZE_MAX + SIZE_MAX / 2 - 1, + SIZE_MAX / 2, + SIZE_MAX / 2 + 1, + SIZE_MAX - 1, + SIZE_MAX }; int bssl_constant_time_test_main(void) { int num_failed = 0; for (size_t i = 0; - i < sizeof(test_values_s) / sizeof(test_values_s[0]); ++i) { - crypto_word_t a = test_values_s[i]; + i < sizeof(test_values_w) / sizeof(test_values_w[0]); ++i) { + crypto_word_t a = test_values_w[i]; num_failed += test_is_zero_w(a); for (size_t j = 0; - j < sizeof(test_values_s) / sizeof(test_values_s[0]); ++j) { - crypto_word_t b = test_values_s[j]; + j < sizeof(test_values_w) / sizeof(test_values_w[0]); ++j) { + crypto_word_t b = test_values_w[j]; num_failed += test_binary_op_w(&constant_time_eq_w, a, b, a == b); num_failed += test_binary_op_w(&constant_time_eq_w, b, a, b == a); num_failed += test_select_w(a, b); @@ -126,3 +126,10 @@ int bssl_constant_time_test_main(void) { return num_failed == 0; } + +// Exposes `constant_time_conditional_memxor` to Rust for tests only. +void bssl_constant_time_test_conditional_memxor(uint8_t dst[256], + const uint8_t src[256], + crypto_word_t b) { + constant_time_conditional_memxor(dst, src, 256, b); +} diff --git a/src/constant_time.rs b/src/constant_time.rs index 9ccf8f6653..e41ad6187e 100644 --- a/src/constant_time.rs +++ b/src/constant_time.rs @@ -37,7 +37,8 @@ prefixed_extern! { #[cfg(test)] mod tests { - use crate::{bssl, error}; + use crate::limb::LimbMask; + use crate::{bssl, error, rand}; #[test] fn test_constant_time() -> Result<(), error::Unspecified> { @@ -46,4 +47,43 @@ mod tests { } Result::from(unsafe { bssl_constant_time_test_main() }) } + + #[test] + fn constant_time_conditional_memxor() -> Result<(), error::Unspecified> { + let rng = rand::SystemRandom::new(); + for _ in 0..256 { + let mut out = rand::generate::<[u8; 256]>(&rng)?.expose(); + let input = rand::generate::<[u8; 256]>(&rng)?.expose(); + + // Mask to 16 bits to make zero more likely than it would otherwise be. + let b = (rand::generate::<[u8; 1]>(&rng)?.expose()[0] & 0x0f) != 0; + + let ref_in = input; + let mut ref_out = out; + if b { + ref_out + .iter_mut() + .zip(ref_in.iter()) + .for_each(|(out, input)| { + *out ^= input; + }); + } + + prefixed_extern! { + fn bssl_constant_time_test_conditional_memxor(dst: &mut [u8; 256], src: &[u8; 256], b: LimbMask); + } + unsafe { + bssl_constant_time_test_conditional_memxor( + &mut out, + &input, + if b { LimbMask::True } else { LimbMask::False }, + ); + } + + assert_eq!(ref_in, input); + assert_eq!(ref_out, out); + } + + Ok(()) + } }