Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: chunk KOS extend message #98

Merged
merged 8 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ot/mpz-ot-core/src/kos/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
pub enum SenderError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("invalid count, must be a multiple of 64: {0}")]
InvalidCount(usize),
#[error("count mismatch: expected {0}, got {1}")]
CountMismatch(usize, usize),
#[error("id mismatch: expected {0}, got {1}")]
Expand All @@ -22,6 +24,8 @@ pub enum SenderError {
pub enum ReceiverError {
#[error("invalid state: expected {0}")]
InvalidState(String),
#[error("invalid count, must be a multiple of 64: {0}")]
InvalidCount(usize),
#[error("count mismatch: expected {0}, got {1}")]
CountMismatch(usize, usize),
#[error("id mismatch: expected {0}, got {1}")]
Expand Down
14 changes: 14 additions & 0 deletions ot/mpz-ot-core/src/kos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ pub(crate) type RngSeed = <Rng as SeedableRng>::Seed;
/// AES-128 CTR used for encryption.
pub(crate) type Aes128Ctr = ctr::Ctr64LE<aes::Aes128>;

/// Pads the number of OTs to accomodate for the KOS extension check and
/// the extension matrix transpose optimization.
pub fn pad_ot_count(mut count: usize) -> usize {
// Add OTs for the KOS extension check.
count += CSP + SSP;
// Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization).
(count + 63) & !63
}

/// Returns the size in bytes of the extension matrix for a given number of OTs.
pub fn extension_matrix_size(count: usize) -> usize {
count * CSP / 8
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
40 changes: 38 additions & 2 deletions ot/mpz-ot-core/src/kos/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::msgs::Derandomize;
#[allow(missing_docs)]
pub enum Message<BaseMsg> {
BaseMsg(BaseMsg),
StartExtend(StartExtend),
Extend(Extend),
Check(Check),
Derandomize(Derandomize),
Expand All @@ -33,15 +34,50 @@ impl<BaseMsg> From<MessageError<BaseMsg>> for std::io::Error {
}
}

/// Extension message sent by the receiver.
/// Extension message sent by the receiver to agree upon the number of OTs to set up.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Extend {
pub struct StartExtend {
/// The number of OTs to set up.
pub count: usize,
}

/// Extension message sent by the receiver.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Extend {
/// The receiver's extension vectors.
pub us: Vec<u8>,
}

impl Extend {
/// Returns an iterator over the chunks of the message.
pub fn into_chunks(self, chunk_size: usize) -> ExtendChunks {
ExtendChunks {
chunk_size,
us: self.us.into_iter(),
}
}
}

/// Iterator over the chunks of an extension message.
pub struct ExtendChunks {
chunk_size: usize,
us: <Vec<u8> as IntoIterator>::IntoIter,
}

impl Iterator for ExtendChunks {
type Item = Extend;

fn next(&mut self) -> Option<Self::Item> {
if self.us.len() == 0 {
None
} else {
Some(Extend {
us: self.us.by_ref().take(self.chunk_size).collect::<Vec<_>>(),
})
}
}
}

/// Values for the correlation check sent by the receiver.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
Expand Down
11 changes: 7 additions & 4 deletions ot/mpz-ot-core/src/kos/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ impl Receiver<state::Extension> {

/// Perform the IKNP OT extension.
///
/// The provided count _must_ be a multiple of 64, otherwise an error will be returned.
///
/// # Sacrificial OTs
///
/// Performing the consistency check sacrifices 256 OTs, so be sure to
Expand All @@ -132,16 +134,17 @@ impl Receiver<state::Extension> {
///
/// # Arguments
///
/// * `count` - The number of OTs to extend.
/// * `count` - The number of OTs to extend (must be a multiple of 64).
pub fn extend(&mut self, count: usize) -> Result<Extend, ReceiverError> {
if self.state.extended {
return Err(ReceiverError::InvalidState(
"extending more than once is currently disabled".to_string(),
));
}

// Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization).
let count = (count + 63) & !63;
if count % 64 != 0 {
return Err(ReceiverError::InvalidCount(count));
}

const NROWS: usize = CSP;
let row_width = count / 8;
Expand Down Expand Up @@ -196,7 +199,7 @@ impl Receiver<state::Extension> {
);
self.state.unchecked_choices.extend(choices);

Ok(Extend { count, us })
Ok(Extend { us })
}

/// Performs the correlation check for all outstanding OTS.
Expand Down
24 changes: 10 additions & 14 deletions ot/mpz-ot-core/src/kos/sender.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
kos::{
extension_matrix_size,
msgs::{Check, Ciphertexts, Extend, SenderPayload},
Aes128Ctr, Rng, RngSeed, SenderConfig, SenderError, CSP, SSP,
},
Expand Down Expand Up @@ -96,6 +97,8 @@ impl Sender<state::Extension> {

/// Perform the IKNP OT extension.
///
/// The provided count _must_ be a multiple of 64, otherwise an error will be returned.
///
/// # Sacrificial OTs
///
/// Performing the consistency check sacrifices 256 OTs, so be sure to extend enough to
Expand All @@ -111,32 +114,25 @@ impl Sender<state::Extension> {
///
/// # Arguments
///
/// * `count` - The number of additional OTs to extend
/// * `extend` - The receiver's setup message
/// * `count` - The number of additional OTs to extend (must be a multiple of 64).
/// * `extend` - The receiver's setup message.
pub fn extend(&mut self, count: usize, extend: Extend) -> Result<(), SenderError> {
if self.state.extended {
return Err(SenderError::InvalidState(
"extending more than once is currently disabled".to_string(),
));
}

// Round up the OTs to extend to the nearest multiple of 64 (matrix transpose optimization).
let count = (count + 63) & !63;
if count % 64 != 0 {
return Err(SenderError::InvalidCount(count));
}

const NROWS: usize = CSP;
let row_width = count / 8;

let Extend {
us,
count: receiver_count,
} = extend;

// Make sure the number of OTs to extend matches the receiver's setup message.
if receiver_count != count {
return Err(SenderError::CountMismatch(receiver_count, count));
}
let Extend { us } = extend;

if us.len() != NROWS * row_width {
if us.len() != extension_matrix_size(count) {
return Err(SenderError::InvalidExtend);
}

Expand Down
1 change: 1 addition & 0 deletions ot/mpz-ot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ itybity.workspace = true
enum-try-as-inner.workspace = true
opaque-debug.workspace = true
serde = { workspace = true, optional = true }
cfg-if.workspace = true

[dev-dependencies]
rstest = { workspace = true }
Expand Down
10 changes: 10 additions & 0 deletions ot/mpz-ot/src/kos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ pub use mpz_ot_core::kos::{
};
use utils_aio::{sink::IoSink, stream::IoStream};

// If we're testing we use a smaller chunk size to make sure the chunking code paths are tested.
cfg_if::cfg_if! {
if #[cfg(test)] {
pub(crate) const EXTEND_CHUNK_SIZE: usize = 1024;
} else {
/// The size of the chunks used to send the extension matrix, 4MB.
pub(crate) const EXTEND_CHUNK_SIZE: usize = 4 * 1024 * 1024;
}
}

/// Converts a sink of KOS messages into a sink of base OT messages.
pub(crate) fn into_base_sink<'a, Si: IoSink<msgs::Message<T>> + Send + Unpin, T: Send + 'a>(
sink: &'a mut Si,
Expand Down
19 changes: 14 additions & 5 deletions ot/mpz-ot/src/kos/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use futures::SinkExt;
use itybity::{FromBitIterator, IntoBitIterator};
use mpz_core::{cointoss, prg::Prg, Block, ProtocolMessage};
use mpz_ot_core::kos::{
msgs::Message, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, CSP, SSP,
msgs::{Message, StartExtend},
pad_ot_count, receiver_state as state, Receiver as ReceiverCore, ReceiverConfig, CSP,
};

use enum_try_as_inner::EnumTryAsInner;
Expand All @@ -15,7 +16,9 @@ use utils_aio::{
stream::{ExpectStreamExt, IoStream},
};

use super::{into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError};
use super::{
into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError, EXTEND_CHUNK_SIZE,
};
use crate::{
OTError, OTReceiver, OTSender, OTSetup, RandomOTReceiver, VerifiableOTReceiver,
VerifiableOTSender,
Expand Down Expand Up @@ -90,9 +93,11 @@ where
let mut ext_receiver =
std::mem::replace(&mut self.state, State::Error).try_into_extension()?;

// Extend the OTs, adding padding for the consistency check.
let count = pad_ot_count(count);

// Extend the OTs.
let (mut ext_receiver, extend) = Backend::spawn(move || {
let extend = ext_receiver.extend(count + CSP + SSP);
let extend = ext_receiver.extend(count);
themighty1 marked this conversation as resolved.
Show resolved Hide resolved

(ext_receiver, extend)
})
Expand All @@ -105,7 +110,11 @@ where
let (cointoss_sender, cointoss_commitment) = cointoss::Sender::new(vec![seed]).send();

// Send the extend message and cointoss commitment
sink.feed(Message::Extend(extend)).await?;
sink.feed(Message::StartExtend(StartExtend { count }))
.await?;
for extend in extend.into_chunks(EXTEND_CHUNK_SIZE) {
sink.feed(Message::Extend(extend)).await?;
}
sink.feed(Message::CointossCommit(cointoss_commitment))
.await?;
sink.flush().await?;
Expand Down
45 changes: 34 additions & 11 deletions ot/mpz-ot/src/kos/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use futures_util::SinkExt;
use itybity::IntoBits;
use mpz_core::{cointoss, prg::Prg, Block, ProtocolMessage};
use mpz_ot_core::kos::{
msgs::Message, sender_state as state, Sender as SenderCore, SenderConfig, CSP, SSP,
extension_matrix_size,
msgs::{Extend, Message, StartExtend},
pad_ot_count, sender_state as state, Sender as SenderCore, SenderConfig, CSP,
};
use rand::{thread_rng, Rng};
use rand_core::{RngCore, SeedableRng};
Expand Down Expand Up @@ -138,23 +140,44 @@ where
let mut ext_sender =
std::mem::replace(&mut self.state, State::Error).try_into_extension()?;

// Receive extend message from the receiver.
let extend = stream
let count = pad_ot_count(count);

let StartExtend {
count: receiver_count,
} = stream
.expect_next()
.await?
.try_into_extend()
.try_into_start_extend()
.map_err(SenderError::from)?;

if count != receiver_count {
return Err(SenderError::ConfigError(
"sender and receiver count mismatch".to_string(),
));
}

let expected_us = extension_matrix_size(count);
let mut extend = Extend {
us: Vec::with_capacity(expected_us),
};

// Receive extension matrix from the receiver.
while extend.us.len() < expected_us {
let Extend { us: chunk } = stream
.expect_next()
.await?
.try_into_extend()
.map_err(SenderError::from)?;

extend.us.extend(chunk);
sinui0 marked this conversation as resolved.
Show resolved Hide resolved
}

// Receive coin toss commitments from the receiver.
let commitment = stream.expect_next().await?.try_into_cointoss_commit()?;

// Extend the OTs, adding padding for the consistency check.
let mut ext_sender = Backend::spawn(move || {
ext_sender
.extend(count + CSP + SSP, extend)
.map(|_| ext_sender)
})
.await?;
// Extend the OTs.
let mut ext_sender =
Backend::spawn(move || ext_sender.extend(count, extend).map(|_| ext_sender)).await?;

// Execute coin toss protocol for consistency check.
let seed: Block = thread_rng().gen();
Expand Down
Loading