diff --git a/ot/mpz-ot-core/src/kos/error.rs b/ot/mpz-ot-core/src/kos/error.rs index b91e4eb7..3e230367 100644 --- a/ot/mpz-ot-core/src/kos/error.rs +++ b/ot/mpz-ot-core/src/kos/error.rs @@ -4,6 +4,8 @@ pub enum SenderError { #[error("invalid state: expected {0}")] InvalidState(String), + #[error("invalid count, must be a multiple of 64: {0}")] + InvalidCount(usize), #[error("count mismatch: expected {0}, got {1}")] CountMismatch(usize, usize), #[error("id mismatch: expected {0}, got {1}")] @@ -22,6 +24,8 @@ pub enum SenderError { pub enum ReceiverError { #[error("invalid state: expected {0}")] InvalidState(String), + #[error("invalid count, must be a multiple of 64: {0}")] + InvalidCount(usize), #[error("count mismatch: expected {0}, got {1}")] CountMismatch(usize, usize), #[error("id mismatch: expected {0}, got {1}")] diff --git a/ot/mpz-ot-core/src/kos/mod.rs b/ot/mpz-ot-core/src/kos/mod.rs index 438539b2..a2c268e1 100644 --- a/ot/mpz-ot-core/src/kos/mod.rs +++ b/ot/mpz-ot-core/src/kos/mod.rs @@ -28,6 +28,20 @@ pub(crate) type RngSeed = ::Seed; /// AES-128 CTR used for encryption. pub(crate) type Aes128Ctr = ctr::Ctr64LE; +/// Pads the number of OTs to accomodate for the KOS extension check and +/// the extension matrix transpose optimization. +pub fn pad_ot_count(mut count: usize) -> usize { + // Add OTs for the KOS extension check. + count += CSP + SSP; + // Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization). + (count + 63) & !63 +} + +/// Returns the size in bytes of the extension matrix for a given number of OTs. +pub fn extension_matrix_size(count: usize) -> usize { + count * CSP / 8 +} + #[cfg(test)] mod tests { use super::*; diff --git a/ot/mpz-ot-core/src/kos/msgs.rs b/ot/mpz-ot-core/src/kos/msgs.rs index afdf2b63..7ee1bc13 100644 --- a/ot/mpz-ot-core/src/kos/msgs.rs +++ b/ot/mpz-ot-core/src/kos/msgs.rs @@ -18,6 +18,7 @@ use crate::msgs::Derandomize; #[allow(missing_docs)] pub enum Message { BaseMsg(BaseMsg), + StartExtend(StartExtend), Extend(Extend), Check(Check), Derandomize(Derandomize), @@ -33,15 +34,50 @@ impl From> for std::io::Error { } } -/// Extension message sent by the receiver. +/// Extension message sent by the receiver to agree upon the number of OTs to set up. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct Extend { +pub struct StartExtend { /// The number of OTs to set up. pub count: usize, +} + +/// Extension message sent by the receiver. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Extend { /// The receiver's extension vectors. pub us: Vec, } +impl Extend { + /// Returns an iterator over the chunks of the message. + pub fn into_chunks(self, chunk_size: usize) -> ExtendChunks { + ExtendChunks { + chunk_size, + us: self.us.into_iter(), + } + } +} + +/// Iterator over the chunks of an extension message. +pub struct ExtendChunks { + chunk_size: usize, + us: as IntoIterator>::IntoIter, +} + +impl Iterator for ExtendChunks { + type Item = Extend; + + fn next(&mut self) -> Option { + if self.us.len() == 0 { + None + } else { + Some(Extend { + us: self.us.by_ref().take(self.chunk_size).collect::>(), + }) + } + } +} + /// Values for the correlation check sent by the receiver. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] diff --git a/ot/mpz-ot-core/src/kos/receiver.rs b/ot/mpz-ot-core/src/kos/receiver.rs index b7b8fb59..89c28c9d 100644 --- a/ot/mpz-ot-core/src/kos/receiver.rs +++ b/ot/mpz-ot-core/src/kos/receiver.rs @@ -117,6 +117,8 @@ impl Receiver { /// Perform the IKNP OT extension. /// + /// The provided count _must_ be a multiple of 64, otherwise an error will be returned. + /// /// # Sacrificial OTs /// /// Performing the consistency check sacrifices 256 OTs, so be sure to @@ -132,7 +134,7 @@ impl Receiver { /// /// # Arguments /// - /// * `count` - The number of OTs to extend. + /// * `count` - The number of OTs to extend (must be a multiple of 64). pub fn extend(&mut self, count: usize) -> Result { if self.state.extended { return Err(ReceiverError::InvalidState( @@ -140,8 +142,9 @@ impl Receiver { )); } - // Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization). - let count = (count + 63) & !63; + if count % 64 != 0 { + return Err(ReceiverError::InvalidCount(count)); + } const NROWS: usize = CSP; let row_width = count / 8; @@ -196,7 +199,7 @@ impl Receiver { ); self.state.unchecked_choices.extend(choices); - Ok(Extend { count, us }) + Ok(Extend { us }) } /// Performs the correlation check for all outstanding OTS. diff --git a/ot/mpz-ot-core/src/kos/sender.rs b/ot/mpz-ot-core/src/kos/sender.rs index 8e692727..1e0ceee1 100644 --- a/ot/mpz-ot-core/src/kos/sender.rs +++ b/ot/mpz-ot-core/src/kos/sender.rs @@ -1,5 +1,6 @@ use crate::{ kos::{ + extension_matrix_size, msgs::{Check, Ciphertexts, Extend, SenderPayload}, Aes128Ctr, Rng, RngSeed, SenderConfig, SenderError, CSP, SSP, }, @@ -96,6 +97,8 @@ impl Sender { /// Perform the IKNP OT extension. /// + /// The provided count _must_ be a multiple of 64, otherwise an error will be returned. + /// /// # Sacrificial OTs /// /// Performing the consistency check sacrifices 256 OTs, so be sure to extend enough to @@ -111,8 +114,8 @@ impl Sender { /// /// # Arguments /// - /// * `count` - The number of additional OTs to extend - /// * `extend` - The receiver's setup message + /// * `count` - The number of additional OTs to extend (must be a multiple of 64). + /// * `extend` - The receiver's setup message. pub fn extend(&mut self, count: usize, extend: Extend) -> Result<(), SenderError> { if self.state.extended { return Err(SenderError::InvalidState( @@ -120,23 +123,16 @@ impl Sender { )); } - // Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization). - let count = (count + 63) & !63; + if count % 64 != 0 { + return Err(SenderError::InvalidCount(count)); + } const NROWS: usize = CSP; let row_width = count / 8; - let Extend { - us, - count: receiver_count, - } = extend; - - // Make sure the number of OTs to extend matches the receiver's setup message. - if receiver_count != count { - return Err(SenderError::CountMismatch(receiver_count, count)); - } + let Extend { us } = extend; - if us.len() != NROWS * row_width { + if us.len() != extension_matrix_size(count) { return Err(SenderError::InvalidExtend); } diff --git a/ot/mpz-ot/Cargo.toml b/ot/mpz-ot/Cargo.toml index ec171c99..96d15c8e 100644 --- a/ot/mpz-ot/Cargo.toml +++ b/ot/mpz-ot/Cargo.toml @@ -32,6 +32,7 @@ itybity.workspace = true enum-try-as-inner.workspace = true opaque-debug.workspace = true serde = { workspace = true, optional = true } +cfg-if.workspace = true [dev-dependencies] rstest = { workspace = true } diff --git a/ot/mpz-ot/src/kos/mod.rs b/ot/mpz-ot/src/kos/mod.rs index 39aca16b..d43542fb 100644 --- a/ot/mpz-ot/src/kos/mod.rs +++ b/ot/mpz-ot/src/kos/mod.rs @@ -18,6 +18,16 @@ pub use mpz_ot_core::kos::{ }; use utils_aio::{sink::IoSink, stream::IoStream}; +// If we're testing we use a smaller chunk size to make sure the chunking code paths are tested. +cfg_if::cfg_if! { + if #[cfg(test)] { + pub(crate) const EXTEND_CHUNK_SIZE: usize = 1024; + } else { + /// The size of the chunks used to send the extension matrix, 4MB. + pub(crate) const EXTEND_CHUNK_SIZE: usize = 4 * 1024 * 1024; + } +} + /// Converts a sink of KOS messages into a sink of base OT messages. pub(crate) fn into_base_sink<'a, Si: IoSink> + Send + Unpin, T: Send + 'a>( sink: &'a mut Si, diff --git a/ot/mpz-ot/src/kos/receiver.rs b/ot/mpz-ot/src/kos/receiver.rs index 1562e76a..1d48a349 100644 --- a/ot/mpz-ot/src/kos/receiver.rs +++ b/ot/mpz-ot/src/kos/receiver.rs @@ -3,7 +3,8 @@ use futures::SinkExt; use itybity::{FromBitIterator, IntoBitIterator}; use mpz_core::{cointoss, prg::Prg, Block, ProtocolMessage}; use mpz_ot_core::kos::{ - msgs::Message, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, CSP, SSP, + msgs::{Message, StartExtend}, + pad_ot_count, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, CSP, }; use enum_try_as_inner::EnumTryAsInner; @@ -15,7 +16,9 @@ use utils_aio::{ stream::{ExpectStreamExt, IoStream}, }; -use super::{into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError}; +use super::{ + into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError, EXTEND_CHUNK_SIZE, +}; use crate::{ OTError, OTReceiver, OTSender, OTSetup, RandomOTReceiver, VerifiableOTReceiver, VerifiableOTSender, @@ -90,9 +93,11 @@ where let mut ext_receiver = std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - // Extend the OTs, adding padding for the consistency check. + let count = pad_ot_count(count); + + // Extend the OTs. let (mut ext_receiver, extend) = Backend::spawn(move || { - let extend = ext_receiver.extend(count + CSP + SSP); + let extend = ext_receiver.extend(count); (ext_receiver, extend) }) @@ -105,7 +110,11 @@ where let (cointoss_sender, cointoss_commitment) = cointoss::Sender::new(vec![seed]).send(); // Send the extend message and cointoss commitment - sink.feed(Message::Extend(extend)).await?; + sink.feed(Message::StartExtend(StartExtend { count })) + .await?; + for extend in extend.into_chunks(EXTEND_CHUNK_SIZE) { + sink.feed(Message::Extend(extend)).await?; + } sink.feed(Message::CointossCommit(cointoss_commitment)) .await?; sink.flush().await?; diff --git a/ot/mpz-ot/src/kos/sender.rs b/ot/mpz-ot/src/kos/sender.rs index 407ddadd..91358b90 100644 --- a/ot/mpz-ot/src/kos/sender.rs +++ b/ot/mpz-ot/src/kos/sender.rs @@ -4,7 +4,9 @@ use futures_util::SinkExt; use itybity::IntoBits; use mpz_core::{cointoss, prg::Prg, Block, ProtocolMessage}; use mpz_ot_core::kos::{ - msgs::Message, sender_state as state, Sender as SenderCore, SenderConfig, CSP, SSP, + extension_matrix_size, + msgs::{Extend, Message, StartExtend}, + pad_ot_count, sender_state as state, Sender as SenderCore, SenderConfig, CSP, }; use rand::{thread_rng, Rng}; use rand_core::{RngCore, SeedableRng}; @@ -138,23 +140,44 @@ where let mut ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - // Receive extend message from the receiver. - let extend = stream + let count = pad_ot_count(count); + + let StartExtend { + count: receiver_count, + } = stream .expect_next() .await? - .try_into_extend() + .try_into_start_extend() .map_err(SenderError::from)?; + if count != receiver_count { + return Err(SenderError::ConfigError( + "sender and receiver count mismatch".to_string(), + )); + } + + let expected_us = extension_matrix_size(count); + let mut extend = Extend { + us: Vec::with_capacity(expected_us), + }; + + // Receive extension matrix from the receiver. + while extend.us.len() < expected_us { + let Extend { us: chunk } = stream + .expect_next() + .await? + .try_into_extend() + .map_err(SenderError::from)?; + + extend.us.extend(chunk); + } + // Receive coin toss commitments from the receiver. let commitment = stream.expect_next().await?.try_into_cointoss_commit()?; - // Extend the OTs, adding padding for the consistency check. - let mut ext_sender = Backend::spawn(move || { - ext_sender - .extend(count + CSP + SSP, extend) - .map(|_| ext_sender) - }) - .await?; + // Extend the OTs. + let mut ext_sender = + Backend::spawn(move || ext_sender.extend(count, extend).map(|_| ext_sender)).await?; // Execute coin toss protocol for consistency check. let seed: Block = thread_rng().gen();