From 40aa96acee0144051943967492fe6c2c9d1f1e55 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Fri, 28 Jun 2024 23:55:36 +0800 Subject: [PATCH 01/17] add ferret with io --- crates/mpz-common/src/ideal.rs | 4 +- crates/mpz-core/src/ggm_tree.rs | 52 ++-- crates/mpz-ot-core/src/ferret/mod.rs | 19 +- crates/mpz-ot-core/src/ferret/mpcot/mod.rs | 5 +- .../mpz-ot-core/src/ferret/mpcot/receiver.rs | 30 +-- .../src/ferret/mpcot/receiver_regular.rs | 30 +-- crates/mpz-ot-core/src/ferret/mpcot/sender.rs | 30 +-- .../src/ferret/mpcot/sender_regular.rs | 30 +-- crates/mpz-ot-core/src/ferret/receiver.rs | 20 +- crates/mpz-ot-core/src/ferret/sender.rs | 24 +- crates/mpz-ot-core/src/ferret/spcot/mod.rs | 81 ++++-- .../mpz-ot-core/src/ferret/spcot/receiver.rs | 243 +++++++++++++----- crates/mpz-ot-core/src/ferret/spcot/sender.rs | 156 +++++++---- crates/mpz-ot/src/ferret/error.rs | 67 +++++ crates/mpz-ot/src/ferret/mod.rs | 175 +++++++++++++ crates/mpz-ot/src/ferret/mpcot/error.rs | 59 +++++ crates/mpz-ot/src/ferret/mpcot/mod.rs | 165 ++++++++++++ crates/mpz-ot/src/ferret/mpcot/receiver.rs | 192 ++++++++++++++ crates/mpz-ot/src/ferret/mpcot/sender.rs | 166 ++++++++++++ crates/mpz-ot/src/ferret/receiver.rs | 192 ++++++++++++++ crates/mpz-ot/src/ferret/sender.rs | 160 ++++++++++++ crates/mpz-ot/src/ferret/spcot/error.rs | 59 +++++ crates/mpz-ot/src/ferret/spcot/mod.rs | 103 ++++++++ crates/mpz-ot/src/ferret/spcot/receiver.rs | 164 ++++++++++++ crates/mpz-ot/src/ferret/spcot/sender.rs | 144 +++++++++++ crates/mpz-ot/src/ideal/cot.rs | 11 +- crates/mpz-ot/src/lib.rs | 1 + 27 files changed, 2126 insertions(+), 256 deletions(-) create mode 100644 crates/mpz-ot/src/ferret/error.rs create mode 100644 crates/mpz-ot/src/ferret/mod.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/error.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/mod.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/receiver.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/sender.rs create mode 100644 crates/mpz-ot/src/ferret/receiver.rs create mode 100644 crates/mpz-ot/src/ferret/sender.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/error.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/mod.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/receiver.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/sender.rs diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 804472ef..7fcb1628 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -18,7 +18,7 @@ struct Buffer { } /// The ideal functionality from the perspective of Alice. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Alice { f: Arc>, buffer: Arc>, @@ -79,7 +79,7 @@ impl Alice { } /// The ideal functionality from the perspective of Bob. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Bob { f: Arc>, buffer: Arc>, diff --git a/crates/mpz-core/src/ggm_tree.rs b/crates/mpz-core/src/ggm_tree.rs index 913fffb6..840efcc6 100644 --- a/crates/mpz-core/src/ggm_tree.rs +++ b/crates/mpz-core/src/ggm_tree.rs @@ -32,33 +32,35 @@ impl GgmTree { assert_eq!(k0.len(), self.depth); assert_eq!(k1.len(), self.depth); let mut buf = [Block::ZERO; 8]; - self.tkprp.expand_1to2(tree, seed); - k0[0] = tree[0]; - k1[0] = tree[1]; + if self.depth > 1 { + self.tkprp.expand_1to2(tree, seed); + k0[0] = tree[0]; + k1[0] = tree[1]; - self.tkprp.expand_2to4(&mut buf, tree); - k0[1] = buf[0] ^ buf[2]; - k1[1] = buf[1] ^ buf[3]; - tree[0..4].copy_from_slice(&buf[0..4]); - - for h in 2..self.depth { - k0[h] = Block::ZERO; - k1[h] = Block::ZERO; - - // How many nodes there are in this layer - let sz = 1 << h; - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - k0[h] ^= buf[0]; - k0[h] ^= buf[2]; - k0[h] ^= buf[4]; - k0[h] ^= buf[6]; - k1[h] ^= buf[1]; - k1[h] ^= buf[3]; - k1[h] ^= buf[5]; - k1[h] ^= buf[7]; + self.tkprp.expand_2to4(&mut buf, tree); + k0[1] = buf[0] ^ buf[2]; + k1[1] = buf[1] ^ buf[3]; + tree[0..4].copy_from_slice(&buf[0..4]); - tree[2 * i..2 * i + 8].copy_from_slice(&buf); + for h in 2..self.depth { + k0[h] = Block::ZERO; + k1[h] = Block::ZERO; + + // How many nodes there are in this layer + let sz = 1 << h; + for i in (0..=sz - 4).rev().step_by(4) { + self.tkprp.expand_4to8(&mut buf, &tree[i..]); + k0[h] ^= buf[0]; + k0[h] ^= buf[2]; + k0[h] ^= buf[4]; + k0[h] ^= buf[6]; + k1[h] ^= buf[1]; + k1[h] ^= buf[3]; + k1[h] ^= buf[5]; + k1[h] ^= buf[7]; + + tree[2 * i..2 * i + 8].copy_from_slice(&buf); + } } } } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 3ad7701e..bbbf264a 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -36,11 +36,12 @@ pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { }; /// The type of Lpn parameters. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, Default)] pub enum LpnType { /// Uniform error distribution. Uniform, /// Regular error distribution. + #[default] Regular, } @@ -48,7 +49,6 @@ pub enum LpnType { mod tests { use super::*; - use msgs::LpnMatrixSeed; use receiver::Receiver; use sender::Sender; @@ -56,7 +56,6 @@ mod tests { use crate::test::assert_cot; use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; use mpz_core::{lpn::LpnParameters, prg::Prg}; - use rand::SeedableRng; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { n: 9600, @@ -66,7 +65,7 @@ mod tests { #[test] fn ferret_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_cot = IdealCOT::default(); let mut ideal_mpcot = IdealMpcot::default(); @@ -101,18 +100,8 @@ mod tests { ) .unwrap(); - let LpnMatrixSeed { - seed: lpn_matrix_seed, - } = seed; - let mut sender = sender - .setup( - delta, - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &v, - ) + .setup(delta, LPN_PARAMETERS_TEST, LpnType::Regular, seed, &v) .unwrap(); // extend once diff --git a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs index e74dc38a..047780d4 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs @@ -16,11 +16,10 @@ mod tests { use crate::ideal::spcot::IdealSpcot; use crate::{SPCOTReceiverOutput, SPCOTSenderOutput}; use mpz_core::prg::Prg; - use rand::SeedableRng; #[test] fn mpcot_general_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); @@ -96,7 +95,7 @@ mod tests { #[test] fn mpcot_regular_test() { - let mut prg = Prg::from_seed([2u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs index 0f8613af..e4d362da 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs @@ -32,11 +32,11 @@ impl Receiver { /// # Argument /// /// * `hash_seed` - Random seed to generate hashes, will be sent to the sender. - pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { + pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); let recv = Receiver { - state: state::PreExtension { + state: state::Extension { counter: 0, hashes: Arc::new(hashes), }, @@ -48,7 +48,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -63,7 +63,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { if alphas.len() as u32 > n { return Err(ReceiverError::InvalidInput( "length of alphas should not exceed n".to_string(), @@ -104,7 +104,7 @@ impl Receiver { } let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, m, n, @@ -117,7 +117,7 @@ impl Receiver { Ok((receiver, p)) } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -128,7 +128,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt.len() != self.state.m { return Err(ReceiverError::InvalidInput( "the length rt should be m".to_string(), @@ -165,7 +165,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, hashes: self.state.hashes, }, @@ -182,8 +182,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -200,20 +200,20 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, /// The hashes to generate Cuckoo hash table. pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state of extension. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter pub(super) counter: usize, /// Current length of Cuckoo hash table, will possibly be changed in each extension. @@ -228,7 +228,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs index 2b226108..e1e7edfe 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs @@ -19,13 +19,13 @@ impl Receiver { } /// Completes the setup phase of the protocol. - pub fn setup(self) -> Receiver { + pub fn setup(self) -> Receiver { Receiver { - state: state::PreExtension { counter: 0 }, + state: state::Extension { counter: 0 }, } } } -impl Receiver { +impl Receiver { /// Performs the prepare procedure in MPCOT extension. /// Outputs the indices for SPCOT. /// @@ -38,7 +38,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { let t = alphas.len() as u32; if t > n { return Err(ReceiverError::InvalidInput( @@ -91,7 +91,7 @@ impl Receiver { .collect(); let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, n, queries_length, @@ -103,7 +103,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// # Arguments. @@ -112,7 +112,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt .iter() .zip(self.state.queries_depth.iter()) @@ -130,7 +130,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, }, }; @@ -145,8 +145,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -162,19 +162,19 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state after the setup phase. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter #[allow(dead_code)] pub(super) counter: usize, @@ -186,7 +186,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs index f1e49105..ad025574 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs @@ -31,12 +31,12 @@ impl Sender { /// /// * `delta` - The sender's global secret. /// * `hash_seed` - The seed for Cuckoo hash sent by the receiver. - pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { + pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { let HashSeed { seed: hash_seed } = hash_seed; let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); Sender { - state: state::PreExtension { + state: state::Extension { delta, counter: 0, hashes: Arc::new(hashes), @@ -45,7 +45,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -59,7 +59,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -86,7 +86,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, m, @@ -101,7 +101,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -112,7 +112,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st.len() != self.state.m { return Err(SenderError::InvalidInput( "the length st should be m".to_string(), @@ -147,7 +147,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, hashes: self.state.hashes, @@ -166,8 +166,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -184,7 +184,7 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -193,13 +193,13 @@ pub mod state { pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state of extension. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -217,7 +217,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs index db0646b6..7afa5106 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs @@ -23,14 +23,14 @@ impl Sender { /// # Argument. /// /// * `delta` - The sender's global secret. - pub fn setup(self, delta: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { - state: state::PreExtension { delta, counter: 0 }, + state: state::Extension { delta, counter: 0 }, } } } -impl Sender { +impl Sender { /// Performs the prepare procedure in MPCOT extension. /// Outputs the information for SPCOT. /// @@ -42,7 +42,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -78,7 +78,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, n, @@ -91,7 +91,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// # Arguments. @@ -100,7 +100,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st .iter() .zip(self.state.queries_depth.iter()) @@ -117,7 +117,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, }, @@ -135,8 +135,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -153,20 +153,20 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state after the setup phase. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -179,7 +179,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 4d08c69b..e5939c60 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -4,7 +4,10 @@ use mpz_core::{ Block, }; -use crate::ferret::{error::ReceiverError, LpnType}; +use crate::{ + ferret::{error::ReceiverError, LpnType}, + TransferId, +}; use super::msgs::LpnMatrixSeed; @@ -59,6 +62,7 @@ impl Receiver { u: u.to_vec(), w: w.to_vec(), e: Vec::default(), + id: TransferId::default(), }, }, LpnMatrixSeed { seed }, @@ -69,10 +73,6 @@ impl Receiver { impl Receiver { /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN parameters. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { match self.state.lpn_type { LpnType::Uniform => { @@ -105,6 +105,8 @@ impl Receiver { return Err(ReceiverError("the length of r should be n".to_string())); } + self.state.id.next(); + // Compute z = A * w + r. let mut z = r.to_vec(); self.state.lpn_encoder.compute(&mut z, &self.state.w); @@ -133,6 +135,11 @@ impl Receiver { Ok((x_, z_)) } + + /// Returns id + pub fn id(&self) -> TransferId { + self.state.id + } } /// The receiver's state. @@ -176,6 +183,9 @@ pub mod state { /// Receiver's lpn error vector. pub(super) e: Vec, + + /// TransferID + pub(super) id: TransferId, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 9e8db180..2af3e4ae 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -4,7 +4,12 @@ use mpz_core::{ Block, }; -use crate::ferret::{error::SenderError, LpnType}; +use crate::{ + ferret::{error::SenderError, LpnType}, + TransferId, +}; + +use super::msgs::LpnMatrixSeed; /// Ferret sender. #[derive(Debug, Default)] @@ -36,7 +41,7 @@ impl Sender { delta: Block, lpn_parameters: LpnParameters, lpn_type: LpnType, - seed: Block, + seed: LpnMatrixSeed, v: &[Block], ) -> Result, SenderError> { if v.len() != lpn_parameters.k { @@ -44,6 +49,7 @@ impl Sender { "the length of v should be equal to k".to_string(), )); } + let LpnMatrixSeed { seed } = seed; let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); Ok(Sender { @@ -54,6 +60,7 @@ impl Sender { lpn_type, lpn_encoder, v: v.to_vec(), + id: TransferId::default(), }, }) } @@ -63,6 +70,7 @@ impl Sender { /// Outputs the information for MPCOT. /// /// See step 3 and 4. + #[inline] pub fn get_mpcot_query(&self) -> (u32, u32) { ( self.state.lpn_parameters.t as u32, @@ -83,6 +91,8 @@ impl Sender { return Err(SenderError("the length of s should be n".to_string())); } + self.state.id.next(); + // Compute y = A * v + s let mut y = s.to_vec(); self.state.lpn_encoder.compute(&mut y, &self.state.v); @@ -97,10 +107,17 @@ impl Sender { Ok(y_) } + + /// Returns id + pub fn id(&self) -> TransferId { + self.state.id + } } /// The sender's state. pub mod state { + use crate::TransferId; + use super::*; mod sealed { @@ -141,6 +158,9 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, + + /// TransferID. + pub(crate) id: TransferId, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/spcot/mod.rs b/crates/mpz-ot-core/src/ferret/spcot/mod.rs index 802efb66..63ebea15 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/mod.rs @@ -7,8 +7,6 @@ pub mod sender; #[cfg(test)] mod tests { - use mpz_core::prg::Prg; - use super::{receiver::Receiver as SpcotReceiver, sender::Sender as SpcotSender}; use crate::{ferret::CSP, ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; @@ -18,49 +16,82 @@ mod tests { let sender = SpcotSender::new(); let receiver = SpcotReceiver::new(); - let mut prg = Prg::new(); - let sender_seed = prg.random_block(); let delta = ideal_cot.delta(); - let mut sender = sender.setup(delta, sender_seed); + let mut sender = sender.setup(delta); let mut receiver = receiver.setup(); - let h1 = 8; - let alpha1 = 3; + let hs = [8, 4, 10]; + let alphas = [3, 2, 4]; - // Extend once - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h1); + let h_sum = hs.iter().sum(); + // batch extension + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h1, alpha1, &rs).unwrap(); - let msg_from_sender = sender.extend(h1, &qs, maskbits).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); + + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); + + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); + + // Check + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); + + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = msg_for_receiver; + + let RCOTSenderOutput { msgs: y_star, .. } = msg_for_sender; + + let check_from_receiver = receiver.check_pre(&x_star).unwrap(); - receiver.extend(h1, alpha1, &ts, msg_from_sender).unwrap(); + let (mut output_sender, check) = sender.check(&y_star, check_from_receiver).unwrap(); - // Extend twice - let h2 = 4; - let alpha2 = 2; + let output_receiver = receiver.check(&z_star, check).unwrap(); - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h2); + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + let h_sum = hs.iter().sum(); + + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h2, alpha2, &rs).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); - let msg_from_sender = sender.extend(h2, &qs, maskbits).unwrap(); + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); - receiver.extend(h2, alpha2, &ts, msg_from_sender).unwrap(); + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); // Check let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs index 5e860f31..baf10ae2 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs @@ -6,6 +6,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -43,71 +47,101 @@ impl Receiver { } impl Receiver { - /// Performs the mask bit step in extension. + /// Performs the mask bit step in batch in extension. /// /// See step 4 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `rs` - The message from COT ideal functionality for the receiver. Only the random bits are used. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `rss` - The message from COT ideal functionality for the receiver for all the tress. Only the random bits are used. pub fn extend_mask_bits( &mut self, - h: usize, - alpha: u32, - rs: &[bool], - ) -> Result { + hs: &[usize], + alphas: &[u32], + rss: &[bool], + ) -> Result, ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( "extension is not allowed".to_string(), )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - if rs.len() != h { + let h_sum = hs.iter().sum(); + + if rss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of r should be h".to_string(), + "the length of r should be the sum of h".to_string(), )); } - // Step 4 in Figure 6 + let mut rs_s = vec![Vec::::new(); hs.len()]; + let mut rss_vec = rss.to_vec(); + for (index, h) in hs.iter().enumerate() { + rs_s[index] = rss_vec.drain(0..*h).collect(); + } - let bs: Vec = alpha - .iter_msb0() - .skip(32 - h) - // Computes alpha_i XOR r_i XOR 1. - .zip(rs.iter()) - .map(|(alpha, &r)| alpha == r) - .collect(); + // Step 4 in Figure 6 + let mut bss = vec![Vec::::new(); hs.len()]; + + let iter = bss + .iter_mut() + .zip(alphas.iter()) + .zip(hs.iter()) + .zip(rs_s.iter()) + .map(|(((bs, alpha), h), rs)| (bs, alpha, h, rs)); + + for (bs, alpha, h, rs) in iter { + *bs = alpha + .iter_msb0() + .skip(32 - h) + // Computes alpha_i XOR r_i XOR 1. + .zip(rs.iter()) + .map(|(alpha, &r)| alpha == r) + .collect(); + } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); + + let res: Vec = bss.into_iter().map(|bs| MaskBits { bs }).collect(); - Ok(MaskBits { bs }) + Ok(res) } - /// Performs the GGM reconstruction step in extension. This function can be called multiple times before checking. + /// Performs the GGM reconstruction step in batch in extension. This function can be called multiple times before checking. /// /// See step 5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `ts` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `extendfs` - The message sent by the sender. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `tss` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. + /// * `extendfss` - The vector of messages sent by the sender. pub fn extend( &mut self, - h: usize, - alpha: u32, - ts: &[Block], - extendfs: ExtendFromSender, + hs: &[usize], + alphas: &[u32], + tss: &[Block], + extendfss: &[ExtendFromSender], ) -> Result<(), ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( @@ -115,61 +149,122 @@ impl Receiver { )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - let ExtendFromSender { ms, sum } = extendfs; - if ts.len() != h { + let h_sum = hs.iter().sum(); + + if tss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of t should be h".to_string(), + "the length of tss should be the sum of h".to_string(), )); } - if ms.len() != h { + let mut ts_s = vec![Vec::::new(); hs.len()]; + let mut tss_vec = tss.to_vec(); + for (index, h) in hs.iter().enumerate() { + ts_s[index] = tss_vec.drain(0..*h).collect(); + } + + if extendfss.len() != hs.len() { return Err(ReceiverError::InvalidLength( - "the length of M should be h".to_string(), + "the length of extendfss should be the length of hs".to_string(), )); } - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); - - // Step 5 in Figure 6. - let k: Vec = ms - .into_iter() - .zip(ts) - .zip(alpha_bar_vec.iter()) - .enumerate() - .map(|(i, (([m0, m1], &t), &b))| { - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - if !b { - // H(t, i|ell) ^ M0 - FIXED_KEY_AES.tccr(tweak, t) ^ m0 - } else { - // H(t, i|ell) ^ M1 - FIXED_KEY_AES.tccr(tweak, t) ^ m1 - } - }) - .collect(); + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; - // Reconstructs GGM tree except `ws[alpha]`. - let ggm_tree = GgmTree::new(h); - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec); + for (index, extendfs) in extendfss.iter().enumerate() { + ms_s[index].clone_from(&extendfs.ms); + sum_s[index] = extendfs.sum; + } + + if ms_s.iter().zip(hs.iter()).any(|(ms, h)| ms.len() != *h) { + return Err(ReceiverError::InvalidLength( + "the length of ms should be h".to_string(), + )); + } + // Updates hasher + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); + + let mut trees = vec![Vec::::new(); hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = alphas + .par_iter() + .zip(ms_s.par_iter()) + .zip(sum_s.par_iter()) + .zip(hs.par_iter()) + .zip(ts_s.par_iter()) + .zip(trees.par_iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + }else{ + let iter = alphas + .iter() + .zip(ms_s.iter()) + .zip(sum_s.iter()) + .zip(hs.iter()) + .zip(ts_s.iter()) + .zip(trees.iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + } + } - // Sets `tree[alpha]`, which is `ws[alpha]`. - tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x); + iter.for_each(|(alpha, ms, sum, h, ts, tree)| { + let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); + + // Step 5 in Figure 6. + let k: Vec = ms + .iter() + .zip(ts) + .zip(alpha_bar_vec.iter()) + .enumerate() + .map(|(i, (([m0, m1], &t), &b))| { + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + if !b { + // H(t, i|ell) ^ M0 + FIXED_KEY_AES.tccr(tweak, t) ^ *m0 + } else { + // H(t, i|ell) ^ M1 + FIXED_KEY_AES.tccr(tweak, t) ^ *m1 + } + }) + .collect(); + + // Reconstructs GGM tree except `ws[alpha]`. + let ggm_tree = GgmTree::new(*h); + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.reconstruct(tree, &k, &alpha_bar_vec); + + // Sets `tree[alpha]`, which is `ws[alpha]`. + tree[(*alpha) as usize] = tree.iter().fold(*sum, |acc, &x| acc ^ x); + }); + + for tree in trees { + self.state.unchecked_ws.extend_from_slice(&tree); + } - self.state.unchecked_ws.extend_from_slice(&tree); - self.state.alphas_and_length.push((alpha, 1 << h)); + for (alpha, h) in alphas.iter().zip(hs.iter()) { + self.state.alphas_and_length.push((*alpha, 1 << h)); + } - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); Ok(()) } @@ -248,7 +343,6 @@ impl Receiver { } self.state.cot_counter += self.state.unchecked_ws.len(); - self.state.extended = true; let mut res = Vec::new(); for (alpha, n) in &self.state.alphas_and_length { @@ -256,8 +350,19 @@ impl Receiver { res.push((tmp, *alpha)); } + self.state.hasher = blake3::Hasher::new(); + self.state.alphas_and_length.clear(); + self.state.chis.clear(); + self.state.unchecked_ws.clear(); + Ok(res) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The receiver's state. diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs index fef1327e..a62ad3bb 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/sender.rs @@ -5,6 +5,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -29,8 +33,7 @@ impl Sender { /// # Arguments /// /// * `delta` - The sender's global secret. - /// * `seed` - The random seed to generate PRG. - pub fn setup(self, delta: Block, seed: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { state: state::Extension { delta, @@ -39,7 +42,6 @@ impl Sender { cot_counter: 0, exec_counter: 0, extended: false, - prg: Prg::from_seed(seed), hasher: blake3::Hasher::new(), }, } @@ -47,85 +49,137 @@ impl Sender { } impl Sender { - /// Performs the SPCOT extension. + /// Performs batch SPCOT extension. /// /// See Step 1-5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `qs`- The blocks received by calling the COT functionality. - /// * `mask`- The mask bits sent by the receiver. + /// * `hs` - The depths of the GGM trees. + /// * `qss`- The blocks received by calling the COT functionality for hs trees. + /// * `masks`- The vector of mask bits sent by the receiver. pub fn extend( &mut self, - h: usize, - qs: &[Block], - mask: MaskBits, - ) -> Result { + hs: &[usize], + qss: &[Block], + masks: &[MaskBits], + ) -> Result, SenderError> { if self.state.extended { return Err(SenderError::InvalidState( "extension is not allowed".to_string(), )); } - if qs.len() != h { + let h_sum = hs.iter().sum(); + + if qss.len() != h_sum { return Err(SenderError::InvalidLength( - "the length of q should be h".to_string(), + "the length of qss should be the sum of h".to_string(), )); } - let MaskBits { bs } = mask; + let mut qs_s = vec![Vec::::new(); hs.len()]; + let mut qss_vec = qss.to_vec(); + for (index, h) in hs.iter().enumerate() { + qs_s[index] = qss_vec.drain(0..*h).collect(); + } - if bs.len() != h { + if masks.len() != hs.len() { + return Err(SenderError::InvalidLength( + "the length of masks should be the length of hs".to_string(), + )); + } + + let bss: Vec> = masks.iter().map(|m| m.clone().bs).collect(); + + if bss.iter().zip(hs.iter()).any(|(b, h)| b.len() != *h) { return Err(SenderError::InvalidLength( "the length of b should be h".to_string(), )); } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); // Step 3-4, Figure 6. // Generates a GGM tree with depth h and seed s. - let s = self.state.prg.random_block(); - let ggm_tree = GgmTree::new(h); - let mut k0 = vec![Block::ZERO; h]; - let mut k1 = vec![Block::ZERO; h]; - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.gen(s, &mut tree, &mut k0, &mut k1); + let mut trees = vec![Vec::::new(); hs.len()]; + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = trees + .par_iter_mut().zip(hs.par_iter()) + .zip(qs_s.par_iter()) + .zip(bss.par_iter()) + .zip(ms_s.par_iter_mut()) + .zip(sum_s.par_iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + }else{ + let iter = trees + .iter_mut() + .zip(hs.iter()) + .zip(qs_s.iter()) + .zip(bss.iter()) + .zip(ms_s.iter_mut()) + .zip(sum_s.iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + } + } + + iter.for_each(|(tree, h, qs, bs, ms, sum)| { + let s = Prg::new().random_block(); + let ggm_tree = GgmTree::new(*h); + let mut k0 = vec![Block::ZERO; *h]; + let mut k1 = vec![Block::ZERO; *h]; + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.gen(s, tree, &mut k0, &mut k1); + + // Computes the sum of the leaves and delta. + *sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); + + // Computes M0 and M1. + for (((i, &q), b), (k0, k1)) in + qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) + { + let mut m = if *b { + [q ^ self.state.delta, q] + } else { + [q, q ^ self.state.delta] + }; + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); + m[0] ^= k0; + m[1] ^= k1; + ms.push(m); + } + }); // Stores the tree, i.e., the possible output of sender. - self.state.unchecked_vs.extend_from_slice(&tree); + for tree in trees { + self.state.unchecked_vs.extend_from_slice(&tree); + } // Stores the length of this extension. - self.state.vs_length.push(1 << h); - - // Computes the sum of the leaves and delta. - let sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); - - // Computes M0 and M1. - let mut ms: Vec<[Block; 2]> = Vec::with_capacity(qs.len()); - for (((i, &q), b), (k0, k1)) in qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) { - let mut m = if b { - [q ^ self.state.delta, q] - } else { - [q, q ^ self.state.delta] - }; - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); - m[0] ^= k0; - m[1] ^= k1; - ms.push(m); + for h in hs { + self.state.vs_length.push(1 << h); } // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); + + let res: Vec = ms_s + .into_iter() + .zip(sum_s.iter()) + .map(|(ms, &sum)| ExtendFromSender { ms, sum }) + .collect(); - Ok(ExtendFromSender { ms, sum }) + Ok(res) } /// Performs the consistency check for the resulting COTs. @@ -193,10 +247,18 @@ impl Sender { res.push(tmp); } - self.state.extended = true; + self.state.hasher = blake3::Hasher::new(); + self.state.unchecked_vs.clear(); + self.state.vs_length.clear(); Ok((res, CheckFromSender { hashed_v })) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The sender's state. @@ -239,8 +301,6 @@ pub mod state { /// This is to prevent the receiver from extending twice pub(super) extended: bool, - /// A PRG to generate random strings. - pub(super) prg: Prg, /// A hasher to generate chi seed. pub(super) hasher: blake3::Hasher, } diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs new file mode 100644 index 00000000..6952f0ec --- /dev/null +++ b/crates/mpz-ot/src/ferret/error.rs @@ -0,0 +1,67 @@ +use crate::OTError; + +/// A Ferret sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::error::SenderError), + #[error(transparent)] + MPCOTSenderError(#[from] crate::ferret::mpcot::SenderError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), + #[error("{0}")] + MPCOTSenderTypeError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A Ferret receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::error::ReceiverError), + #[error(transparent)] + MPCOTReceiverError(#[from] crate::ferret::mpcot::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), + #[error("{0}")] + MPCOTReceiverTypeError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs new file mode 100644 index 00000000..2b2047b9 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -0,0 +1,175 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. +mod error; +mod mpcot; +mod receiver; +mod sender; +mod spcot; + +pub use error::{ReceiverError, SenderError}; +pub use receiver::Receiver; +pub use sender::Sender; + +use mpz_core::lpn::LpnParameters; +use mpz_ot_core::ferret::LpnType; + +/// Configuration of Ferret. +#[derive(Debug)] +pub struct FerretConfig { + rcot: RandomCOT, + setup_rcot: SetupRandomCOT, + lpn_parameters: LpnParameters, + lpn_type: LpnType, +} + +impl FerretConfig { + /// Create a new instance. + /// + /// # Arguments. + /// + /// * `rcot` - The rcot for MPCOT. + /// * `setup_rcot` - The rcot for setup. + /// * `lpn_parameters` - The parameters of LPN. + /// * `lpn_type` - The type of LPN. + pub fn new( + rcot: RandomCOT, + setup_rcot: SetupRandomCOT, + lpn_parameters: LpnParameters, + lpn_type: LpnType, + ) -> Self { + Self { + rcot, + setup_rcot, + lpn_parameters, + lpn_type, + } + } + + /// Get rcot + pub fn rcot(&self) -> RandomCOT { + self.rcot.clone() + } + + /// Get the setup rcot + pub fn setup_rcot(&mut self) -> &mut SetupRandomCOT { + &mut self.setup_rcot + } + + /// Get the lpn type + pub fn lpn_type(&self) -> LpnType { + self.lpn_type + } + + /// Get the lpn parameters + pub fn lpn_parameters(&self) -> LpnParameters { + self.lpn_parameters + } +} + +#[cfg(test)] +mod tests { + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::{lpn::LpnParameters, Block}; + use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, RandomCOTReceiver, RandomCOTSender, + }; + + use super::*; + + // l = n - k = 8380 + const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { + n: 9600, + k: 1220, + t: 600, + }; + + fn setup() -> ( + Sender, + Receiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let sender_config = FerretConfig::new( + rcot_sender.clone(), + rcot_sender.clone(), + LPN_PARAMETERS_TEST, + LpnType::Regular, + ); + + let receiver_config = FerretConfig::new( + rcot_receiver.clone(), + rcot_receiver, + LPN_PARAMETERS_TEST, + LpnType::Regular, + ); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(sender_config); + + let receiver = Receiver::new(receiver_config); + + (sender, receiver, delta) + } + + #[tokio::test] + async fn test_ferret() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, delta) = setup(); + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta) + .map_err(OTError::from), + receiver.setup(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + // extend once. + let count = 8000; + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &b, &u, &w); + + // extend twice + let count = 9000; + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &b, &u, &w); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs new file mode 100644 index 00000000..238808d0 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/error.rs @@ -0,0 +1,59 @@ +use crate::OTError; + +/// A MPCOT sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::mpcot::error::SenderError), + #[error(transparent)] + SPCOTSenderError(#[from] crate::ferret::spcot::SenderError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::mpcot::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A MPCOT receiver error +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::mpcot::error::ReceiverError), + #[error(transparent)] + SpcotReceiverError(#[from] crate::ferret::spcot::ReceiverError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::mpcot::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/mod.rs b/crates/mpz-ot/src/ferret/mpcot/mod.rs new file mode 100644 index 00000000..598b5734 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/mod.rs @@ -0,0 +1,165 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod error; +mod receiver; +mod sender; + +pub(crate) use error::{ReceiverError, SenderError}; +pub(crate) use receiver::Receiver; +pub(crate) use sender::Sender; + +#[cfg(test)] +mod tests { + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + use mpz_ot_core::ferret::LpnType; + + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, + }; + + use receiver::Receiver; + use sender::Sender; + + use super::*; + + fn setup( + lpn_type: LpnType, + ) -> ( + Sender, + Receiver, + IdealCOTSender, + IdealCOTReceiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(lpn_type); + + let receiver = Receiver::new(lpn_type); + + (sender, receiver, rcot_sender, rcot_receiver, delta) + } + + #[tokio::test] + async fn test_mpcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Uniform); + + let alphas = [0, 1, 3, 4, 2]; + let t = alphas.len(); + let n = 10; + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta, rcot_sender) + .map_err(OTError::from), + receiver + .setup(&mut ctx_receiver, rcot_receiver) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + // extend twice + let alphas = [5, 1, 7, 2]; + let t = alphas.len(); + let n = 16; + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Regular); + + // extend once. + let alphas = [0, 3, 4, 7, 9]; + let t = alphas.len(); + let n = 10; + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta, rcot_sender) + .map_err(OTError::from), + receiver + .setup(&mut ctx_receiver, rcot_receiver) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + // extend twice. + let alphas = [0, 3, 7, 9, 14, 15]; + let t = alphas.len(); + let n = 16; + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/receiver.rs b/crates/mpz-ot/src/ferret/mpcot/receiver.rs new file mode 100644 index 00000000..e2553efd --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/receiver.rs @@ -0,0 +1,192 @@ +use crate::{ + ferret::{mpcot::error::ReceiverError, spcot::Receiver as SpcotReceiver}, + RandomCOTReceiver, +}; +use enum_try_as_inner::EnumTryAsInner; + +use mpz_common::Context; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + receiver::{state as uniform_state, Receiver as UniformReceiverCore}, + receiver_regular::{state as regular_state, Receiver as RegularReceiverCore}, + }, + LpnType, +}; +use serio::SinkExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + UniformInitialized(UniformReceiverCore), + UniformExtension(UniformReceiverCore), + RegularInitialized(RegularReceiverCore), + RegularExtension(RegularReceiverCore), + Complete, + Error, +} + +/// MPCOT receiver. +#[derive(Debug)] +pub(crate) struct Receiver { + state: State, + spcot: SpcotReceiver, + lpn_type: LpnType, +} + +impl Receiver { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `lpn_type` - The type of LPN. + pub(crate) fn new(lpn_type: LpnType) -> Self { + match lpn_type { + LpnType::Uniform => Self { + state: State::UniformInitialized(UniformReceiverCore::new()), + spcot: crate::ferret::spcot::Receiver::new(), + lpn_type, + }, + LpnType::Regular => Self { + state: State::RegularInitialized(RegularReceiverCore::new()), + spcot: crate::ferret::spcot::Receiver::new(), + lpn_type, + }, + } + } + + /// Performs setup for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `rcot` - The random COT used by Receiver. + pub(crate) async fn setup( + &mut self, + ctx: &mut Ctx, + rcot: RandomCOT, + ) -> Result<(), ReceiverError> { + match self.lpn_type { + LpnType::Uniform => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_initialized()?; + + let hash_seed = Prg::new().random_block(); + + let (ext_receiver, hash_seed) = ext_receiver.setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + self.state = State::UniformExtension(ext_receiver); + } + LpnType::Regular => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_initialized()?; + + let ext_receiver = ext_receiver.setup(); + + self.state = State::RegularExtension(ext_receiver); + } + } + + self.spcot.setup(rcot)?; + + Ok(()) + } + + /// Performs MPCOT extension. + /// + /// + /// # Arguments + /// + /// * `ctx` - The context, + /// * `alphas` - The queried indices. + /// * `n` - The total number of indices. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + alphas: &[u32], + n: u32, + ) -> Result, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let alphas_vec = alphas.to_vec(); + + match self.lpn_type { + LpnType::Uniform => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_extension()?; + + let (ext_receiver, h_and_pos) = + Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + self.spcot.extend(ctx, &pos, &hs).await?; + + let rt = self.spcot.check(ctx).await?; + + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (ext_receiver, output) = + Backend::spawn(move || ext_receiver.extend(&rt)).await?; + + self.state = State::UniformExtension(ext_receiver); + + Ok(output) + } + + LpnType::Regular => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_extension()?; + + let (ext_receiver, h_and_pos) = + Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + self.spcot.extend(ctx, &pos, &hs).await?; + + let rt = self.spcot.check(ctx).await?; + + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (ext_receiver, output) = + Backend::spawn(move || ext_receiver.extend(&rt)).await?; + + self.state = State::RegularExtension(ext_receiver); + + Ok(output) + } + } + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { + match self.lpn_type { + LpnType::Uniform => { + std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; + } + LpnType::Regular => { + std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; + } + } + + self.spcot.finalize()?; + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/sender.rs b/crates/mpz-ot/src/ferret/mpcot/sender.rs new file mode 100644 index 00000000..a0256276 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/sender.rs @@ -0,0 +1,166 @@ +use crate::{ + ferret::{mpcot::error::SenderError, spcot::Sender as SpcotSender}, + RandomCOTSender, +}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, + sender::{state as uniform_state, Sender as UniformSenderCore}, + sender_regular::{state as regular_state, Sender as RegularSenderCore}, + }, + LpnType, +}; +use serio::stream::IoStreamExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + UniformInitialized(UniformSenderCore), + UniformExtension(UniformSenderCore), + RegularInitialized(RegularSenderCore), + RegularExtension(RegularSenderCore), + Complete, + Error, +} + +/// MPCOT sender. +#[derive(Debug)] +pub(crate) struct Sender { + state: State, + spcot: SpcotSender, + lpn_type: LpnType, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `lpn_type` - The type of LPN. + pub(crate) fn new(lpn_type: LpnType) -> Self { + match lpn_type { + LpnType::Uniform => Self { + state: State::UniformInitialized(UniformSenderCore::new()), + spcot: crate::ferret::spcot::Sender::new(), + lpn_type, + }, + LpnType::Regular => Self { + state: State::RegularInitialized(RegularSenderCore::new()), + spcot: crate::ferret::spcot::Sender::new(), + lpn_type, + }, + } + } + + /// Performs setup with provided delta. + /// + /// # Arguments + /// + /// * `ctx` - The channel. + /// * `delta` - The delta value to use for OT extension. + /// * `rcot` - The random COT used by Sender. + pub(crate) async fn setup_with_delta( + &mut self, + ctx: &mut Ctx, + delta: Block, + rcot: RandomCOT, + ) -> Result<(), SenderError> { + match self.lpn_type { + LpnType::Uniform => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_initialized()?; + + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let ext_sender = ext_sender.setup(delta, hash_seed); + + self.state = State::UniformExtension(ext_sender); + } + + LpnType::Regular => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::RegularExtension(ext_sender); + } + } + + self.spcot.setup_with_delta(delta, rcot)?; + + Ok(()) + } + + /// Performs MPCOT extension. + /// + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `t` - The number of queried indices. + /// * `n` - The total number of indices. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + t: u32, + n: u32, + ) -> Result, SenderError> + where + RandomCOT: RandomCOTSender, + { + match self.lpn_type { + LpnType::Uniform => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_extension()?; + + let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; + + self.spcot.extend(ctx, &hs).await?; + + let st = self.spcot.check(ctx).await?; + + let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; + + self.state = State::UniformExtension(ext_sender); + Ok(output) + } + LpnType::Regular => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_extension()?; + + let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; + + self.spcot.extend(ctx, &hs).await?; + + let st = self.spcot.check(ctx).await?; + + let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; + + self.state = State::RegularExtension(ext_sender); + Ok(output) + } + } + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { + match self.lpn_type { + LpnType::Uniform => { + std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; + } + LpnType::Regular => { + std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; + } + } + + self.spcot.finalize()?; + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs new file mode 100644 index 00000000..520506e8 --- /dev/null +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -0,0 +1,192 @@ +use crate::{ + ferret::{mpcot::Receiver as MpcotReceiver, ReceiverError}, + RandomCOTReceiver, +}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::{ + ferret::receiver::{state, Receiver as ReceiverCore}, + RCOTReceiverOutput, +}; +use serio::SinkExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +use super::FerretConfig; +use crate::{async_trait, OTError}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(ReceiverCore), + Extension(ReceiverCore), + Complete, + Error, +} + +/// Ferret Receiver. +#[derive(Debug)] +pub struct Receiver { + state: State, + mpcot: MpcotReceiver, + config: FerretConfig, +} + +impl Receiver +where + RandomCOT: Send + Default + Clone, + SetupRandomCOT: Send, +{ + /// Creates a new Receiver. + /// + /// # Arguments. + /// + /// * `config` - Ferret configuration. + pub fn new(config: FerretConfig) -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + mpcot: MpcotReceiver::new(config.lpn_type()), + config, + } + } + + /// Setup for receiver. + /// + /// # Arguments. + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> + where + Ctx: Context, + SetupRandomCOT: RandomCOTReceiver, + { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let rcot = self.config.rcot(); + self.mpcot.setup(ctx, rcot).await?; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Get random blocks from ideal Random COT. + + let RCOTReceiverOutput { + choices: u, + msgs: w, + .. + } = self + .config + .setup_rcot() + .receive_random_correlated(ctx, params.k) + .await?; + + let seed = Prg::new().random_block(); + + let (ext_receiver, seed) = ext_receiver.setup(params, lpn_type, seed, &u, &w)?; + + ctx.io_mut().send(seed).await?; + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `ctx` - The channel context. + async fn extend(&mut self, ctx: &mut Ctx) -> Result<(Vec, Vec), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let (alphas, n) = ext_receiver.get_mpcot_query(); + + let r = self.mpcot.extend(ctx, &alphas, n as u32).await?; + + let (ext_receiver, choices, msgs) = Backend::spawn(move || { + ext_receiver + .extend(&r) + .map(|(choices, msgs)| (ext_receiver, choices, msgs)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok((choices, msgs)) + } + + /// Complete extension + pub fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + self.state = State::Complete; + self.mpcot.finalize()?; + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTReceiver + for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send + Clone + Default + 'static, + SetupRandomCOT: Send + 'static, +{ + async fn receive_random_correlated( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let (mut choices_buffer, mut msgs_buffer) = self.extend(ctx).await?; + + assert_eq!(choices_buffer.len(), msgs_buffer.len()); + + let l = choices_buffer.len(); + + let id = self + .state + .try_as_extension() + .map_err(ReceiverError::from)? + .id(); + + if count <= l { + let choices_res = choices_buffer.drain(..count).collect(); + + let msgs_res = msgs_buffer.drain(..count).collect(); + + return Ok(RCOTReceiverOutput { + id, + choices: choices_res, + msgs: msgs_res, + }); + } else { + let mut choices_res = choices_buffer; + let mut msgs_res = msgs_buffer; + + for _ in 0..count / l - 1 { + (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + + choices_res.extend_from_slice(&choices_buffer); + msgs_res.extend_from_slice(&msgs_buffer); + } + + (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + + choices_res.extend_from_slice(&choices_buffer[0..count % l]); + msgs_res.extend_from_slice(&msgs_buffer[0..count % l]); + + return Ok(RCOTReceiverOutput { + id, + choices: choices_res, + msgs: msgs_res, + }); + } + } +} diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs new file mode 100644 index 00000000..709ff8e2 --- /dev/null +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -0,0 +1,160 @@ +use crate::{ferret::mpcot::Sender as MpcotSender, RandomCOTSender}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::sender::{state, Sender as SenderCore}, + RCOTSenderOutput, +}; +use serio::stream::IoStreamExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +use super::{FerretConfig, SenderError}; +use crate::{async_trait, OTError}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(SenderCore), + Complete, + Error, +} + +/// Ferret Sender. +#[derive(Debug)] +pub struct Sender { + state: State, + mpcot: MpcotSender, + config: FerretConfig, +} + +impl Sender +where + RandomCOT: Send + Default + Clone, + SetupRandomCOT: Send, +{ + /// Creates a new Sender. + pub fn new(config: FerretConfig) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + mpcot: MpcotSender::new(config.lpn_type()), + config, + } + } + + /// Setup with provided delta. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + /// * `delta` - The provided delta used for sender. + pub async fn setup_with_delta( + &mut self, + ctx: &mut Ctx, + delta: Block, + ) -> Result<(), SenderError> + where + Ctx: Context, + SetupRandomCOT: RandomCOTSender, + { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let rcot = self.config.rcot(); + + self.mpcot.setup_with_delta(ctx, delta, rcot).await?; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Get random blocks from ideal Random COT. + let RCOTSenderOutput { msgs: v, .. } = self + .config + .setup_rcot() + .send_random_correlated(ctx, params.k) + .await?; + + // Get seed for LPN matrix from receiver. + let seed = ctx.io_mut().expect_next().await?; + + // Ferret core setup. + let ext_sender = ext_sender.setup(delta, params, lpn_type, seed, &v)?; + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs extension. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + async fn extend(&mut self, ctx: &mut Ctx) -> Result, SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let (t, n) = ext_sender.get_mpcot_query(); + + let s = self.mpcot.extend(ctx, t, n).await?; + + let (ext_sender, output) = + Backend::spawn(move || ext_sender.extend(&s).map(|output| (ext_sender, output))) + .await?; + self.state = State::Extension(ext_sender); + + Ok(output) + } + + /// Complete extension + pub fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + self.state = State::Complete; + self.mpcot.finalize()?; + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTSender + for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send + Default + Clone + 'static, + SetupRandomCOT: Send + 'static, +{ + async fn send_random_correlated( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let mut buffer = self.extend(ctx).await?; + let l = buffer.len(); + + let id = self + .state + .try_as_extension() + .map_err(SenderError::from)? + .id(); + + if count <= l { + let res = buffer.drain(..count).collect(); + return Ok(RCOTSenderOutput { id, msgs: res }); + } else { + let mut res = buffer; + for _ in 0..count / l - 1 { + buffer = self.extend(ctx).await?; + res.extend_from_slice(&buffer); + } + + buffer = self.extend(ctx).await?; + res.extend_from_slice(&buffer[0..count % l]); + + return Ok(RCOTSenderOutput { id, msgs: res }); + } + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs new file mode 100644 index 00000000..0fa9dc9c --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/error.rs @@ -0,0 +1,59 @@ +use crate::OTError; + +/// A SPCOT sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::spcot::error::SenderError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::spcot::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A SPCOT receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::spcot::error::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::spcot::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/mod.rs b/crates/mpz-ot/src/ferret/spcot/mod.rs new file mode 100644 index 00000000..6e53fd28 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/mod.rs @@ -0,0 +1,103 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod error; +mod receiver; +mod sender; + +pub(crate) use error::{ReceiverError, SenderError}; +pub(crate) use receiver::Receiver; +pub(crate) use sender::Sender; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, + }; + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + + fn setup() -> ( + Sender, + Receiver, + IdealCOTSender, + IdealCOTReceiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(); + let receiver = Receiver::new(); + + (sender, receiver, rcot_sender, rcot_receiver, delta) + } + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(); + + // shold set the same delta as in RCOT. + sender.setup_with_delta(delta, rcot_sender).unwrap(); + receiver.setup(rcot_receiver).unwrap(); + + let hs = [8, 4]; + let alphas = [4, 2]; + + tokio::try_join!( + sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, &hs) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender.check(&mut ctx_sender).map_err(OTError::from), + receiver.check(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice. + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + tokio::try_join!( + sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, &hs) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender.check(&mut ctx_sender).map_err(OTError::from), + receiver.check(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/receiver.rs b/crates/mpz-ot/src/ferret/spcot/receiver.rs new file mode 100644 index 00000000..3c48bfad --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/receiver.rs @@ -0,0 +1,164 @@ +use crate::{ferret::spcot::error::ReceiverError, RandomCOTReceiver}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::ExtendFromSender, + receiver::{state, Receiver as ReceiverCore}, + }, + CSP, + }, + RCOTReceiverOutput, +}; +use serio::{stream::IoStreamExt, SinkExt}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(ReceiverCore), + Extension(Box>), + Complete, + Error, +} + +/// SPCOT Receiver. +#[derive(Debug)] +pub(crate) struct Receiver { + state: State, + rcot: RandomCOT, +} + +impl Receiver { + /// Creates a new Receiver. + pub(crate) fn new() -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + rcot: Default::default(), + } + } + + /// Performs setup for receiver. + /// + /// # Arguments. + /// + /// * `rcot` - The random COT used by the receiver. + pub(crate) fn setup(&mut self, rcot: RandomCOT) -> Result<(), ReceiverError> { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_receiver = ext_receiver.setup(); + self.state = State::Extension(Box::new(ext_receiver)); + self.rcot = rcot; + Ok(()) + } + + /// Performs spcot extension for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `alphas`` - The vector of chosen positions. + /// * `h` - The depth of GGM tree. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + alphas: &[u32], + hs: &[usize], + ) -> Result<(), ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = self.rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut ext_receiver, masks) = Backend::spawn(move || { + ext_receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (ext_receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let ext_receiver = Backend::spawn(move || { + ext_receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| ext_receiver) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs batch check for SPCOT extension. + /// + /// # Arguments + /// + /// * `ctx` - The context. + pub(crate) async fn check( + &mut self, + ctx: &mut Ctx, + ) -> Result, u32)>, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = self.rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut ext_receiver, checkfr) = Backend::spawn(move || { + ext_receiver + .check_pre(&x_star) + .map(|checkfr| (ext_receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let (ext_receiver, output) = Backend::spawn(move || { + ext_receiver + .check(&z_star, check) + .map(|output| (ext_receiver, output)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(output) + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/sender.rs b/crates/mpz-ot/src/ferret/spcot/sender.rs new file mode 100644 index 00000000..9178b787 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/sender.rs @@ -0,0 +1,144 @@ +use crate::{ferret::spcot::error::SenderError, RandomCOTSender}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::MaskBits, + sender::{state, Sender as SenderCore}, + }, + CSP, + }, + RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt, SinkExt}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(Box>), + Complete, + Error, +} + +/// SPCOT sender. +#[derive(Debug)] +pub(crate) struct Sender { + state: State, + rcot: RandomCOT, +} + +impl Sender { + /// Creates a new Sender. + pub(crate) fn new() -> Self { + Self { + state: State::Initialized(SenderCore::new()), + rcot: Default::default(), + } + } + + /// Performs setup with the provided delta. + /// + /// # Arguments + /// + /// * `delta` - The delta value to use for OT extension. + /// * `rcot` - The random COT used by the sender. + pub(crate) fn setup_with_delta( + &mut self, + delta: Block, + rcot: RandomCOT, + ) -> Result<(), SenderError> { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::Extension(Box::new(ext_sender)); + self.rcot = rcot; + Ok(()) + } + + /// Performs spcot extension for sender. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `hs` - The depths of GGM trees. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + hs: &[usize], + ) -> Result<(), SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = self.rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (ext_sender, extend_msg) = Backend::spawn(move || { + ext_sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (ext_sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs batch check for SPCOT extension. + /// + /// # Arguments + /// + /// * `ctx` - The context. + pub(crate) async fn check( + &mut self, + ctx: &mut Ctx, + ) -> Result>, SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = + self.rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (ext_sender, output, check_msg) = Backend::spawn(move || { + ext_sender + .check(&y_star, checkfr) + .map(|(output, check_msg)| (ext_sender, output, check_msg)) + }) + .await?; + + ctx.io_mut().send(check_msg).await?; + + self.state = State::Extension(ext_sender); + + Ok(output) + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index b0084957..18233dfe 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -46,9 +46,16 @@ pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { } /// Ideal COT sender. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTSender(Alice); +impl IdealCOTSender { + /// Returns Alice. + pub fn alice(&mut self) -> &mut Alice { + &mut self.0 + } +} + #[async_trait] impl OTSetup for IdealCOTSender where @@ -98,7 +105,7 @@ impl RandomCOTSender for IdealCOTSender { } /// Ideal COT receiver. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTReceiver(Bob); #[async_trait] diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index b9871eab..0e4d1b48 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -10,6 +10,7 @@ )] pub mod chou_orlandi; +pub mod ferret; #[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod kos; From ef72b9b88820f789400d14996c0ba24745a007a4 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Sat, 29 Jun 2024 00:11:59 +0800 Subject: [PATCH 02/17] cargo clippy --- crates/mpz-ot/src/ferret/mpcot/error.rs | 4 ++-- crates/mpz-ot/src/ferret/spcot/error.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs index 238808d0..e300bf0d 100644 --- a/crates/mpz-ot/src/ferret/mpcot/error.rs +++ b/crates/mpz-ot/src/ferret/mpcot/error.rs @@ -2,7 +2,7 @@ use crate::OTError; /// A MPCOT sender error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum SenderError { #[error(transparent)] IOError(#[from] std::io::Error), @@ -31,7 +31,7 @@ impl From for SenderError { /// A MPCOT receiver error #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum ReceiverError { #[error(transparent)] IOError(#[from] std::io::Error), diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs index 0fa9dc9c..5f23f466 100644 --- a/crates/mpz-ot/src/ferret/spcot/error.rs +++ b/crates/mpz-ot/src/ferret/spcot/error.rs @@ -2,7 +2,7 @@ use crate::OTError; /// A SPCOT sender error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum SenderError { #[error(transparent)] IOError(#[from] std::io::Error), @@ -31,7 +31,7 @@ impl From for SenderError { /// A SPCOT receiver error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum ReceiverError { #[error(transparent)] IOError(#[from] std::io::Error), From a01ac1912e673d89dfa1f3e6b0a8d3dc7544c60f Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 2 Jul 2024 01:18:30 +0800 Subject: [PATCH 03/17] add vope --- crates/mpz-zk/Cargo.toml | 53 +++++++++++++++ crates/mpz-zk/src/lib.rs | 24 +++++++ crates/mpz-zk/src/vope/error.rs | 61 +++++++++++++++++ crates/mpz-zk/src/vope/mod.rs | 63 +++++++++++++++++ crates/mpz-zk/src/vope/receiver.rs | 105 +++++++++++++++++++++++++++++ crates/mpz-zk/src/vope/sender.rs | 99 +++++++++++++++++++++++++++ 6 files changed, 405 insertions(+) create mode 100644 crates/mpz-zk/Cargo.toml create mode 100644 crates/mpz-zk/src/lib.rs create mode 100644 crates/mpz-zk/src/vope/error.rs create mode 100644 crates/mpz-zk/src/vope/mod.rs create mode 100644 crates/mpz-zk/src/vope/receiver.rs create mode 100644 crates/mpz-zk/src/vope/sender.rs diff --git a/crates/mpz-zk/Cargo.toml b/crates/mpz-zk/Cargo.toml new file mode 100644 index 00000000..54b42ced --- /dev/null +++ b/crates/mpz-zk/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "mpz-zk" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lints] +workspace = true + +[lib] +name = "mpz_zk" + +[features] +default = ["rayon"] +rayon = ["mpz-ot-core/rayon"] +ideal = ["mpz-common/ideal"] + +[dependencies] +mpz-core.workspace = true +mpz-zk-core.workspace = true +mpz-common.workspace = true +mpz-cointoss.workspace = true +mpz-ot-core.workspace = true +mpz-ot.workspace = true + +tlsn-utils-aio.workspace = true + +async-trait.workspace = true +futures.workspace = true +rand.workspace = true +rand_core.workspace = true +rand_chacha.workspace = true +thiserror.workspace = true +rayon = { workspace = true } +itybity.workspace = true +enum-try-as-inner.workspace = true +opaque-debug.workspace = true +serde = { workspace = true, optional = true } +serio.workspace = true +cfg-if.workspace = true + +[dev-dependencies] +mpz-common = { workspace = true, features = ["test-utils", "ideal"] } +mpz-ot-core = { workspace = true, features = ["test-utils"] } +rstest = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } +tokio = { workspace = true, features = [ + "net", + "macros", + "rt", + "rt-multi-thread", +] } diff --git a/crates/mpz-zk/src/lib.rs b/crates/mpz-zk/src/lib.rs new file mode 100644 index 00000000..4ae2e10f --- /dev/null +++ b/crates/mpz-zk/src/lib.rs @@ -0,0 +1,24 @@ +//! Implementations of zero-knowledge protocols. + +#![deny( + unsafe_code, + missing_docs, + unused_imports, + unused_must_use, + unreachable_pub, + clippy::all +)] + +pub mod vope; + +/// An oblivious transfer error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum VOPEError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("sender error: {0}")] + SenderError(Box), + #[error("receiver error: {0}")] + ReceiverError(Box), +} diff --git a/crates/mpz-zk/src/vope/error.rs b/crates/mpz-zk/src/vope/error.rs new file mode 100644 index 00000000..912efda1 --- /dev/null +++ b/crates/mpz-zk/src/vope/error.rs @@ -0,0 +1,61 @@ +//! Errors in VOPE + +use crate::VOPEError; + +/// A VOPE Sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::vope::error::SenderError), + #[error(transparent)] + RandomCOTError(#[from] mpz_ot::OTError), + #[error("{0}")] + StateError(String), +} + +/// A VOPE Receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::vope::error::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] mpz_ot::OTError), + #[error("{0}")] + StateError(String), +} + +impl From for VOPEError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => VOPEError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::vope::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +impl From for VOPEError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => VOPEError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::vope::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-zk/src/vope/mod.rs b/crates/mpz-zk/src/vope/mod.rs new file mode 100644 index 00000000..1b34799d --- /dev/null +++ b/crates/mpz-zk/src/vope/mod.rs @@ -0,0 +1,63 @@ +//! This is the implementation of vector oblivious polynomial evaluation (VOPE) based on Figure 4 in https://eprint.iacr.org/2021/076.pdf + +pub mod error; +pub mod receiver; +pub mod sender; + +#[cfg(test)] +mod tests { + use crate::{ + vope::{receiver::Receiver, sender::Sender}, + VOPEError, + }; + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + use mpz_ot::ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}; + use mpz_zk_core::test::poly_check; + + fn setup() -> (Sender, Receiver, Block) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(rcot_sender); + let receiver = Receiver::new(rcot_receiver); + + (sender, receiver, delta) + } + + #[tokio::test] + async fn test_vope() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, delta) = setup(); + + sender.setup_with_delta(delta).unwrap(); + receiver.setup().unwrap(); + + let d = 1; + + let (output_sender, output_receiver) = tokio::try_join!( + sender.extend(&mut ctx_sender, d).map_err(VOPEError::from), + receiver + .extend(&mut ctx_receiver, d) + .map_err(VOPEError::from) + ) + .unwrap(); + + assert!(poly_check(&output_receiver, output_sender, delta)); + + let d = 5; + + let (output_sender, output_receiver) = tokio::try_join!( + sender.extend(&mut ctx_sender, d).map_err(VOPEError::from), + receiver + .extend(&mut ctx_receiver, d) + .map_err(VOPEError::from) + ) + .unwrap(); + + assert!(poly_check(&output_receiver, output_sender, delta)); + } +} diff --git a/crates/mpz-zk/src/vope/receiver.rs b/crates/mpz-zk/src/vope/receiver.rs new file mode 100644 index 00000000..d1fb3cdd --- /dev/null +++ b/crates/mpz-zk/src/vope/receiver.rs @@ -0,0 +1,105 @@ +//! Implementation of VOPE receiver. + +use crate::vope::error::ReceiverError; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver}; +use mpz_zk_core::vope::{ + receiver::{state, Receiver as ReceiverCore}, + CSP, +}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum State { + Initialized(ReceiverCore), + Extension(ReceiverCore), + Complete, + Error, +} + +/// VOPE receiver (prover) +#[derive(Debug)] +pub struct Receiver { + state: State, + rcot: RandomCOT, +} + +impl Receiver { + /// Creates a new receiver. + /// + /// # Arguments + /// + /// * `rcot` - The random COT used by the receiver. + pub fn new(rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + rcot, + } + } + + /// Performs setup for receiver. + pub fn setup(&mut self) -> Result<(), ReceiverError> { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_receiver = ext_receiver.setup(); + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs VOPE extension for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `d` - The polynomial degree. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + d: usize, + ) -> Result, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + assert!(d > 0); + + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = self + .rcot + .receive_random_correlated(ctx, (2 * d - 1) * CSP) + .await?; + + // extend + let (ext_receiver, res) = Backend::spawn(move || { + ext_receiver + .extend(&ms, &us, d) + .map(|res| (ext_receiver, res)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(res) + } + + /// Complete extension. + pub fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-zk/src/vope/sender.rs b/crates/mpz-zk/src/vope/sender.rs new file mode 100644 index 00000000..447dcb6b --- /dev/null +++ b/crates/mpz-zk/src/vope/sender.rs @@ -0,0 +1,99 @@ +//! Implementation of VOPE sender + +use crate::vope::error::SenderError; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot::{RCOTSenderOutput, RandomCOTSender}; +use mpz_zk_core::vope::{ + sender::{state, Sender as SenderCore}, + CSP, +}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum State { + Initialized(SenderCore), + Extension(SenderCore), + Complete, + Error, +} + +/// VOPE sender (verifier) +#[derive(Debug)] +pub struct Sender { + state: State, + rcot: RandomCOT, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Arguments + /// + /// * `rcot` - The random COT used by the sender. + pub fn new(rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + rcot, + } + } + + /// Performs setup with the provided delta. + /// + /// # Arguments + /// + /// * `delta` - The delta value to use for VOPE extension. + pub fn setup_with_delta(&mut self, delta: Block) -> Result<(), SenderError> { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs VOPE extension for sender. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `d` - The polynomial degree. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + d: usize, + ) -> Result + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + assert!(d > 0); + + let RCOTSenderOutput { msgs: ks, .. } = self + .rcot + .send_random_correlated(ctx, (2 * d - 1) * CSP) + .await?; + + let (ext_sender, res) = + Backend::spawn(move || ext_sender.extend(&ks, d).map(|res| (ext_sender, res))).await?; + + self.state = State::Extension(ext_sender); + + Ok(res) + } + + /// Complete extension. + pub fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} From 0bd03daefbf7e145bd24f3d5f8c55909e900c4e4 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 2 Jul 2024 01:24:05 +0800 Subject: [PATCH 04/17] add vope --- crates/mpz-ot/src/lib.rs | 1 - crates/mpz-zk-core/Cargo.toml | 35 +++++ crates/mpz-zk-core/src/lib.rs | 23 ++++ crates/mpz-zk-core/src/test.rs | 11 ++ crates/mpz-zk-core/src/vope/error.rs | 21 +++ crates/mpz-zk-core/src/vope/mod.rs | 65 ++++++++++ crates/mpz-zk-core/src/vope/receiver.rs | 162 ++++++++++++++++++++++++ crates/mpz-zk-core/src/vope/sender.rs | 128 +++++++++++++++++++ 8 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 crates/mpz-zk-core/Cargo.toml create mode 100644 crates/mpz-zk-core/src/lib.rs create mode 100644 crates/mpz-zk-core/src/test.rs create mode 100644 crates/mpz-zk-core/src/vope/error.rs create mode 100644 crates/mpz-zk-core/src/vope/mod.rs create mode 100644 crates/mpz-zk-core/src/vope/receiver.rs create mode 100644 crates/mpz-zk-core/src/vope/sender.rs diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index 0e4d1b48..d53e322b 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -11,7 +11,6 @@ pub mod chou_orlandi; pub mod ferret; -#[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod kos; diff --git a/crates/mpz-zk-core/Cargo.toml b/crates/mpz-zk-core/Cargo.toml new file mode 100644 index 00000000..e390db57 --- /dev/null +++ b/crates/mpz-zk-core/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "mpz-zk-core" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lints] +workspace = true + +[lib] +name = "mpz_zk_core" + +[features] +default = ["rayon", "test-utils"] +rayon = ["dep:rayon", "itybity/rayon"] +test-utils = [] + +[dependencies] +mpz-core.workspace = true +mpz-ot-core.workspace = true +clmul.workspace = true +matrix-transpose.workspace = true + +tlsn-utils.workspace = true + +rayon = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } +thiserror.workspace = true +derive_builder.workspace = true +itybity.workspace = true +opaque-debug.workspace = true +cfg-if.workspace = true +bytemuck = { workspace = true, features = ["derive"] } +enum-try-as-inner.workspace = true diff --git a/crates/mpz-zk-core/src/lib.rs b/crates/mpz-zk-core/src/lib.rs new file mode 100644 index 00000000..92ca8ef1 --- /dev/null +++ b/crates/mpz-zk-core/src/lib.rs @@ -0,0 +1,23 @@ +//! Low-level crate containing core functionalities for zero-knowledge protocols. +//! +//! This crate is not intended to be used directly. Instead, use the higher-level APIs provided by +//! the `mpz-zk` crate. +//! +//! # ⚠️ Warning ⚠️ +//! +//! Some implementations make assumptions about invariants which may not be checked if using these +//! low-level APIs naively. Failing to uphold these invariants may result in security vulnerabilities. +//! +//! USE AT YOUR OWN RISK. + +#![deny( + unsafe_code, + missing_docs, + unused_imports, + unused_must_use, + unreachable_pub, + clippy::all +)] + +pub mod test; +pub mod vope; diff --git a/crates/mpz-zk-core/src/test.rs b/crates/mpz-zk-core/src/test.rs new file mode 100644 index 00000000..b22369d3 --- /dev/null +++ b/crates/mpz-zk-core/src/test.rs @@ -0,0 +1,11 @@ +//! test functions. + +use mpz_core::Block; + +/// Check polynomial relation. +pub fn poly_check(a: &[Block], b: Block, delta: Block) -> bool { + b == a + .iter() + .rev() + .fold(Block::ZERO, |acc, &x| x ^ (delta.gfmul(acc))) +} diff --git a/crates/mpz-zk-core/src/vope/error.rs b/crates/mpz-zk-core/src/vope/error.rs new file mode 100644 index 00000000..101a11b3 --- /dev/null +++ b/crates/mpz-zk-core/src/vope/error.rs @@ -0,0 +1,21 @@ +//! Errors that can occur when using VOPE. + +/// Errors that can occur when using VOPE sender (verifier). +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error("invalid input: expected {0}")] + InvalidInput(String), + #[error("invalid length: expected {0}")] + InvalidLength(String), +} + +/// Errors that can occur when using VOPE receiver (prover). +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error("invalid input: expected {0}")] + InvalidInput(String), + #[error("invalid length: expected {0}")] + InvalidLength(String), +} diff --git a/crates/mpz-zk-core/src/vope/mod.rs b/crates/mpz-zk-core/src/vope/mod.rs new file mode 100644 index 00000000..3da511be --- /dev/null +++ b/crates/mpz-zk-core/src/vope/mod.rs @@ -0,0 +1,65 @@ +//! This is the implementation of vector oblivious polynomial evaluation (VOPE) based on Figure 4 in https://eprint.iacr.org/2021/076.pdf + +pub mod error; +pub mod receiver; +pub mod sender; + +/// Security parameter +pub const CSP: usize = 128; + +#[cfg(test)] +mod tests { + use mpz_core::prg::Prg; + use mpz_ot_core::{ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; + + use crate::test::poly_check; + + use super::{receiver::Receiver, sender::Sender, CSP}; + + #[test] + fn vope_test() { + let mut prg = Prg::new(); + let delta = prg.random_block(); + + let mut ideal_cot = IdealCOT::default(); + ideal_cot.set_delta(delta); + + let sender = Sender::new(); + let receiver = Receiver::new(); + + let mut sender = sender.setup(delta); + let mut receiver = receiver.setup(); + + let d = 1; + + let (sender_cot, receiver_cot) = ideal_cot.random_correlated((2 * d - 1) * CSP); + + let RCOTSenderOutput { msgs: ks, .. } = sender_cot; + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = receiver_cot; + + let sender_out = sender.extend(&ks, d).unwrap(); + let receiver_out = receiver.extend(&ms, &us, d).unwrap(); + + assert!(poly_check(&receiver_out, sender_out, delta)); + + let d = 5; + + let (sender_cot, receiver_cot) = ideal_cot.random_correlated((2 * d - 1) * CSP); + + let RCOTSenderOutput { msgs: ks, .. } = sender_cot; + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = receiver_cot; + + let sender_out = sender.extend(&ks, d).unwrap(); + let receiver_out = receiver.extend(&ms, &us, d).unwrap(); + + assert!(poly_check(&receiver_out, sender_out, delta)); + } +} diff --git a/crates/mpz-zk-core/src/vope/receiver.rs b/crates/mpz-zk-core/src/vope/receiver.rs new file mode 100644 index 00000000..9a0e4fce --- /dev/null +++ b/crates/mpz-zk-core/src/vope/receiver.rs @@ -0,0 +1,162 @@ +//! VOPE receiver. +use mpz_core::Block; + +use crate::vope::CSP; + +use super::error::ReceiverError; + +/// VOPE receiver +/// This is the prover in Figure 4. +#[derive(Debug, Default)] +pub struct Receiver { + state: T, +} + +impl Receiver { + /// Create a new receiver. + pub fn new() -> Self { + Receiver { + state: state::Initialized::default(), + } + } + + /// Completes the setup phase of the protocol. + /// + /// See Initialize in Figure 4. + pub fn setup(self) -> Receiver { + Receiver { + state: state::Extension { + vope_counter: 0, + exec_counter: 0, + }, + } + } +} + +impl Receiver { + /// Performs VOPE extension. + /// + /// See step 1-3 in Figure 4. + /// + /// # Arguments + /// + /// * `ms` - The blocks received by calling the COT ideal functionality. + /// * `us` - The bits received by calling the COT ideal functionality. + /// * `d` - The degree of the polynomial. + /// + /// Note that this functionality is only suitable for small d. + pub fn extend( + &mut self, + ms: &[Block], + us: &[bool], + d: usize, + ) -> Result, ReceiverError> { + if d == 0 { + return Err(ReceiverError::InvalidInput( + "the degree d should not be 0".to_string(), + )); + } + + if ms.len() != us.len() { + return Err(ReceiverError::InvalidLength( + "the length of ms and us should be equal".to_string(), + )); + } + + if ms.len() != (2 * d - 1) * CSP { + return Err(ReceiverError::InvalidLength( + "the length of ms and us should be (2 * d -1) * CSP".to_string(), + )); + } + + let mut h_ms = ms.to_vec(); + let mut h_us = us.to_vec(); + + let mut mi = vec![Block::ZERO; 2 * d - 1]; + let mut ui = vec![Block::ZERO; 2 * d - 1]; + + let base: Vec = (0..CSP) + .map(|x| bytemuck::cast((1_u128) << (CSP - 1 - x))) + .collect(); + + for i in 0..(2 * d - 1) { + let m = h_ms.split_off(CSP); + let u = h_us.split_off(CSP); + + mi[i] = Block::inn_prdt_red(&h_ms, &base); + + ui[i] = + h_us.iter().zip(base.iter()).fold( + Block::ZERO, + |acc, (b, base)| { + if *b { + acc ^ *base + } else { + acc + } + }, + ); + h_ms = m; + h_us = u; + } + + let mut gi = vec![Block::ZERO; d + 1]; + gi[0] = mi[0]; + gi[1] = ui[0]; + + for i in 0..d - 1 { + poly_update(&mut gi, mi[i + 1], ui[i + 1], i + 2); + gi[0] ^= mi[d + i]; + gi[1] ^= ui[d + i]; + } + + self.state.exec_counter += 1; + self.state.vope_counter += 1; + + Ok(gi) + } +} + +fn poly_update(g: &mut [Block], m: Block, u: Block, length: usize) { + let mut buffer = vec![Block::ZERO; length + 1]; + for i in 0..length { + buffer[i + 1] = g[i].gfmul(u); + g[i] = g[i].gfmul(m); + + g[i] ^= buffer[i]; + } + g[length] = buffer[length]; +} + +/// The receiver's state. +pub mod state { + mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + /// The receiver's state. + pub trait State: sealed::Sealed {} + + /// The receiver's initial state. + #[derive(Default)] + pub struct Initialized {} + + impl State for Initialized {} + opaque_debug::implement!(Initialized); + + /// The receiver's state after the setup phase. + /// + /// In this state the sender performs VOPE extension. + pub struct Extension { + /// Current VOPE counter + pub(super) vope_counter: usize, + /// Current execution counter + pub(super) exec_counter: usize, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} diff --git a/crates/mpz-zk-core/src/vope/sender.rs b/crates/mpz-zk-core/src/vope/sender.rs new file mode 100644 index 00000000..2bca3fd6 --- /dev/null +++ b/crates/mpz-zk-core/src/vope/sender.rs @@ -0,0 +1,128 @@ +//! VOPE sender. +use mpz_core::Block; + +use crate::vope::CSP; + +use super::error::SenderError; + +/// VOPE sender +/// This is the verifier in Figure 4. +#[derive(Debug, Default)] +pub struct Sender { + state: T, +} + +impl Sender { + /// Creates a new sender. + pub fn new() -> Self { + Sender { + state: state::Initialized::default(), + } + } + + /// Completes the setup phase of the protocol. + /// + /// See Initialize in Figure 4. + /// + /// # Arguments. + /// + /// * `delta` - The sender's global secret. + pub fn setup(self, delta: Block) -> Sender { + Sender { + state: state::Extension { + delta, + vope_counter: 0, + exec_counter: 0, + }, + } + } +} + +impl Sender { + /// Performs VOPE extension. + /// + /// See step 1-3 in Figure 4. + /// + /// # Arguments + /// + /// * `ks` - The blocks received by calling the COT ideal functionality. + /// * `d` - The degree of the polynomial. + /// + /// Note that this functionality is only suitable for small d. + pub fn extend(&mut self, ks: &[Block], d: usize) -> Result { + if d == 0 { + return Err(SenderError::InvalidInput( + "the degree d should not be 0".to_string(), + )); + } + + if ks.len() != (2 * d - 1) * CSP { + return Err(SenderError::InvalidLength( + "the length of ks should be (2 * d -1) * CSP".to_string(), + )); + } + + let mut ki = vec![Block::ZERO; 2 * d - 1]; + + let base: Vec = (0..CSP) + .map(|x| bytemuck::cast((1_u128) << (CSP - 1 - x))) + .collect(); + + let mut h_ks = ks.to_vec(); + + for k in ki.iter_mut().take(2 * d - 1) { + let buf = h_ks.split_off(CSP); + *k = Block::inn_prdt_red(&h_ks, &base); + h_ks = buf; + } + + let mut b = ki[0]; + + for i in 0..d - 1 { + b = b.gfmul(ki[i + 1]) ^ ki[d + i] + } + + self.state.exec_counter += 1; + self.state.vope_counter += 1; + + Ok(b) + } +} +/// The sender's state. +pub mod state { + use super::*; + + mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + /// The sender's state. + pub trait State: sealed::Sealed {} + + /// The sender's initial state. + #[derive(Default)] + pub struct Initialized {} + + impl State for Initialized {} + opaque_debug::implement!(Initialized); + + /// The sender's state after the setup phase. + /// + /// In this state the sender performs VOPE extension. + pub struct Extension { + /// Sender's global secret. + #[allow(dead_code)] + pub(crate) delta: Block, + + /// Current VOPE counter + pub(super) vope_counter: usize, + /// Current execution counter + pub(super) exec_counter: usize, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} From b87caf78bf37312837723f4d335bf75f929a8a14 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 2 Jul 2024 01:25:28 +0800 Subject: [PATCH 05/17] fmt --- Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index d1e7c492..6557b9c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ members = [ "crates/clmul", "crates/mpz-ole-core", "crates/mpz-ole", + "crates/mpz-zk-core", + "crates/mpz-zk", ] resolver = "2" @@ -43,6 +45,8 @@ mpz-ole = { path = "crates/mpz-ole" } mpz-ole-core = { path = "crates/mpz-ole-core" } clmul = { path = "crates/clmul" } matrix-transpose = { path = "crates/matrix-transpose" } +mpz-zk-core = { path = "crates/mpz-zk-core" } +mpz-zk = { path = "crates/mpz-zk" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } From 116035e929c217aae106ed253d731f169781227f Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Thu, 15 Aug 2024 19:08:06 -0700 Subject: [PATCH 06/17] refactor(mpz-ot): ferret clean up (#173) * refactor(mpz-ot): ferret clean up * buffer OTs, setup rcot only invoked once * fix mpcot test --------- Co-authored-by: Xiang Xie --- crates/mpz-common/src/ideal.rs | 4 +- crates/mpz-ot-core/src/ferret/mod.rs | 30 +- crates/mpz-ot-core/src/ferret/receiver.rs | 51 ++- crates/mpz-ot-core/src/ferret/sender.rs | 49 ++- crates/mpz-ot-core/src/lib.rs | 2 +- crates/mpz-ot/src/ferret/error.rs | 375 ++++++++++++++++++--- crates/mpz-ot/src/ferret/mod.rs | 123 ++----- crates/mpz-ot/src/ferret/mpcot.rs | 185 ++++++++++ crates/mpz-ot/src/ferret/mpcot/error.rs | 59 ---- crates/mpz-ot/src/ferret/mpcot/mod.rs | 165 --------- crates/mpz-ot/src/ferret/mpcot/receiver.rs | 192 ----------- crates/mpz-ot/src/ferret/mpcot/sender.rs | 166 --------- crates/mpz-ot/src/ferret/receiver.rs | 265 +++++++++------ crates/mpz-ot/src/ferret/sender.rs | 298 +++++++++++----- crates/mpz-ot/src/ferret/spcot.rs | 167 +++++++++ crates/mpz-ot/src/ferret/spcot/error.rs | 59 ---- crates/mpz-ot/src/ferret/spcot/mod.rs | 103 ------ crates/mpz-ot/src/ferret/spcot/receiver.rs | 164 --------- crates/mpz-ot/src/ferret/spcot/sender.rs | 144 -------- crates/mpz-ot/src/ideal/cot.rs | 16 +- crates/mpz-ot/src/lib.rs | 13 +- 21 files changed, 1227 insertions(+), 1403 deletions(-) create mode 100644 crates/mpz-ot/src/ferret/mpcot.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/error.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/mod.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/receiver.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/sender.rs create mode 100644 crates/mpz-ot/src/ferret/spcot.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/error.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/mod.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/receiver.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/sender.rs diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 7fcb1628..1b6b3181 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -35,7 +35,7 @@ impl Clone for Alice { impl Alice { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } @@ -96,7 +96,7 @@ impl Clone for Bob { impl Bob { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index bbbf264a..0e27f0a9 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -52,9 +52,11 @@ mod tests { use receiver::Receiver; use sender::Sender; - use crate::ideal::{cot::IdealCOT, mpcot::IdealMpcot}; - use crate::test::assert_cot; - use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; + use crate::{ + ideal::{cot::IdealCOT, mpcot::IdealMpcot}, + test::assert_cot, + MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, + }; use mpz_core::{lpn::LpnParameters, prg::Prg}; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { @@ -111,8 +113,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(2).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(2).unwrap(); assert_cot(delta, &choices, &msgs, &received); @@ -123,8 +132,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(sender.remaining()).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(receiver.remaining()).unwrap(); assert_cot(delta, &choices, &msgs, &received); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index e5939c60..1cfd1e08 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -1,4 +1,6 @@ //! Ferret receiver +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, @@ -6,7 +8,7 @@ use mpz_core::{ use crate::{ ferret::{error::ReceiverError, LpnType}, - TransferId, + RCOTReceiverOutput, TransferId, }; use super::msgs::LpnMatrixSeed; @@ -63,6 +65,8 @@ impl Receiver { w: w.to_vec(), e: Vec::default(), id: TransferId::default(), + choices_buffer: VecDeque::new(), + msgs_buffer: VecDeque::new(), }, }, LpnMatrixSeed { seed }, @@ -71,6 +75,16 @@ impl Receiver { } impl Receiver { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.choices_buffer.len() + } + /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { @@ -100,7 +114,7 @@ impl Receiver { /// # Arguments. /// /// * `r` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, r: &[Block]) -> Result<(Vec, Vec), ReceiverError> { + pub fn extend(&mut self, r: Vec) -> Result<(), ReceiverError> { if r.len() != self.state.lpn_parameters.n { return Err(ReceiverError("the length of r should be n".to_string())); } @@ -108,7 +122,7 @@ impl Receiver { self.state.id.next(); // Compute z = A * w + r. - let mut z = r.to_vec(); + let mut z = r; self.state.lpn_encoder.compute(&mut z, &self.state.w); // Compute x = A * u + e. @@ -133,12 +147,32 @@ impl Receiver { // Update counter self.state.counter += 1; - Ok((x_, z_)) + self.state.choices_buffer.extend(x_); + self.state.msgs_buffer.extend(z_); + + Ok(()) } - /// Returns id - pub fn id(&self) -> TransferId { - self.state.id + /// Consumes `count` COTs. + pub fn consume( + &mut self, + count: usize, + ) -> Result, ReceiverError> { + if count > self.state.choices_buffer.len() { + return Err(ReceiverError(format!( + "insufficient OTs: {} < {count}", + self.state.choices_buffer.len() + ))); + } + + let choices = self.state.choices_buffer.drain(0..count).collect(); + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTReceiverOutput { + id: self.state.id.next(), + choices, + msgs, + }) } } @@ -186,6 +220,9 @@ pub mod state { /// TransferID pub(super) id: TransferId, + /// Extended OTs buffers. + pub(super) choices_buffer: VecDeque, + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 2af3e4ae..436d6003 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -1,4 +1,6 @@ //! Ferret sender. +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, @@ -6,7 +8,7 @@ use mpz_core::{ use crate::{ ferret::{error::SenderError, LpnType}, - TransferId, + RCOTSenderOutput, TransferId, }; use super::msgs::LpnMatrixSeed; @@ -61,12 +63,28 @@ impl Sender { lpn_encoder, v: v.to_vec(), id: TransferId::default(), + msgs_buffer: VecDeque::new(), }, }) } } impl Sender { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.msgs_buffer.len() + } + + /// Returns the delta correlation. + pub fn delta(&self) -> Block { + self.state.delta + } + /// Outputs the information for MPCOT. /// /// See step 3 and 4. @@ -86,7 +104,7 @@ impl Sender { /// # Arguments. /// /// * `s` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, s: &[Block]) -> Result, SenderError> { + pub fn extend(&mut self, s: Vec) -> Result<(), SenderError> { if s.len() != self.state.lpn_parameters.n { return Err(SenderError("the length of s should be n".to_string())); } @@ -94,7 +112,7 @@ impl Sender { self.state.id.next(); // Compute y = A * v + s - let mut y = s.to_vec(); + let mut y = s; self.state.lpn_encoder.compute(&mut y, &self.state.v); let y_ = y.split_off(self.state.lpn_parameters.k); @@ -104,13 +122,26 @@ impl Sender { // Update counter self.state.counter += 1; + self.state.msgs_buffer.extend(y_); - Ok(y_) + Ok(()) } - /// Returns id - pub fn id(&self) -> TransferId { - self.state.id + /// Consumes `count` COTs. + pub fn consume(&mut self, count: usize) -> Result, SenderError> { + if count > self.state.msgs_buffer.len() { + return Err(SenderError(format!( + "insufficient OTs: {} < {count}", + self.state.msgs_buffer.len() + ))); + } + + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTSenderOutput { + id: self.state.id.next(), + msgs, + }) } } @@ -159,8 +190,10 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, - /// TransferID. + /// Transfer ID. pub(crate) id: TransferId, + /// COT messages buffer. + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index 8dd77287..dcfffc59 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId { impl TransferId { /// Returns the current transfer ID, incrementing `self` in-place. - pub(crate) fn next(&mut self) -> Self { + pub fn next(&mut self) -> Self { let id = *self; self.0 += 1; id diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs index 6952f0ec..4e428a4b 100644 --- a/crates/mpz-ot/src/ferret/error.rs +++ b/crates/mpz-ot/src/ferret/error.rs @@ -1,67 +1,342 @@ -use crate::OTError; +use std::fmt::Display; -/// A Ferret sender error. +/// Ferret sender error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::error::SenderError), - #[error(transparent)] - MPCOTSenderError(#[from] crate::ferret::mpcot::SenderError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), - #[error("{0}")] - MPCOTSenderTypeError(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), +pub struct SenderError { + kind: SenderErrorKind, + #[source] + source: Option>, +} + +impl SenderError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum SenderErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for SenderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + SenderErrorKind::Io => f.write_str("io error")?, + SenderErrorKind::State => f.write_str("state error")?, + SenderErrorKind::Core => f.write_str("core error")?, + SenderErrorKind::Rcot => f.write_str("rcot error")?, + SenderErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for SenderError { + fn from(err: std::io::Error) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: mpz_ot_core::ferret::error::SenderError) -> Self { + Self { + kind: SenderErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: crate::OTError) -> Self { + Self { + kind: SenderErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: MPCOTError) -> Self { + Self { + kind: SenderErrorKind::Mpcot, + source: Some(Box::new(err)), } } } -impl From for SenderError { - fn from(err: crate::ferret::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) +impl From for crate::OTError { + fn from(err: SenderError) -> Self { + crate::OTError::SenderError(Box::new(err)) } } -/// A Ferret receiver error. +/// Ferret receiver error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::error::ReceiverError), - #[error(transparent)] - MPCOTReceiverError(#[from] crate::ferret::mpcot::ReceiverError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), - #[error("{0}")] - MPCOTReceiverTypeError(String), -} - -impl From for OTError { +pub struct ReceiverError { + kind: ReceiverErrorKind, + #[source] + source: Option>, +} + +impl ReceiverError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum ReceiverErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for ReceiverError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ReceiverErrorKind::Io => f.write_str("io error")?, + ReceiverErrorKind::State => f.write_str("state error")?, + ReceiverErrorKind::Core => f.write_str("core error")?, + ReceiverErrorKind::Rcot => f.write_str("rcot error")?, + ReceiverErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for ReceiverError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: mpz_ot_core::ferret::error::ReceiverError) -> Self { + Self { + kind: ReceiverErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ReceiverErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: MPCOTError) -> Self { + Self { + kind: ReceiverErrorKind::Mpcot, + source: Some(Box::new(err)), + } + } +} + +impl From for crate::OTError { fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), + crate::OTError::ReceiverError(Box::new(err)) + } +} + +mod mpcot { + use super::*; + + /// MPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct MPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + Spcot, + } + + impl Display for MPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + ErrorKind::Spcot => f.write_str("spcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for MPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: SPCOTError) -> Self { + Self { + kind: ErrorKind::Spcot, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } } } } +pub(crate) use mpcot::MPCOTError; + +mod spcot { + use super::*; + + /// SPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct SPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + } + + impl Display for SPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + } -impl From for ReceiverError { - fn from(err: crate::ferret::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for SPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } + } } } +pub(crate) use spcot::SPCOTError; diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs index 2b2047b9..086e5e8b 100644 --- a/crates/mpz-ot/src/ferret/mod.rs +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -13,47 +13,26 @@ use mpz_core::lpn::LpnParameters; use mpz_ot_core::ferret::LpnType; /// Configuration of Ferret. -#[derive(Debug)] -pub struct FerretConfig { - rcot: RandomCOT, - setup_rcot: SetupRandomCOT, +#[derive(Debug, Clone)] +pub struct FerretConfig { lpn_parameters: LpnParameters, lpn_type: LpnType, } -impl FerretConfig { +impl FerretConfig { /// Create a new instance. /// /// # Arguments. /// - /// * `rcot` - The rcot for MPCOT. - /// * `setup_rcot` - The rcot for setup. /// * `lpn_parameters` - The parameters of LPN. /// * `lpn_type` - The type of LPN. - pub fn new( - rcot: RandomCOT, - setup_rcot: SetupRandomCOT, - lpn_parameters: LpnParameters, - lpn_type: LpnType, - ) -> Self { + pub fn new(lpn_parameters: LpnParameters, lpn_type: LpnType) -> Self { Self { - rcot, - setup_rcot, lpn_parameters, lpn_type, } } - /// Get rcot - pub fn rcot(&self) -> RandomCOT { - self.rcot.clone() - } - - /// Get the setup rcot - pub fn setup_rcot(&mut self) -> &mut SetupRandomCOT { - &mut self.setup_rcot - } - /// Get the lpn type pub fn lpn_type(&self) -> LpnType { self.lpn_type @@ -67,17 +46,14 @@ impl FerretConfig { #[cfg(test)] mod tests { - use futures::TryFutureExt; + use super::*; + use futures::TryFutureExt as _; use mpz_common::executor::test_st_executor; - use mpz_core::{lpn::LpnParameters, Block}; + use mpz_core::lpn::LpnParameters; use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + use rstest::*; - use crate::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - OTError, RandomCOTReceiver, RandomCOTSender, - }; - - use super::*; + use crate::{ideal::cot::ideal_rcot, Correlation, OTError, RandomCOTReceiver, RandomCOTSender}; // l = n - k = 8380 const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { @@ -86,73 +62,46 @@ mod tests { t: 600, }; - fn setup() -> ( - Sender, - Receiver, - Block, - ) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); - - let sender_config = FerretConfig::new( - rcot_sender.clone(), - rcot_sender.clone(), - LPN_PARAMETERS_TEST, - LpnType::Regular, - ); - - let receiver_config = FerretConfig::new( - rcot_receiver.clone(), - rcot_receiver, - LPN_PARAMETERS_TEST, - LpnType::Regular, - ); - - let delta = rcot_sender.alice().get_mut().delta(); - - let sender = Sender::new(sender_config); - - let receiver = Receiver::new(receiver_config); - - (sender, receiver, delta) - } - + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] #[tokio::test] - async fn test_ferret() { + async fn test_ferret(#[case] lpn_type: LpnType) { let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver, delta) = setup(); + let (rcot_sender, rcot_receiver) = ideal_rcot(); + + let config = FerretConfig::new(LPN_PARAMETERS_TEST, lpn_type); + + let mut sender = Sender::new(config.clone(), rcot_sender); + let mut receiver = Receiver::new(config, rcot_receiver); tokio::try_join!( - sender - .setup_with_delta(&mut ctx_sender, delta) - .map_err(OTError::from), + sender.setup(&mut ctx_sender).map_err(OTError::from), receiver.setup(&mut ctx_receiver).map_err(OTError::from) ) .unwrap(); // extend once. - let count = 8000; - let ( - RCOTSenderOutput { - id: sender_id, - msgs: u, - }, - RCOTReceiverOutput { - id: receiver_id, - choices: b, - msgs: w, - }, - ) = tokio::try_join!( - sender.send_random_correlated(&mut ctx_sender, count), - receiver.receive_random_correlated(&mut ctx_receiver, count) + let count = LPN_PARAMETERS_TEST.k; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) ) .unwrap(); - assert_eq!(sender_id, receiver_id); - assert_cot(delta, &b, &u, &w); - // extend twice - let count = 9000; + let count = 10000; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) + ) + .unwrap(); + let ( RCOTSenderOutput { id: sender_id, @@ -170,6 +119,6 @@ mod tests { .unwrap(); assert_eq!(sender_id, receiver_id); - assert_cot(delta, &b, &u, &w); + assert_cot(sender.delta(), &b, &u, &w); } } diff --git a/crates/mpz-ot/src/ferret/mpcot.rs b/crates/mpz-ot/src/ferret/mpcot.rs new file mode 100644 index 00000000..be7de33a --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot.rs @@ -0,0 +1,185 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, receiver::Receiver as UniformReceiverCore, + receiver_regular::Receiver as RegularReceiverCore, sender::Sender as UniformSender, + sender_regular::Sender as RegularSender, + }, + LpnType, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ + ferret::{error::MPCOTError as Error, spcot}, + RandomCOTReceiver, RandomCOTSender, +}; + +/// MPCOT send. +/// +/// # Arguments. +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `lpn_type` - The type of LPN. +/// * `t` - The number of queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + lpn_type: LpnType, + t: u32, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let (sender, hs) = CpuBackend::blocking(move || { + UniformSender::new() + .setup(delta, hash_seed) + .pre_extend(t, n) + }) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + LpnType::Regular => { + let (sender, hs) = + CpuBackend::blocking(move || RegularSender::new().setup(delta).pre_extend(t, n)) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + } +} + +/// MPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `lpn_type` - The type of LPN. +/// * `alphas` - The queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + lpn_type: LpnType, + alphas: Vec, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed = Prg::new().random_block(); + + let (receiver, hash_seed) = UniformReceiverCore::new().setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + LpnType::Regular => { + let receiver = RegularReceiverCore::new().setup(); + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ideal::cot::ideal_rcot; + use mpz_common::executor::test_st_executor; + use mpz_ot_core::ferret::LpnType; + use rstest::*; + + #[rstest] + #[case(LpnType::Uniform)] + #[case(LpnType::Regular)] + #[tokio::test] + async fn test_mpcot(#[case] lpn_type: LpnType) { + use crate::Correlation; + + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let alphas = match lpn_type { + LpnType::Uniform => vec![0, 1, 3, 4, 2], + LpnType::Regular => vec![0, 3, 4, 7, 9], + }; + + let t = alphas.len(); + let n = 10; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send( + &mut ctx_sender, + &mut rcot_sender, + delta, + lpn_type, + t as u32, + n + ), + receive( + &mut ctx_receiver, + &mut rcot_receiver, + lpn_type, + alphas.clone(), + n + ) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs deleted file mode 100644 index e300bf0d..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/error.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::OTError; - -/// A MPCOT sender error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::mpcot::error::SenderError), - #[error(transparent)] - SPCOTSenderError(#[from] crate::ferret::spcot::SenderError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), - } - } -} - -impl From for SenderError { - fn from(err: crate::ferret::mpcot::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) - } -} - -/// A MPCOT receiver error -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::mpcot::error::ReceiverError), - #[error(transparent)] - SpcotReceiverError(#[from] crate::ferret::spcot::ReceiverError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), - } - } -} - -impl From for ReceiverError { - fn from(err: crate::ferret::mpcot::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) - } -} diff --git a/crates/mpz-ot/src/ferret/mpcot/mod.rs b/crates/mpz-ot/src/ferret/mpcot/mod.rs deleted file mode 100644 index 598b5734..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/mod.rs +++ /dev/null @@ -1,165 +0,0 @@ -//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. - -mod error; -mod receiver; -mod sender; - -pub(crate) use error::{ReceiverError, SenderError}; -pub(crate) use receiver::Receiver; -pub(crate) use sender::Sender; - -#[cfg(test)] -mod tests { - use futures::TryFutureExt; - use mpz_common::executor::test_st_executor; - use mpz_core::Block; - use mpz_ot_core::ferret::LpnType; - - use crate::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - OTError, - }; - - use receiver::Receiver; - use sender::Sender; - - use super::*; - - fn setup( - lpn_type: LpnType, - ) -> ( - Sender, - Receiver, - IdealCOTSender, - IdealCOTReceiver, - Block, - ) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); - - let delta = rcot_sender.alice().get_mut().delta(); - - let sender = Sender::new(lpn_type); - - let receiver = Receiver::new(lpn_type); - - (sender, receiver, rcot_sender, rcot_receiver, delta) - } - - #[tokio::test] - async fn test_mpcot() { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - - let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Uniform); - - let alphas = [0, 1, 3, 4, 2]; - let t = alphas.len(); - let n = 10; - - tokio::try_join!( - sender - .setup_with_delta(&mut ctx_sender, delta, rcot_sender) - .map_err(OTError::from), - receiver - .setup(&mut ctx_receiver, rcot_receiver) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - // extend twice - let alphas = [5, 1, 7, 2]; - let t = alphas.len(); - let n = 16; - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - sender.finalize().unwrap(); - receiver.finalize().unwrap(); - - let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Regular); - - // extend once. - let alphas = [0, 3, 4, 7, 9]; - let t = alphas.len(); - let n = 10; - - tokio::try_join!( - sender - .setup_with_delta(&mut ctx_sender, delta, rcot_sender) - .map_err(OTError::from), - receiver - .setup(&mut ctx_receiver, rcot_receiver) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - // extend twice. - let alphas = [0, 3, 7, 9, 14, 15]; - let t = alphas.len(); - let n = 16; - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - sender.finalize().unwrap(); - receiver.finalize().unwrap(); - } -} diff --git a/crates/mpz-ot/src/ferret/mpcot/receiver.rs b/crates/mpz-ot/src/ferret/mpcot/receiver.rs deleted file mode 100644 index e2553efd..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/receiver.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::{ - ferret::{mpcot::error::ReceiverError, spcot::Receiver as SpcotReceiver}, - RandomCOTReceiver, -}; -use enum_try_as_inner::EnumTryAsInner; - -use mpz_common::Context; -use mpz_core::{prg::Prg, Block}; -use mpz_ot_core::ferret::{ - mpcot::{ - receiver::{state as uniform_state, Receiver as UniformReceiverCore}, - receiver_regular::{state as regular_state, Receiver as RegularReceiverCore}, - }, - LpnType, -}; -use serio::SinkExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - UniformInitialized(UniformReceiverCore), - UniformExtension(UniformReceiverCore), - RegularInitialized(RegularReceiverCore), - RegularExtension(RegularReceiverCore), - Complete, - Error, -} - -/// MPCOT receiver. -#[derive(Debug)] -pub(crate) struct Receiver { - state: State, - spcot: SpcotReceiver, - lpn_type: LpnType, -} - -impl Receiver { - /// Creates a new Sender. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN. - pub(crate) fn new(lpn_type: LpnType) -> Self { - match lpn_type { - LpnType::Uniform => Self { - state: State::UniformInitialized(UniformReceiverCore::new()), - spcot: crate::ferret::spcot::Receiver::new(), - lpn_type, - }, - LpnType::Regular => Self { - state: State::RegularInitialized(RegularReceiverCore::new()), - spcot: crate::ferret::spcot::Receiver::new(), - lpn_type, - }, - } - } - - /// Performs setup for receiver. - /// - /// # Arguments - /// - /// * `ctx` - The context. - /// * `rcot` - The random COT used by Receiver. - pub(crate) async fn setup( - &mut self, - ctx: &mut Ctx, - rcot: RandomCOT, - ) -> Result<(), ReceiverError> { - match self.lpn_type { - LpnType::Uniform => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_initialized()?; - - let hash_seed = Prg::new().random_block(); - - let (ext_receiver, hash_seed) = ext_receiver.setup(hash_seed); - - ctx.io_mut().send(hash_seed).await?; - - self.state = State::UniformExtension(ext_receiver); - } - LpnType::Regular => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_initialized()?; - - let ext_receiver = ext_receiver.setup(); - - self.state = State::RegularExtension(ext_receiver); - } - } - - self.spcot.setup(rcot)?; - - Ok(()) - } - - /// Performs MPCOT extension. - /// - /// - /// # Arguments - /// - /// * `ctx` - The context, - /// * `alphas` - The queried indices. - /// * `n` - The total number of indices. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - alphas: &[u32], - n: u32, - ) -> Result, ReceiverError> - where - RandomCOT: RandomCOTReceiver, - { - let alphas_vec = alphas.to_vec(); - - match self.lpn_type { - LpnType::Uniform => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_extension()?; - - let (ext_receiver, h_and_pos) = - Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; - - let mut hs = vec![0usize; h_and_pos.len()]; - - let mut pos = vec![0u32; h_and_pos.len()]; - for (index, (h, p)) in h_and_pos.iter().enumerate() { - hs[index] = *h; - pos[index] = *p; - } - - self.spcot.extend(ctx, &pos, &hs).await?; - - let rt = self.spcot.check(ctx).await?; - - let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); - let (ext_receiver, output) = - Backend::spawn(move || ext_receiver.extend(&rt)).await?; - - self.state = State::UniformExtension(ext_receiver); - - Ok(output) - } - - LpnType::Regular => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_extension()?; - - let (ext_receiver, h_and_pos) = - Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; - - let mut hs = vec![0usize; h_and_pos.len()]; - - let mut pos = vec![0u32; h_and_pos.len()]; - for (index, (h, p)) in h_and_pos.iter().enumerate() { - hs[index] = *h; - pos[index] = *p; - } - - self.spcot.extend(ctx, &pos, &hs).await?; - - let rt = self.spcot.check(ctx).await?; - - let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); - let (ext_receiver, output) = - Backend::spawn(move || ext_receiver.extend(&rt)).await?; - - self.state = State::RegularExtension(ext_receiver); - - Ok(output) - } - } - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { - match self.lpn_type { - LpnType::Uniform => { - std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; - } - LpnType::Regular => { - std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; - } - } - - self.spcot.finalize()?; - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ferret/mpcot/sender.rs b/crates/mpz-ot/src/ferret/mpcot/sender.rs deleted file mode 100644 index a0256276..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/sender.rs +++ /dev/null @@ -1,166 +0,0 @@ -use crate::{ - ferret::{mpcot::error::SenderError, spcot::Sender as SpcotSender}, - RandomCOTSender, -}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; -use mpz_core::Block; -use mpz_ot_core::ferret::{ - mpcot::{ - msgs::HashSeed, - sender::{state as uniform_state, Sender as UniformSenderCore}, - sender_regular::{state as regular_state, Sender as RegularSenderCore}, - }, - LpnType, -}; -use serio::stream::IoStreamExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - UniformInitialized(UniformSenderCore), - UniformExtension(UniformSenderCore), - RegularInitialized(RegularSenderCore), - RegularExtension(RegularSenderCore), - Complete, - Error, -} - -/// MPCOT sender. -#[derive(Debug)] -pub(crate) struct Sender { - state: State, - spcot: SpcotSender, - lpn_type: LpnType, -} - -impl Sender { - /// Creates a new Sender. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN. - pub(crate) fn new(lpn_type: LpnType) -> Self { - match lpn_type { - LpnType::Uniform => Self { - state: State::UniformInitialized(UniformSenderCore::new()), - spcot: crate::ferret::spcot::Sender::new(), - lpn_type, - }, - LpnType::Regular => Self { - state: State::RegularInitialized(RegularSenderCore::new()), - spcot: crate::ferret::spcot::Sender::new(), - lpn_type, - }, - } - } - - /// Performs setup with provided delta. - /// - /// # Arguments - /// - /// * `ctx` - The channel. - /// * `delta` - The delta value to use for OT extension. - /// * `rcot` - The random COT used by Sender. - pub(crate) async fn setup_with_delta( - &mut self, - ctx: &mut Ctx, - delta: Block, - rcot: RandomCOT, - ) -> Result<(), SenderError> { - match self.lpn_type { - LpnType::Uniform => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_initialized()?; - - let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; - - let ext_sender = ext_sender.setup(delta, hash_seed); - - self.state = State::UniformExtension(ext_sender); - } - - LpnType::Regular => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_initialized()?; - - let ext_sender = ext_sender.setup(delta); - - self.state = State::RegularExtension(ext_sender); - } - } - - self.spcot.setup_with_delta(delta, rcot)?; - - Ok(()) - } - - /// Performs MPCOT extension. - /// - /// - /// # Arguments. - /// - /// * `ctx` - The context. - /// * `t` - The number of queried indices. - /// * `n` - The total number of indices. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - t: u32, - n: u32, - ) -> Result, SenderError> - where - RandomCOT: RandomCOTSender, - { - match self.lpn_type { - LpnType::Uniform => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_extension()?; - - let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; - - self.spcot.extend(ctx, &hs).await?; - - let st = self.spcot.check(ctx).await?; - - let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; - - self.state = State::UniformExtension(ext_sender); - Ok(output) - } - LpnType::Regular => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_extension()?; - - let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; - - self.spcot.extend(ctx, &hs).await?; - - let st = self.spcot.check(ctx).await?; - - let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; - - self.state = State::RegularExtension(ext_sender); - Ok(output) - } - } - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { - match self.lpn_type { - LpnType::Uniform => { - std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; - } - LpnType::Regular => { - std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; - } - } - - self.spcot.finalize()?; - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs index 520506e8..d04e5d29 100644 --- a/crates/mpz-ot/src/ferret/receiver.rs +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -1,52 +1,61 @@ -use crate::{ - ferret::{mpcot::Receiver as MpcotReceiver, ReceiverError}, - RandomCOTReceiver, -}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; +use std::mem; + +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; use mpz_core::{prg::Prg, Block}; use mpz_ot_core::{ - ferret::receiver::{state, Receiver as ReceiverCore}, + ferret::{ + receiver::{state, Receiver as ReceiverCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, RCOTReceiverOutput, }; use serio::SinkExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; -use super::FerretConfig; -use crate::{async_trait, OTError}; +use crate::{ + ferret::{mpcot, FerretConfig, ReceiverError}, + OTError, RandomCOTReceiver, +}; -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] +#[derive(Debug)] pub(crate) enum State { Initialized(ReceiverCore), Extension(ReceiverCore), - Complete, Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + /// Ferret Receiver. #[derive(Debug)] -pub struct Receiver { +pub struct Receiver { state: State, - mpcot: MpcotReceiver, - config: FerretConfig, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: ReceiverBuffer, + buffer_len: usize, } -impl Receiver -where - RandomCOT: Send + Default + Clone, - SetupRandomCOT: Send, -{ +impl Receiver { /// Creates a new Receiver. /// /// # Arguments. /// - /// * `config` - Ferret configuration. - pub fn new(config: FerretConfig) -> Self { + /// * `config` - The Ferret config. + /// * `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { Self { state: State::Initialized(ReceiverCore::new()), - mpcot: MpcotReceiver::new(config.lpn_type()), config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, } } @@ -58,36 +67,59 @@ where pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> where Ctx: Context, - SetupRandomCOT: RandomCOTReceiver, + RandomCOT: RandomCOTReceiver, { - let ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let rcot = self.config.rcot(); - self.mpcot.setup(ctx, rcot).await?; + let State::Initialized(receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in initialized state")); + }; let params = self.config.lpn_parameters(); let lpn_type = self.config.lpn_type(); - // Get random blocks from ideal Random COT. + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + // Get random blocks from ideal Random COT. let RCOTReceiverOutput { - choices: u, - msgs: w, - .. + choices: mut u, + msgs: mut w, + id, } = self - .config - .setup_rcot() - .receive_random_correlated(ctx, params.k) + .rcot + .receive_random_correlated(ctx, params.k + self.buffer_len) .await?; + // Initiate buffer. + let buffer = RCOTReceiverOutput { + id, + choices: u.drain(0..self.buffer_len).collect(), + msgs: w.drain(0..self.buffer_len).collect(), + }; + self.buffer = ReceiverBuffer::new(buffer); + let seed = Prg::new().random_block(); - let (ext_receiver, seed) = ext_receiver.setup(params, lpn_type, seed, &u, &w)?; + let (receiver, seed) = receiver.setup(params, lpn_type, seed, &u, &w)?; ctx.io_mut().send(seed).await?; - self.state = State::Extension(ext_receiver); + self.state = State::Extension(receiver); Ok(()) } @@ -96,97 +128,126 @@ where /// /// # Arguments /// - /// * `ctx` - The channel context. - async fn extend(&mut self, ctx: &mut Ctx) -> Result<(Vec, Vec), ReceiverError> + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend(&mut self, ctx: &mut Ctx, count: usize) -> Result<(), ReceiverError> where Ctx: Context, - RandomCOT: RandomCOTReceiver, + RandomCOT: RandomCOTReceiver + Send, { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + let State::Extension(mut receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in extension state")); + }; - let (alphas, n) = ext_receiver.get_mpcot_query(); + let lpn_type = self.config.lpn_type(); + let target = receiver.remaining() + count; + while receiver.remaining() < target { + let (alphas, n) = receiver.get_mpcot_query(); - let r = self.mpcot.extend(ctx, &alphas, n as u32).await?; + let r = mpcot::receive(ctx, &mut self.buffer, lpn_type, alphas, n as u32).await?; - let (ext_receiver, choices, msgs) = Backend::spawn(move || { - ext_receiver - .extend(&r) - .map(|(choices, msgs)| (ext_receiver, choices, msgs)) - }) - .await?; + receiver = CpuBackend::blocking(move || receiver.extend(r).map(|()| receiver)).await?; - self.state = State::Extension(ext_receiver); + // Update receiver buffer. + let buffer = receiver + .consume(self.buffer_len) + .map_err(ReceiverError::from) + .map_err(OTError::from)?; - Ok((choices, msgs)) - } + self.buffer = ReceiverBuffer::new(buffer); + } - /// Complete extension - pub fn finalize(&mut self) -> Result<(), ReceiverError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - self.state = State::Complete; - self.mpcot.finalize()?; + self.state = State::Extension(receiver); Ok(()) } } #[async_trait] -impl RandomCOTReceiver - for Receiver +impl RandomCOTReceiver for Receiver where - Ctx: Context, - RandomCOT: RandomCOTReceiver + Send + Clone + Default + 'static, - SetupRandomCOT: Send + 'static, + RandomCOT: Send, { async fn receive_random_correlated( &mut self, - ctx: &mut Ctx, + _ctx: &mut Ctx, count: usize, ) -> Result, OTError> { - let (mut choices_buffer, mut msgs_buffer) = self.extend(ctx).await?; - - assert_eq!(choices_buffer.len(), msgs_buffer.len()); - - let l = choices_buffer.len(); - - let id = self - .state - .try_as_extension() - .map_err(ReceiverError::from)? - .id(); - - if count <= l { - let choices_res = choices_buffer.drain(..count).collect(); + let State::Extension(receiver) = &mut self.state else { + return Err(ReceiverError::state("receiver not in extension state").into()); + }; + + receiver + .consume(count) + .map_err(ReceiverError::from) + .map_err(OTError::from) + } +} - let msgs_res = msgs_buffer.drain(..count).collect(); +impl Allocate for Receiver { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} - return Ok(RCOTReceiverOutput { - id, - choices: choices_res, - msgs: msgs_res, - }); - } else { - let mut choices_res = choices_buffer; - let mut msgs_res = msgs_buffer; +#[async_trait] +impl Preprocess for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send, +{ + type Error = ReceiverError; - for _ in 0..count / l - 1 { - (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} - choices_res.extend_from_slice(&choices_buffer); - msgs_res.extend_from_slice(&msgs_buffer); - } +#[derive(Debug)] +struct ReceiverBuffer { + buffer: RCOTReceiverOutput, +} - (choices_buffer, msgs_buffer) = self.extend(ctx).await?; +impl ReceiverBuffer { + fn new(buffer: RCOTReceiverOutput) -> Self { + Self { buffer } + } +} - choices_res.extend_from_slice(&choices_buffer[0..count % l]); - msgs_res.extend_from_slice(&msgs_buffer[0..count % l]); +impl Default for ReceiverBuffer { + fn default() -> Self { + ReceiverBuffer { + buffer: RCOTReceiverOutput { + id: Default::default(), + choices: Vec::new(), + msgs: Vec::new(), + }, + } + } +} - return Ok(RCOTReceiverOutput { - id, - choices: choices_res, - msgs: msgs_res, - }); +#[async_trait] +impl RandomCOTReceiver for ReceiverBuffer { + async fn receive_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.choices.len() { + return Err(ReceiverError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.choices.len() + )) + .into()); } + + let choices = self.buffer.choices.drain(0..count).collect(); + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTReceiverOutput { + id: self.buffer.id.next(), + choices, + msgs, + }) } } diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs index 709ff8e2..187c1744 100644 --- a/crates/mpz-ot/src/ferret/sender.rs +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -1,45 +1,60 @@ -use crate::{ferret::mpcot::Sender as MpcotSender, RandomCOTSender}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; +use std::mem; + +use crate::{ferret::mpcot, Correlation, RandomCOTSender}; +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; use mpz_core::Block; use mpz_ot_core::{ - ferret::sender::{state, Sender as SenderCore}, + ferret::{ + sender::{state, Sender as SenderCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, RCOTSenderOutput, }; use serio::stream::IoStreamExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; use super::{FerretConfig, SenderError}; -use crate::{async_trait, OTError}; +use crate::OTError; -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] +#[derive(Debug)] pub(crate) enum State { Initialized(SenderCore), Extension(SenderCore), - Complete, Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + /// Ferret Sender. #[derive(Debug)] -pub struct Sender { +pub struct Sender { state: State, - mpcot: MpcotSender, - config: FerretConfig, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: SenderBuffer, + buffer_len: usize, } -impl Sender -where - RandomCOT: Send + Default + Clone, - SetupRandomCOT: Send, -{ +impl Sender { /// Creates a new Sender. - pub fn new(config: FerretConfig) -> Self { + /// + /// # Argument + /// + /// `config` - The Ferret config. + /// `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { Self { state: State::Initialized(SenderCore::new()), - mpcot: MpcotSender::new(config.lpn_type()), config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, } } @@ -48,39 +63,57 @@ where /// # Argument /// /// * `ctx` - The channel context. - /// * `delta` - The provided delta used for sender. - pub async fn setup_with_delta( - &mut self, - ctx: &mut Ctx, - delta: Block, - ) -> Result<(), SenderError> + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), SenderError> where Ctx: Context, - SetupRandomCOT: RandomCOTSender, + RandomCOT: RandomCOTSender + Correlation, { - let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let rcot = self.config.rcot(); - - self.mpcot.setup_with_delta(ctx, delta, rcot).await?; + let State::Initialized(sender) = self.state.take() else { + return Err(SenderError::state("sender not in initialized state")); + }; let params = self.config.lpn_parameters(); let lpn_type = self.config.lpn_type(); + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + // Get random blocks from ideal Random COT. - let RCOTSenderOutput { msgs: v, .. } = self - .config - .setup_rcot() - .send_random_correlated(ctx, params.k) + let RCOTSenderOutput { msgs: mut v, id } = self + .rcot + .send_random_correlated(ctx, params.k + self.buffer_len) .await?; + // Initiate buffer. + let buffer = RCOTSenderOutput { + id, + msgs: v.drain(0..self.buffer_len).collect(), + }; + self.buffer = SenderBuffer::new(self.rcot.delta(), buffer); + // Get seed for LPN matrix from receiver. let seed = ctx.io_mut().expect_next().await?; // Ferret core setup. - let ext_sender = ext_sender.setup(delta, params, lpn_type, seed, &v)?; + let sender = sender.setup(self.rcot.delta(), params, lpn_type, seed, &v)?; - self.state = State::Extension(ext_sender); + self.state = State::Extension(sender); Ok(()) } @@ -89,72 +122,173 @@ where /// /// # Argument /// - /// * `ctx` - The channel context. - async fn extend(&mut self, ctx: &mut Ctx) -> Result, SenderError> + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result<(), SenderError> where - RandomCOT: RandomCOTSender, + RandomCOT: RandomCOTSender + Send, { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + let State::Extension(mut sender) = self.state.take() else { + return Err(SenderError::state("sender not in extension state")); + }; - let (t, n) = ext_sender.get_mpcot_query(); + let lpn_type = self.config.lpn_type(); + let delta = sender.delta(); + let target = sender.remaining() + count; + while sender.remaining() < target { + let (t, n) = sender.get_mpcot_query(); - let s = self.mpcot.extend(ctx, t, n).await?; + let s = mpcot::send(ctx, &mut self.buffer, delta, lpn_type, t, n).await?; - let (ext_sender, output) = - Backend::spawn(move || ext_sender.extend(&s).map(|output| (ext_sender, output))) - .await?; - self.state = State::Extension(ext_sender); + sender = CpuBackend::blocking(move || sender.extend(s).map(|()| sender)).await?; - Ok(output) - } + // Update sender buffer. + let buffer = sender + .consume(self.buffer_len) + .map_err(SenderError::from) + .map_err(OTError::from)?; - /// Complete extension - pub fn finalize(&mut self) -> Result<(), SenderError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - self.state = State::Complete; - self.mpcot.finalize()?; + self.buffer = SenderBuffer::new(delta, buffer); + } + + self.state = State::Extension(sender); Ok(()) } } +impl Correlation for Sender +where + RandomCOT: Correlation, +{ + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.rcot.delta() + } +} + #[async_trait] -impl RandomCOTSender - for Sender +impl RandomCOTSender for Sender where - Ctx: Context, - RandomCOT: RandomCOTSender + Send + Default + Clone + 'static, - SetupRandomCOT: Send + 'static, + RandomCOT: Correlation + Send, { async fn send_random_correlated( &mut self, - ctx: &mut Ctx, + _ctx: &mut Ctx, count: usize, ) -> Result, OTError> { - let mut buffer = self.extend(ctx).await?; - let l = buffer.len(); - - let id = self - .state - .try_as_extension() - .map_err(SenderError::from)? - .id(); - - if count <= l { - let res = buffer.drain(..count).collect(); - return Ok(RCOTSenderOutput { id, msgs: res }); - } else { - let mut res = buffer; - for _ in 0..count / l - 1 { - buffer = self.extend(ctx).await?; - res.extend_from_slice(&buffer); - } + let State::Extension(sender) = &mut self.state else { + return Err(SenderError::state("sender not in extension state").into()); + }; + + sender + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) + } +} + +impl Allocate for Sender { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} + +#[async_trait] +impl Preprocess for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send, +{ + type Error = SenderError; - buffer = self.extend(ctx).await?; - res.extend_from_slice(&buffer[0..count % l]); + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} + +#[derive(Debug)] +struct SenderBuffer { + delta: Block, + buffer: RCOTSenderOutput, +} + +impl SenderBuffer { + fn new(delta: Block, buffer: RCOTSenderOutput) -> Self { + Self { delta, buffer } + } +} + +impl Default for SenderBuffer { + fn default() -> Self { + let buffer = RCOTSenderOutput { + id: Default::default(), + msgs: Vec::new(), + }; + Self { + delta: Block::ZERO, + buffer, + } + } +} +impl Correlation for SenderBuffer { + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.delta + } +} - return Ok(RCOTSenderOutput { id, msgs: res }); +#[async_trait] +impl RandomCOTSender for SenderBuffer { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.msgs.len() { + return Err(SenderError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.msgs.len() + )) + .into()); } + + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTSenderOutput { + id: self.buffer.id.next(), + msgs, + }) + } +} + +#[derive(Debug)] +struct BootstrappedSender<'a>(&'a mut SenderCore); + +impl Correlation for BootstrappedSender<'_> { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.delta() + } +} + +#[async_trait] +impl RandomCOTSender for BootstrappedSender<'_> { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + self.0 + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) } } diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs new file mode 100644 index 00000000..bccad692 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -0,0 +1,167 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::{ExtendFromSender, MaskBits}, + receiver::Receiver as ReceiverCore, + sender::Sender as SenderCore, + }, + CSP, + }, + RCOTReceiverOutput, RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ferret::error::SPCOTError as Error, RandomCOTReceiver, RandomCOTSender}; + +/// SPCOT send. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + hs: &[usize], +) -> Result>, Error> { + let mut sender = SenderCore::new().setup(delta); + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (mut sender, extend_msg) = CpuBackend::blocking(move || { + sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (output, check_msg) = CpuBackend::blocking(move || { + sender + .check(&y_star, checkfr) + .map(|(output, check_msg)| (output, check_msg)) + }) + .await?; + + ctx.io_mut().send(check_msg).await?; + + Ok(output) +} + +/// SPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `alphas` - Vector of chosen positions. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + alphas: &[u32], + hs: &[usize], +) -> Result, u32)>, Error> { + let mut receiver = ReceiverCore::new().setup(); + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut receiver, masks) = CpuBackend::blocking(move || { + receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let mut receiver = CpuBackend::blocking(move || { + receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| receiver) + }) + .await?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut receiver, checkfr) = CpuBackend::blocking(move || { + receiver + .check_pre(&x_star) + .map(|checkfr| (receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let output = + CpuBackend::blocking(move || receiver.check(&z_star, check).map(|output| output)).await?; + + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ideal::cot::ideal_rcot, Correlation}; + use mpz_common::executor::test_st_executor; + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let hs = [8usize, 4]; + let alphas = [4u32, 2]; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send(&mut ctx_sender, &mut rcot_sender, delta, &hs), + receive(&mut ctx_receiver, &mut rcot_receiver, &alphas, &hs) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs deleted file mode 100644 index 5f23f466..00000000 --- a/crates/mpz-ot/src/ferret/spcot/error.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::OTError; - -/// A SPCOT sender error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::spcot::error::SenderError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), - } - } -} - -impl From for SenderError { - fn from(err: crate::ferret::spcot::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) - } -} - -/// A SPCOT receiver error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::spcot::error::ReceiverError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), - } - } -} - -impl From for ReceiverError { - fn from(err: crate::ferret::spcot::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) - } -} diff --git a/crates/mpz-ot/src/ferret/spcot/mod.rs b/crates/mpz-ot/src/ferret/spcot/mod.rs deleted file mode 100644 index 6e53fd28..00000000 --- a/crates/mpz-ot/src/ferret/spcot/mod.rs +++ /dev/null @@ -1,103 +0,0 @@ -//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. - -mod error; -mod receiver; -mod sender; - -pub(crate) use error::{ReceiverError, SenderError}; -pub(crate) use receiver::Receiver; -pub(crate) use sender::Sender; - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - OTError, - }; - use futures::TryFutureExt; - use mpz_common::executor::test_st_executor; - use mpz_core::Block; - - fn setup() -> ( - Sender, - Receiver, - IdealCOTSender, - IdealCOTReceiver, - Block, - ) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); - - let delta = rcot_sender.alice().get_mut().delta(); - - let sender = Sender::new(); - let receiver = Receiver::new(); - - (sender, receiver, rcot_sender, rcot_receiver, delta) - } - - #[tokio::test] - async fn test_spcot() { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - - let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(); - - // shold set the same delta as in RCOT. - sender.setup_with_delta(delta, rcot_sender).unwrap(); - receiver.setup(rcot_receiver).unwrap(); - - let hs = [8, 4]; - let alphas = [4, 2]; - - tokio::try_join!( - sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, &hs) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender.check(&mut ctx_sender).map_err(OTError::from), - receiver.check(&mut ctx_receiver).map_err(OTError::from) - ) - .unwrap(); - - assert!(output_sender - .iter_mut() - .zip(output_receiver.iter()) - .all(|(vs, (ws, alpha))| { - vs[*alpha as usize] ^= delta; - vs == ws - })); - - // extend twice. - let hs = [6, 9, 8]; - let alphas = [2, 1, 3]; - - tokio::try_join!( - sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, &hs) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender.check(&mut ctx_sender).map_err(OTError::from), - receiver.check(&mut ctx_receiver).map_err(OTError::from) - ) - .unwrap(); - - assert!(output_sender - .iter_mut() - .zip(output_receiver.iter()) - .all(|(vs, (ws, alpha))| { - vs[*alpha as usize] ^= delta; - vs == ws - })); - - sender.finalize().unwrap(); - receiver.finalize().unwrap(); - } -} diff --git a/crates/mpz-ot/src/ferret/spcot/receiver.rs b/crates/mpz-ot/src/ferret/spcot/receiver.rs deleted file mode 100644 index 3c48bfad..00000000 --- a/crates/mpz-ot/src/ferret/spcot/receiver.rs +++ /dev/null @@ -1,164 +0,0 @@ -use crate::{ferret::spcot::error::ReceiverError, RandomCOTReceiver}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; -use mpz_core::Block; -use mpz_ot_core::{ - ferret::{ - spcot::{ - msgs::ExtendFromSender, - receiver::{state, Receiver as ReceiverCore}, - }, - CSP, - }, - RCOTReceiverOutput, -}; -use serio::{stream::IoStreamExt, SinkExt}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(ReceiverCore), - Extension(Box>), - Complete, - Error, -} - -/// SPCOT Receiver. -#[derive(Debug)] -pub(crate) struct Receiver { - state: State, - rcot: RandomCOT, -} - -impl Receiver { - /// Creates a new Receiver. - pub(crate) fn new() -> Self { - Self { - state: State::Initialized(ReceiverCore::new()), - rcot: Default::default(), - } - } - - /// Performs setup for receiver. - /// - /// # Arguments. - /// - /// * `rcot` - The random COT used by the receiver. - pub(crate) fn setup(&mut self, rcot: RandomCOT) -> Result<(), ReceiverError> { - let ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let ext_receiver = ext_receiver.setup(); - self.state = State::Extension(Box::new(ext_receiver)); - self.rcot = rcot; - Ok(()) - } - - /// Performs spcot extension for receiver. - /// - /// # Arguments - /// - /// * `ctx` - The context. - /// * `alphas`` - The vector of chosen positions. - /// * `h` - The depth of GGM tree. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - alphas: &[u32], - hs: &[usize], - ) -> Result<(), ReceiverError> - where - RandomCOT: RandomCOTReceiver, - { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - let h = hs.iter().sum(); - let RCOTReceiverOutput { - choices: rss, - msgs: tss, - .. - } = self.rcot.receive_random_correlated(ctx, h).await?; - - // extend - let h_in = hs.to_vec(); - let alphas_in = alphas.to_vec(); - let (mut ext_receiver, masks) = Backend::spawn(move || { - ext_receiver - .extend_mask_bits(&h_in, &alphas_in, &rss) - .map(|mask| (ext_receiver, mask)) - }) - .await?; - - ctx.io_mut().send(masks).await?; - - let extendfss: Vec = ctx.io_mut().expect_next().await?; - - let h_in = hs.to_vec(); - let alphas_in = alphas.to_vec(); - let ext_receiver = Backend::spawn(move || { - ext_receiver - .extend(&h_in, &alphas_in, &tss, &extendfss) - .map(|_| ext_receiver) - }) - .await?; - - self.state = State::Extension(ext_receiver); - - Ok(()) - } - - /// Performs batch check for SPCOT extension. - /// - /// # Arguments - /// - /// * `ctx` - The context. - pub(crate) async fn check( - &mut self, - ctx: &mut Ctx, - ) -> Result, u32)>, ReceiverError> - where - RandomCOT: RandomCOTReceiver, - { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - // batch check - let RCOTReceiverOutput { - choices: x_star, - msgs: z_star, - .. - } = self.rcot.receive_random_correlated(ctx, CSP).await?; - - let (mut ext_receiver, checkfr) = Backend::spawn(move || { - ext_receiver - .check_pre(&x_star) - .map(|checkfr| (ext_receiver, checkfr)) - }) - .await?; - - ctx.io_mut().send(checkfr).await?; - let check = ctx.io_mut().expect_next().await?; - - let (ext_receiver, output) = Backend::spawn(move || { - ext_receiver - .check(&z_star, check) - .map(|output| (ext_receiver, output)) - }) - .await?; - - self.state = State::Extension(ext_receiver); - - Ok(output) - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ferret/spcot/sender.rs b/crates/mpz-ot/src/ferret/spcot/sender.rs deleted file mode 100644 index 9178b787..00000000 --- a/crates/mpz-ot/src/ferret/spcot/sender.rs +++ /dev/null @@ -1,144 +0,0 @@ -use crate::{ferret::spcot::error::SenderError, RandomCOTSender}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; -use mpz_core::Block; -use mpz_ot_core::{ - ferret::{ - spcot::{ - msgs::MaskBits, - sender::{state, Sender as SenderCore}, - }, - CSP, - }, - RCOTSenderOutput, -}; -use serio::{stream::IoStreamExt, SinkExt}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(SenderCore), - Extension(Box>), - Complete, - Error, -} - -/// SPCOT sender. -#[derive(Debug)] -pub(crate) struct Sender { - state: State, - rcot: RandomCOT, -} - -impl Sender { - /// Creates a new Sender. - pub(crate) fn new() -> Self { - Self { - state: State::Initialized(SenderCore::new()), - rcot: Default::default(), - } - } - - /// Performs setup with the provided delta. - /// - /// # Arguments - /// - /// * `delta` - The delta value to use for OT extension. - /// * `rcot` - The random COT used by the sender. - pub(crate) fn setup_with_delta( - &mut self, - delta: Block, - rcot: RandomCOT, - ) -> Result<(), SenderError> { - let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let ext_sender = ext_sender.setup(delta); - - self.state = State::Extension(Box::new(ext_sender)); - self.rcot = rcot; - Ok(()) - } - - /// Performs spcot extension for sender. - /// - /// # Arguments - /// - /// * `ctx` - The context. - /// * `hs` - The depths of GGM trees. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - hs: &[usize], - ) -> Result<(), SenderError> - where - RandomCOT: RandomCOTSender, - { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - let h = hs.iter().sum(); - let RCOTSenderOutput { msgs: qss, .. } = self.rcot.send_random_correlated(ctx, h).await?; - - let masks: Vec = ctx.io_mut().expect_next().await?; - - // extend - let h_in = hs.to_vec(); - let (ext_sender, extend_msg) = Backend::spawn(move || { - ext_sender - .extend(&h_in, &qss, &masks) - .map(|extend_msg| (ext_sender, extend_msg)) - }) - .await?; - - ctx.io_mut().send(extend_msg).await?; - - self.state = State::Extension(ext_sender); - - Ok(()) - } - - /// Performs batch check for SPCOT extension. - /// - /// # Arguments - /// - /// * `ctx` - The context. - pub(crate) async fn check( - &mut self, - ctx: &mut Ctx, - ) -> Result>, SenderError> - where - RandomCOT: RandomCOTSender, - { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - // batch check - let RCOTSenderOutput { msgs: y_star, .. } = - self.rcot.send_random_correlated(ctx, CSP).await?; - - let checkfr = ctx.io_mut().expect_next().await?; - - let (ext_sender, output, check_msg) = Backend::spawn(move || { - ext_sender - .check(&y_star, checkfr) - .map(|(output, check_msg)| (ext_sender, output, check_msg)) - }) - .await?; - - ctx.io_mut().send(check_msg).await?; - - self.state = State::Extension(ext_sender); - - Ok(output) - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index 18233dfe..bc7df0a6 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -11,7 +11,9 @@ use mpz_ot_core::{ ideal::cot::IdealCOT, COTReceiverOutput, COTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, }; -use crate::{COTReceiver, COTSender, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender}; +use crate::{ + COTReceiver, COTSender, Correlation, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender, +}; fn cot( f: &mut IdealCOT, @@ -82,6 +84,14 @@ where } } +impl Correlation for IdealCOTSender { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.lock().delta() + } +} + #[async_trait] impl COTSender for IdealCOTSender { async fn send_correlated( @@ -170,7 +180,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_cot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; let choices = (0..count).map(|_| rng.gen()).collect::>(); @@ -201,7 +211,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_rcot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index 0e4d1b48..c1508883 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -61,9 +61,18 @@ pub trait OTSender { async fn send(&mut self, ctx: &mut Ctx, msgs: &[T]) -> Result; } +/// Correlation of COT messages. +pub trait Correlation { + /// The type of the correlation. + type Correlation; + + /// Returns the correlation. + fn delta(&self) -> Self::Correlation; +} + /// A correlated oblivious transfer sender. #[async_trait] -pub trait COTSender { +pub trait COTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. @@ -97,7 +106,7 @@ pub trait RandomOTSender { /// A random correlated oblivious transfer sender. #[async_trait] -pub trait RandomCOTSender { +pub trait RandomCOTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. From 5f4e90f904fc7973fd5eed22b2b15ef3d02f882a Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 20 Aug 2024 07:09:36 +0800 Subject: [PATCH 07/17] fix clippy --- crates/mpz-ot-core/src/chou_orlandi/receiver.rs | 2 +- crates/mpz-ot-core/src/chou_orlandi/sender.rs | 2 +- crates/mpz-ot-core/src/ferret/receiver.rs | 4 ++-- crates/mpz-ot-core/src/ferret/sender.rs | 4 ++-- crates/mpz-ot-core/src/ideal/cot.rs | 2 +- crates/mpz-ot-core/src/ideal/mpcot.rs | 2 +- crates/mpz-ot-core/src/ideal/ot.rs | 2 +- crates/mpz-ot-core/src/ideal/rot.rs | 4 ++-- crates/mpz-ot-core/src/ideal/spcot.rs | 2 +- crates/mpz-ot-core/src/kos/receiver.rs | 2 +- crates/mpz-ot-core/src/kos/sender.rs | 2 +- crates/mpz-ot-core/src/lib.rs | 2 +- crates/mpz-ot/src/ferret/receiver.rs | 10 +++++----- crates/mpz-ot/src/ferret/sender.rs | 2 +- crates/mpz-ot/src/ferret/spcot.rs | 3 +-- 15 files changed, 22 insertions(+), 23 deletions(-) diff --git a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs index 403802f9..d9638951 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs @@ -153,7 +153,7 @@ impl Receiver { let SenderPayload { id, payload } = payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(ReceiverError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/chou_orlandi/sender.rs b/crates/mpz-ot-core/src/chou_orlandi/sender.rs index 09a8b5a6..328354eb 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/sender.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/sender.rs @@ -139,7 +139,7 @@ impl Sender { } = receiver_payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(SenderError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 1cfd1e08..782d2b9e 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -119,7 +119,7 @@ impl Receiver { return Err(ReceiverError("the length of r should be n".to_string())); } - self.state.id.next(); + self.state.id.next_id(); // Compute z = A * w + r. let mut z = r; @@ -169,7 +169,7 @@ impl Receiver { let msgs = self.state.msgs_buffer.drain(0..count).collect(); Ok(RCOTReceiverOutput { - id: self.state.id.next(), + id: self.state.id.next_id(), choices, msgs, }) diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 436d6003..e6af6452 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -109,7 +109,7 @@ impl Sender { return Err(SenderError("the length of s should be n".to_string())); } - self.state.id.next(); + self.state.id.next_id(); // Compute y = A * v + s let mut y = s; @@ -139,7 +139,7 @@ impl Sender { let msgs = self.state.msgs_buffer.drain(0..count).collect(); Ok(RCOTSenderOutput { - id: self.state.id.next(), + id: self.state.id.next_id(), msgs, }) } diff --git a/crates/mpz-ot-core/src/ideal/cot.rs b/crates/mpz-ot-core/src/ideal/cot.rs index a28abef8..a842129d 100644 --- a/crates/mpz-ot-core/src/ideal/cot.rs +++ b/crates/mpz-ot-core/src/ideal/cot.rs @@ -76,7 +76,7 @@ impl IdealCOT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( RCOTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/mpcot.rs b/crates/mpz-ot-core/src/ideal/mpcot.rs index 44a5595f..c038331b 100644 --- a/crates/mpz-ot-core/src/ideal/mpcot.rs +++ b/crates/mpz-ot-core/src/ideal/mpcot.rs @@ -60,7 +60,7 @@ impl IdealMpcot { self.counter += 1; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (MPCOTSenderOutput { id, s }, MPCOTReceiverOutput { id, r }) } diff --git a/crates/mpz-ot-core/src/ideal/ot.rs b/crates/mpz-ot-core/src/ideal/ot.rs index e389066e..76ebe630 100644 --- a/crates/mpz-ot-core/src/ideal/ot.rs +++ b/crates/mpz-ot-core/src/ideal/ot.rs @@ -55,7 +55,7 @@ impl IdealOT { self.counter += choices.len(); self.choices.extend(choices); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (OTSenderOutput { id }, OTReceiverOutput { id, msgs: chosen }) } diff --git a/crates/mpz-ot-core/src/ideal/rot.rs b/crates/mpz-ot-core/src/ideal/rot.rs index 8a8b5d68..e29b9204 100644 --- a/crates/mpz-ot-core/src/ideal/rot.rs +++ b/crates/mpz-ot-core/src/ideal/rot.rs @@ -68,7 +68,7 @@ impl IdealROT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, @@ -103,7 +103,7 @@ impl IdealROT { .collect(); self.counter += choices.len(); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/spcot.rs b/crates/mpz-ot-core/src/ideal/spcot.rs index 12c5f829..93b3c720 100644 --- a/crates/mpz-ot-core/src/ideal/spcot.rs +++ b/crates/mpz-ot-core/src/ideal/spcot.rs @@ -61,7 +61,7 @@ impl IdealSpcot { self.counter += n; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (SPCOTSenderOutput { id, v }, SPCOTReceiverOutput { id, w }) } diff --git a/crates/mpz-ot-core/src/kos/receiver.rs b/crates/mpz-ot-core/src/kos/receiver.rs index fdcad328..127c4f1d 100644 --- a/crates/mpz-ot-core/src/kos/receiver.rs +++ b/crates/mpz-ot-core/src/kos/receiver.rs @@ -330,7 +330,7 @@ impl Receiver { )); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); let index = self.state.index - self.state.keys.len(); Ok(ReceiverKeys { diff --git a/crates/mpz-ot-core/src/kos/sender.rs b/crates/mpz-ot-core/src/kos/sender.rs index 24917940..23edff5c 100644 --- a/crates/mpz-ot-core/src/kos/sender.rs +++ b/crates/mpz-ot-core/src/kos/sender.rs @@ -294,7 +294,7 @@ impl Sender { return Err(SenderError::InsufficientSetup(count, self.state.keys.len())); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); Ok(SenderKeys { id, diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index dcfffc59..b0b69260 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId { impl TransferId { /// Returns the current transfer ID, incrementing `self` in-place. - pub fn next(&mut self) -> Self { + pub fn next_id(&mut self) -> Self { let id = *self; self.0 += 1; id diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs index d04e5d29..fbbb38eb 100644 --- a/crates/mpz-ot/src/ferret/receiver.rs +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -19,8 +19,8 @@ use crate::{ #[derive(Debug)] pub(crate) enum State { - Initialized(ReceiverCore), - Extension(ReceiverCore), + Initialized(Box>), + Extension(Box>), Error, } @@ -50,7 +50,7 @@ impl Receiver { /// * `rcot` - The random COT in setup. pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { Self { - state: State::Initialized(ReceiverCore::new()), + state: State::Initialized(Box::new(ReceiverCore::new())), config, rcot, alloc: 0, @@ -119,7 +119,7 @@ impl Receiver { ctx.io_mut().send(seed).await?; - self.state = State::Extension(receiver); + self.state = State::Extension(Box::new(receiver)); Ok(()) } @@ -245,7 +245,7 @@ impl RandomCOTReceiver for ReceiverBuffer { let choices = self.buffer.choices.drain(0..count).collect(); let msgs = self.buffer.msgs.drain(0..count).collect(); Ok(RCOTReceiverOutput { - id: self.buffer.id.next(), + id: self.buffer.id.next_id(), choices, msgs, }) diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs index 187c1744..02884b2c 100644 --- a/crates/mpz-ot/src/ferret/sender.rs +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -262,7 +262,7 @@ impl RandomCOTSender for SenderBuffer { let msgs = self.buffer.msgs.drain(0..count).collect(); Ok(RCOTSenderOutput { - id: self.buffer.id.next(), + id: self.buffer.id.next_id(), msgs, }) } diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs index bccad692..5fcb6e6c 100644 --- a/crates/mpz-ot/src/ferret/spcot.rs +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -57,7 +57,6 @@ pub(crate) async fn send>( let (output, check_msg) = CpuBackend::blocking(move || { sender .check(&y_star, checkfr) - .map(|(output, check_msg)| (output, check_msg)) }) .await?; @@ -130,7 +129,7 @@ pub(crate) async fn receive Date: Tue, 20 Aug 2024 08:01:39 +0800 Subject: [PATCH 08/17] params --- crates/mpz-ot-core/src/ferret/mod.rs | 12 ++++++------ crates/mpz-ot/examples/ferret.rs | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 crates/mpz-ot/examples/ferret.rs diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 0e27f0a9..6b478b06 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -22,17 +22,17 @@ pub const CUCKOO_TRIAL_NUM: usize = 100; /// LPN parameters with regular noise. /// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10180608, - k: 124000, - t: 4971, + n: 10_180_608, + k: 124_000, + t: 4_971, }; /// LPN parameters with uniform noise. /// Derived from Table 2. pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10616092, - k: 588160, - t: 1324, + n: 10_616_092, + k: 588_160, + t: 1_324, }; /// The type of Lpn parameters. diff --git a/crates/mpz-ot/examples/ferret.rs b/crates/mpz-ot/examples/ferret.rs new file mode 100644 index 00000000..f328e4d9 --- /dev/null +++ b/crates/mpz-ot/examples/ferret.rs @@ -0,0 +1 @@ +fn main() {} From 84399bd0976fa89b7203f2612a5de069b0c558ba Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Wed, 21 Aug 2024 10:08:18 +0800 Subject: [PATCH 09/17] add default ferret configs --- crates/mpz-ot-core/src/ferret/mod.rs | 19 ---- crates/mpz-ot/src/ferret/mod.rs | 132 +++++++++++++++++++++++++++ crates/mpz-ot/src/ferret/spcot.rs | 9 +- 3 files changed, 134 insertions(+), 26 deletions(-) diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 6b478b06..ac73c005 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -1,7 +1,4 @@ //! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. - -use mpz_core::lpn::LpnParameters; - pub mod cuckoo; pub mod error; pub mod mpcot; @@ -19,22 +16,6 @@ pub const CUCKOO_HASH_NUM: usize = 3; /// Trial numbers in Cuckoo hash insertion. pub const CUCKOO_TRIAL_NUM: usize = 100; -/// LPN parameters with regular noise. -/// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h -pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10_180_608, - k: 124_000, - t: 4_971, -}; - -/// LPN parameters with uniform noise. -/// Derived from Table 2. -pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10_616_092, - k: 588_160, - t: 1_324, -}; - /// The type of Lpn parameters. #[derive(Debug, Clone, Copy, Default)] pub enum LpnType { diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs index 086e5e8b..9d421885 100644 --- a/crates/mpz-ot/src/ferret/mod.rs +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -44,6 +44,138 @@ impl FerretConfig { } } +/// Ferret config with regular LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_REGULAR_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 102_400, + k: 6_750, + t: 1_600, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_REGULAR_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_740_800, + k: 66_400, + t: 1700, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_REGULAR_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_REGULAR_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_REGULAR_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 518_656, + k: 34_643, + t: 1_013, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_REGULAR_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_485_760, + k: 458_000, + t: 1280, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_UNIFORM_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 98_000, + k: 4_450, + t: 1_600, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_UNIFORM_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_071_888, + k: 40_800, + t: 1720, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_UNIFORM_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_UNIFORM_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_UNIFORM_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 545_656, + k: 34_643, + t: 1_050, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_UNIFORM_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_488_928, + k: 458_000, + t: 1_280, + }, + lpn_type: LpnType::Uniform, +}; + #[cfg(test)] mod tests { use super::*; diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs index 5fcb6e6c..e63a1aa9 100644 --- a/crates/mpz-ot/src/ferret/spcot.rs +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -54,11 +54,7 @@ pub(crate) async fn send>( let checkfr = ctx.io_mut().expect_next().await?; - let (output, check_msg) = CpuBackend::blocking(move || { - sender - .check(&y_star, checkfr) - }) - .await?; + let (output, check_msg) = CpuBackend::blocking(move || sender.check(&y_star, checkfr)).await?; ctx.io_mut().send(check_msg).await?; @@ -128,8 +124,7 @@ pub(crate) async fn receive Date: Wed, 4 Sep 2024 18:49:56 +0800 Subject: [PATCH 10/17] init mpz-zk-core --- crates/mpz-core/src/block.rs | 14 ++ crates/mpz-zk-core/Cargo.toml | 3 + crates/mpz-zk-core/src/ideal/mod.rs | 3 + crates/mpz-zk-core/src/ideal/vope.rs | 102 +++++++++++ crates/mpz-zk-core/src/lib.rs | 22 +++ crates/mpz-zk-core/src/quicksilver/error.rs | 22 +++ crates/mpz-zk-core/src/quicksilver/mod.rs | 22 +++ crates/mpz-zk-core/src/quicksilver/prover.rs | 173 ++++++++++++++++++ .../mpz-zk-core/src/quicksilver/verifier.rs | 164 +++++++++++++++++ crates/mpz-zk/src/vope/mod.rs | 9 +- 10 files changed, 531 insertions(+), 3 deletions(-) create mode 100644 crates/mpz-zk-core/src/ideal/mod.rs create mode 100644 crates/mpz-zk-core/src/ideal/vope.rs create mode 100644 crates/mpz-zk-core/src/quicksilver/error.rs create mode 100644 crates/mpz-zk-core/src/quicksilver/mod.rs create mode 100644 crates/mpz-zk-core/src/quicksilver/prover.rs create mode 100644 crates/mpz-zk-core/src/quicksilver/verifier.rs diff --git a/crates/mpz-core/src/block.rs b/crates/mpz-core/src/block.rs index 2f7a0105..1e1d67ac 100644 --- a/crates/mpz-core/src/block.rs +++ b/crates/mpz-core/src/block.rs @@ -7,6 +7,7 @@ use generic_array::{typenum::consts::U16, GenericArray}; use itybity::{BitIterable, BitLength, GetBit, Lsb0, Msb0}; use rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng}; use serde::{Deserialize, Serialize}; +use std::iter::successors; /// A block of 128 bits #[repr(transparent)] @@ -22,6 +23,11 @@ impl Block { pub const ONE: Self = Self([1u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); /// A block with all bits set to 1 pub const ONES: Self = Self([0xff; 16]); + /// A block with all 1 bits excect the lsb. + pub const MINIS_ONE: Block = Self([ + 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, + ]); /// A length 2 array of zero and one blocks pub const SELECT_MASK: [Self; 2] = [Self::ZERO, Self::ONES]; @@ -123,6 +129,14 @@ impl Block { bytemuck::cast([x[1], x[0]]) } + /// Generate the powers of the seed. + #[inline(always)] + pub fn powers(seed: Self, size: usize) -> Vec { + successors(Some(Block::ONE), |pow| Some(pow.gfmul(seed))) + .take(size) + .collect() + } + /// Converts a block to a [`GenericArray`](cipher::generic_array::GenericArray) /// from the [`generic-array`](https://docs.rs/generic-array/latest/generic_array/) crate. #[allow(dead_code)] diff --git a/crates/mpz-zk-core/Cargo.toml b/crates/mpz-zk-core/Cargo.toml index e390db57..69b2672e 100644 --- a/crates/mpz-zk-core/Cargo.toml +++ b/crates/mpz-zk-core/Cargo.toml @@ -19,6 +19,7 @@ test-utils = [] [dependencies] mpz-core.workspace = true mpz-ot-core.workspace = true +mpz-circuits.workspace = true clmul.workspace = true matrix-transpose.workspace = true @@ -33,3 +34,5 @@ opaque-debug.workspace = true cfg-if.workspace = true bytemuck = { workspace = true, features = ["derive"] } enum-try-as-inner.workspace = true +blake3.workspace = true +rand_core.workspace = true \ No newline at end of file diff --git a/crates/mpz-zk-core/src/ideal/mod.rs b/crates/mpz-zk-core/src/ideal/mod.rs new file mode 100644 index 00000000..c088fde3 --- /dev/null +++ b/crates/mpz-zk-core/src/ideal/mod.rs @@ -0,0 +1,3 @@ +//! Ideal functionalities. + +pub mod vope; \ No newline at end of file diff --git a/crates/mpz-zk-core/src/ideal/vope.rs b/crates/mpz-zk-core/src/ideal/vope.rs new file mode 100644 index 00000000..c1238c22 --- /dev/null +++ b/crates/mpz-zk-core/src/ideal/vope.rs @@ -0,0 +1,102 @@ +//! Ideal VOPE functionality. + +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::TransferId; +use rand_core::SeedableRng; + +use crate::{VOPEReceiverOutput, VOPESenderOutput}; + +/// The ideal VOPE functionality. +#[derive(Debug)] +pub struct IdealVOPE { + delta: Block, + transfer_id: TransferId, + counter: usize, + prg: Prg, +} + +impl IdealVOPE { + /// Creates a new ideal VOPE functionality. + /// + /// # Arguments + /// + /// * `seed` - The seed for the PRG. + /// * `delta` - The correlation. + pub fn new(seed: Block, delta: Block) -> Self { + Self { + delta, + transfer_id: TransferId::default(), + counter: 0, + prg: Prg::from_seed(seed), + } + } + + /// Returns the correlation, delta. + pub fn delta(&self) -> Block { + self.delta + } + + /// Sets the correlation, delta. + pub fn set_delta(&mut self, delta: Block) { + self.delta = delta; + } + + /// Returns the current transfer id. + pub fn transfer_id(&self) -> TransferId { + self.transfer_id + } + + /// Returns the number of VOPE executed. + pub fn count(&self) -> usize { + self.counter + } + + /// Executes the VOPE. + /// + /// # Arguments + /// + /// * `degree` - The degree of the polynomnial. + pub fn random_correlated( + &mut self, + degree: usize, + ) -> (VOPESenderOutput, VOPEReceiverOutput) { + let mut coeff = vec![Block::ZERO; degree + 1]; + self.prg.random_blocks(&mut coeff); + + let powers = Block::powers(self.delta, degree + 1); + + let eval = Block::inn_prdt_red(&coeff, &powers); + + self.counter += 1; + let id = self.transfer_id.next_id(); + + ( + VOPESenderOutput { id, eval }, + VOPEReceiverOutput { id, coeff }, + ) + } +} + +impl Default for IdealVOPE { + fn default() -> Self { + let mut rng = Prg::from_seed(Block::ZERO); + Self::new(rng.random_block(), rng.random_block()) + } +} + +#[cfg(test)] +mod tests { + use crate::{test::poly_check, VOPEReceiverOutput, VOPESenderOutput}; + + use super::IdealVOPE; + + #[test] + fn test_ideal_vope() { + let mut ideal = IdealVOPE::default(); + + let (VOPESenderOutput { eval, .. }, VOPEReceiverOutput { coeff, .. }) = + ideal.random_correlated(10); + + assert!(poly_check(&coeff, eval, ideal.delta())); + } +} diff --git a/crates/mpz-zk-core/src/lib.rs b/crates/mpz-zk-core/src/lib.rs index 92ca8ef1..847b6f7c 100644 --- a/crates/mpz-zk-core/src/lib.rs +++ b/crates/mpz-zk-core/src/lib.rs @@ -19,5 +19,27 @@ clippy::all )] +use mpz_ot_core::TransferId; + +pub mod quicksilver; pub mod test; pub mod vope; +pub mod ideal; + +/// The output the receiver receives from the VOPE functionality. +#[derive(Debug)] +pub struct VOPEReceiverOutput { + /// The transfer id. + pub id: TransferId, + /// The coefficients. + pub coeff: Vec, +} + +/// The output the sender receives from the VOPE functinality. +#[derive(Debug)] +pub struct VOPESenderOutput { + /// The transfer id. + pub id: TransferId, + /// The evaluation value. + pub eval: T, +} diff --git a/crates/mpz-zk-core/src/quicksilver/error.rs b/crates/mpz-zk-core/src/quicksilver/error.rs new file mode 100644 index 00000000..4909a652 --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/error.rs @@ -0,0 +1,22 @@ +//! Errors in QuickSilver. +use mpz_circuits::CircuitError; + +/// Errors that can occur during proving +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum QsProverError { + #[error(transparent)] + CircuitError(#[from] CircuitError), + #[error("invalid inputs")] + InvalidInputs, +} + +/// Errors that can occur during verifying +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum QsVerifierError { + #[error(transparent)] + CircuitError(#[from] CircuitError), + #[error("invalid inputs")] + InvalidInputs, +} diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs new file mode 100644 index 00000000..d8f77f68 --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -0,0 +1,22 @@ +//! This is the implementation of QuickSilver (https://eprint.iacr.org/2021/076.pdf). + +mod error; +mod prover; +mod verifier; + +pub use error::*; +pub use prover::Prover; +pub use verifier::Verifier; + +/// Buffer size of each check. +pub const CHECK_BUFFER_SIZE: usize = 1024 * 1024; + +#[inline] +fn bools_to_bytes(bv: &[bool]) -> Vec { + let offset = if bv.len() % 8 == 0 { 0 } else { 1 }; + let mut v = vec![0u8; bv.len() / 8 + offset]; + for (i, b) in bv.iter().enumerate() { + v[i / 8] |= (*b as u8) << (7 - (i % 8)); + } + v +} diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs new file mode 100644 index 00000000..ed93f37b --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -0,0 +1,173 @@ +use mpz_core::Block; +use mpz_ot_core::RCOTReceiverOutput; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +use super::{bools_to_bytes, QsProverError, CHECK_BUFFER_SIZE}; + +/// QuickSilver Prover. +#[derive(Debug)] +pub struct Prover { + /// Buffer for left wire label. + buf_left: Vec, + /// Buffer for right wire label. + buf_right: Vec, + /// Buffer for output wire label. + buf_out: Vec, + /// Counter for check. + counter: usize, + /// Hasher. + hasher: blake3::Hasher, + /// Hash buffer for the bools. + buf_hash: Vec, +} + +impl Prover { + /// Create a new instance + pub fn new() -> Self { + Self { + buf_left: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_right: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_out: vec![Block::ZERO; CHECK_BUFFER_SIZE], + counter: 0, + hasher: blake3::Hasher::new(), + buf_hash: vec![false; CHECK_BUFFER_SIZE], + } + } + + /// Compute authenticated bits for inputs. + /// See step 4 in Figure 5 + /// + /// # Arguments + /// + /// * `inputs` - The input bits. + /// * `cot` - The COT mask received from Ideal COT as the receiver. + pub fn compute_input_bits( + &mut self, + inputs: &[bool], + cot: RCOTReceiverOutput, + ) -> Result<(Vec, Vec), QsProverError> { + if cot.choices.len() != inputs.len() { + return Err(QsProverError::InvalidInputs); + } + + let RCOTReceiverOutput { + choices: bits, + msgs: blks, + .. + } = cot; + + let res: (Vec, Vec) = bits + .iter() + .zip(inputs.iter()) + .zip(blks.iter()) + .map(|((mask, b), blk)| (b ^ mask, Self::set_value(*blk, *b))) + .unzip(); + + // Hash the bools. + self.hasher.update(&bools_to_bytes(&res.0)); + + Ok(res) + } + + /// Compute authenticated and gate. + /// See step 6 in Figure 5. + /// + /// # Arguments. + /// + /// * `ma` - The MAC of wire a. + /// * `mb` - The MAC of wire b. + /// * `cot` - The COT mask received from Ideal COT as the receiver. + pub fn compute_and_gate( + &mut self, + ma: Block, + mb: Block, + cot: RCOTReceiverOutput, + ) -> Result { + if cot.choices.len() != 1 { + return Err(QsProverError::InvalidInputs); + } + + assert!(self.counter < CHECK_BUFFER_SIZE); + + self.buf_left[self.counter] = ma; + self.buf_right[self.counter] = mb; + + let RCOTReceiverOutput { + choices: s, + msgs: blks, + .. + } = cot; + + // Compute wa * wb + let v = ma.lsb() & mb.lsb() == 1; + // Compute the mask of v with s. + let d = v ^ s[0]; + + self.buf_out[self.counter] = Self::set_value(blks[0], v); + self.buf_hash[self.counter] = d; + self.counter += 1; + + Ok(d) + } + + /// Check and gate. + /// See step 6, 7 in Figure 5. + /// + /// # Arguments. + /// + /// * `vope` - The mask blocks received from ideal VOPE. + pub fn check_and_gate(&mut self, vope: (Block, Block)) -> (Block, Block) { + assert!(self.counter <= CHECK_BUFFER_SIZE); + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = self.buf_left[..self.counter] + .par_iter() + .zip(self.buf_right[..self.counter].par_iter()) + .zip(self.buf_out[..self.counter].par_iter()); + } else{ + let iter = self.buf_left[..self.counter] + .iter() + .zip(self.buf_right[..self.counter].iter()) + .zip(self.buf_out[..self.counter].iter()) + } + } + + // Compute A0 and A1. + let blocks: (Vec, Vec) = iter + .map(|((a, b), c)| { + let tmp0 = if a.lsb() == 1 { *b } else { Block::ZERO }; + let tmp1 = if b.lsb() == 1 { *a } else { Block::ZERO }; + + (a.gfmul(*b), tmp0 ^ tmp1 ^ *c) + }) + .unzip(); + + // Compute chi and powers. + self.hasher + .update(&bools_to_bytes(&self.buf_hash[..self.counter])); + let seed = *self.hasher.finalize().as_bytes(); + let seed = Block::try_from(&seed[0..16]).unwrap(); + let chis = Block::powers(seed, self.counter); + + // Compute the inner product. + let u = Block::inn_prdt_red(&blocks.0, &chis); + let v = Block::inn_prdt_red(&blocks.1, &chis); + + // Mask the results. + let u = u ^ vope.0; + let v = v ^ vope.1; + + // Update the hasher + self.hasher.update(&u.to_bytes()); + self.hasher.update(&v.to_bytes()); + self.counter = 0; + + (u, v) + } + + // set the LSB of the block to as the bit. + #[inline] + fn set_value(block: Block, b: bool) -> Block { + (block & Block::MINIS_ONE) ^ (if b { Block::ONE } else { Block::ZERO }) + } +} diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs new file mode 100644 index 00000000..fc2bfce7 --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -0,0 +1,164 @@ +use mpz_core::Block; +use mpz_ot_core::RCOTSenderOutput; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +use super::{bools_to_bytes, QsVerifierError, CHECK_BUFFER_SIZE}; + +/// QuickSilver Verifier. +#[derive(Debug)] +pub struct Verifier { + /// Global secret. + delta: Block, + /// Buffer for left wire label. + buf_left: Vec, + /// Buffer for right wire label. + buf_right: Vec, + /// Buffer for output wire label. + buf_out: Vec, + /// Counter for check. + counter: usize, + /// Hasher. + hasher: blake3::Hasher, + /// Hash buffer for the bools. + buf_hash: Vec, + /// Indicate the and_gate check passes or not. + and_gate_checked: bool, +} + +impl Verifier { + /// Create a new instance + /// + /// # Arguments. + /// + /// * `delta` - The global secret. + pub fn new(delta: Block) -> Self { + Self { + delta, + buf_left: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_right: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_out: vec![Block::ZERO; CHECK_BUFFER_SIZE], + counter: 0, + hasher: blake3::Hasher::new(), + buf_hash: vec![false; CHECK_BUFFER_SIZE], + and_gate_checked: true, + } + } + /// Compute authenticated bits for inputs. + /// See step 4 in Figure 5 + /// # Arguments + /// + /// * `masks` - The mask bits sent from the prover. + /// * `cot` - The COT mask received from Ideal COT as the sender. + pub fn compute_input_bits( + &mut self, + masks: &[bool], + cot: RCOTSenderOutput, + // The mask bits sent by prover. + ) -> Result, QsVerifierError> { + if masks.len() != cot.msgs.len() { + return Err(QsVerifierError::InvalidInputs); + } + + // Hash the bools. + self.hasher.update(&bools_to_bytes(masks)); + + let RCOTSenderOutput { msgs: blks, .. } = cot; + + let res = blks + .iter() + .zip(masks.iter()) + .map(|(blk, mask)| { + let block = *blk ^ (if *mask { self.delta } else { Block::ZERO }); + Self::set_zero(block) + }) + .collect(); + + Ok(res) + } + + /// Compute authenticated and gate. + /// See step 6 in Figure 5. + /// + /// # Arguments. + /// + /// * `ka` - The KEY of wire a. + /// * `kb` - The KEY of wire b. + /// * `mask` - The mask sent by the prover. + /// * `cot` - The COT mask received from Ideal COT as the sender. + pub fn compute_and_gate( + &mut self, + ka: Block, + kb: Block, + mask: bool, + cot: RCOTSenderOutput, + ) -> Result<(), QsVerifierError> { + if cot.msgs.len() != 1 { + return Err(QsVerifierError::InvalidInputs); + } + + assert!(self.counter < CHECK_BUFFER_SIZE); + + self.buf_left[self.counter] = ka; + self.buf_right[self.counter] = kb; + self.buf_hash[self.counter] = mask; + + let RCOTSenderOutput { msgs: blks, .. } = cot; + + let block = blks[0] ^ if mask { self.delta } else { Block::ZERO }; + self.buf_out[self.counter] = Self::set_zero(block); + self.counter += 1; + + Ok(()) + } + + /// Check and gate. + /// See step 6, 7 in Figure 5. + /// + /// # Arguments. + /// + /// * `vope` - The mask block received from ideal VOPE. + /// * `u` - The block sent by the prover. + /// * `v` - The block sent by the prover. + pub fn check_and_gate(&mut self, vope: Block, u: Block, v: Block) { + assert!(self.counter <= CHECK_BUFFER_SIZE); + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = self.buf_left[..self.counter] + .par_iter() + .zip(self.buf_right[..self.counter].par_iter()) + .zip(self.buf_out[..self.counter].par_iter()); + } else{ + let iter = self.buf_left[..self.counter] + .iter() + .zip(self.buf_right[..self.counter].iter()) + .zip(self.buf_out[..self.counter].iter()) + } + } + + // Compute B. + let block: Vec = iter + .map(|((a, b), c)| a.gfmul(*b) ^ c.gfmul(self.delta)) + .collect(); + + // Compute chi and powers. + self.hasher + .update(&bools_to_bytes(&self.buf_hash[..=self.counter])); + let seed = *self.hasher.finalize().as_bytes(); + let seed = Block::try_from(&seed[0..16]).unwrap(); + let chis = Block::powers(seed, self.counter); + + // Compute the inner product. + let w = Block::inn_prdt_red(&block, &chis); + self.and_gate_checked &= (w ^ vope) == u ^ v.gfmul(self.delta); + + self.hasher.update(&u.to_bytes()); + self.hasher.update(&v.to_bytes()); + self.counter = 0; + } + + // Set the lsb of the block to zero. + #[inline] + fn set_zero(block: Block) -> Block { + block & Block::MINIS_ONE + } +} diff --git a/crates/mpz-zk/src/vope/mod.rs b/crates/mpz-zk/src/vope/mod.rs index 1b34799d..789a131b 100644 --- a/crates/mpz-zk/src/vope/mod.rs +++ b/crates/mpz-zk/src/vope/mod.rs @@ -13,13 +13,16 @@ mod tests { use futures::TryFutureExt; use mpz_common::executor::test_st_executor; use mpz_core::Block; - use mpz_ot::ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}; + use mpz_ot::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + Correlation, + }; use mpz_zk_core::test::poly_check; fn setup() -> (Sender, Receiver, Block) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + let (rcot_sender, rcot_receiver) = ideal_rcot(); - let delta = rcot_sender.alice().get_mut().delta(); + let delta = rcot_sender.delta(); let sender = Sender::new(rcot_sender); let receiver = Receiver::new(rcot_receiver); From 52bcc7a169bca144847a2f5849ba1c77ff9714cc Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Wed, 4 Sep 2024 19:56:42 +0800 Subject: [PATCH 11/17] add mpz-zk-core test --- crates/mpz-zk-core/src/quicksilver/mod.rs | 58 +++++++++++++++++++ crates/mpz-zk-core/src/quicksilver/prover.rs | 3 +- .../mpz-zk-core/src/quicksilver/verifier.rs | 6 ++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs index d8f77f68..f4296da4 100644 --- a/crates/mpz-zk-core/src/quicksilver/mod.rs +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -20,3 +20,61 @@ fn bools_to_bytes(bv: &[bool]) -> Vec { } v } + +#[cfg(test)] +mod tests { + use mpz_core::prg::Prg; + use mpz_ot_core::{ideal::cot::IdealCOT, test::assert_cot}; + + use crate::{ideal::vope::IdealVOPE, VOPEReceiverOutput, VOPESenderOutput}; + + use super::{Prover, Verifier}; + + #[test] + fn test_qs_core() { + let mut prg = Prg::new(); + let mut input = vec![false; 100]; + prg.random_bools(&mut input); + + let mut ideal_cot = IdealCOT::default(); + let mut ideal_vope = IdealVOPE::default(); + + let mut delta = ideal_cot.delta(); + delta.set_lsb(); + + ideal_vope.set_delta(delta); + + let mut prover = Prover::new(); + let mut verifier = Verifier::new(delta); + + let (cot_sender, cot_receiver) = ideal_cot.random_correlated(input.len()); + + let (masks, prover_labels) = prover.compute_input_bits(&input, cot_receiver).unwrap(); + + let verifier_labels = verifier.compute_input_bits(&masks, cot_sender).unwrap(); + + assert_cot(delta, &input, &prover_labels, &verifier_labels); + + prover_labels + .iter() + .zip(verifier_labels.iter()) + .for_each(|(&mac, &key)| { + let (cot_sender, cot_receiver) = ideal_cot.random_correlated(1); + + let mask = prover.compute_and_gate(mac, mac, cot_receiver).unwrap(); + + verifier + .compute_and_gate(key, key, mask, cot_sender) + .unwrap(); + }); + + let (VOPESenderOutput { eval, .. }, VOPEReceiverOutput { coeff, .. }) = + ideal_vope.random_correlated(1); + + let (u, v) = prover.check_and_gate((coeff[0], coeff[1])); + + verifier.check_and_gate(eval, u, v); + + assert!(verifier.checked()); + } +} diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs index ed93f37b..9d452222 100644 --- a/crates/mpz-zk-core/src/quicksilver/prover.rs +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -165,7 +165,8 @@ impl Prover { (u, v) } - // set the LSB of the block to as the bit. + // Set the LSB of the block to as the bit. + // This assumes the lsb of delta is 1. #[inline] fn set_value(block: Block, b: bool) -> Block { (block & Block::MINIS_ONE) ^ (if b { Block::ONE } else { Block::ZERO }) diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs index fc2bfce7..0afa54b7 100644 --- a/crates/mpz-zk-core/src/quicksilver/verifier.rs +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -156,7 +156,13 @@ impl Verifier { self.counter = 0; } + /// Returns the and_check results. + pub fn checked(&self) -> bool { + self.and_gate_checked + } + // Set the lsb of the block to zero. + // This assumes the lsb of delta is 1. #[inline] fn set_zero(block: Block) -> Block { block & Block::MINIS_ONE From d3a50dd89dcc7889b665f5beb0e98e1de6aa722f Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Mon, 9 Sep 2024 23:08:41 +0800 Subject: [PATCH 12/17] refine core --- crates/mpz-zk-core/Cargo.toml | 3 +- crates/mpz-zk-core/src/ideal/mod.rs | 2 +- crates/mpz-zk-core/src/lib.rs | 2 +- crates/mpz-zk-core/src/quicksilver/error.rs | 8 +- crates/mpz-zk-core/src/quicksilver/mod.rs | 50 ++++++++++--- crates/mpz-zk-core/src/quicksilver/prover.rs | 74 ++++++++++--------- .../mpz-zk-core/src/quicksilver/verifier.rs | 68 +++++++++-------- crates/mpz-zk/src/lib.rs | 1 + crates/mpz-zk/src/quicksilver/error.rs | 1 + crates/mpz-zk/src/quicksilver/mod.rs | 5 ++ crates/mpz-zk/src/quicksilver/prover.rs | 1 + crates/mpz-zk/src/quicksilver/verifier.rs | 1 + value.rs | 0 13 files changed, 133 insertions(+), 83 deletions(-) create mode 100644 crates/mpz-zk/src/quicksilver/error.rs create mode 100644 crates/mpz-zk/src/quicksilver/mod.rs create mode 100644 crates/mpz-zk/src/quicksilver/prover.rs create mode 100644 crates/mpz-zk/src/quicksilver/verifier.rs create mode 100644 value.rs diff --git a/crates/mpz-zk-core/Cargo.toml b/crates/mpz-zk-core/Cargo.toml index 69b2672e..936f9dc8 100644 --- a/crates/mpz-zk-core/Cargo.toml +++ b/crates/mpz-zk-core/Cargo.toml @@ -25,6 +25,7 @@ matrix-transpose.workspace = true tlsn-utils.workspace = true +serde_arrays.workspace = true rayon = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } thiserror.workspace = true @@ -35,4 +36,4 @@ cfg-if.workspace = true bytemuck = { workspace = true, features = ["derive"] } enum-try-as-inner.workspace = true blake3.workspace = true -rand_core.workspace = true \ No newline at end of file +rand_core.workspace = true diff --git a/crates/mpz-zk-core/src/ideal/mod.rs b/crates/mpz-zk-core/src/ideal/mod.rs index c088fde3..b8436c2b 100644 --- a/crates/mpz-zk-core/src/ideal/mod.rs +++ b/crates/mpz-zk-core/src/ideal/mod.rs @@ -1,3 +1,3 @@ //! Ideal functionalities. -pub mod vope; \ No newline at end of file +pub mod vope; diff --git a/crates/mpz-zk-core/src/lib.rs b/crates/mpz-zk-core/src/lib.rs index 847b6f7c..de438fb6 100644 --- a/crates/mpz-zk-core/src/lib.rs +++ b/crates/mpz-zk-core/src/lib.rs @@ -21,10 +21,10 @@ use mpz_ot_core::TransferId; +pub mod ideal; pub mod quicksilver; pub mod test; pub mod vope; -pub mod ideal; /// The output the receiver receives from the VOPE functionality. #[derive(Debug)] diff --git a/crates/mpz-zk-core/src/quicksilver/error.rs b/crates/mpz-zk-core/src/quicksilver/error.rs index 4909a652..7bcc0b5a 100644 --- a/crates/mpz-zk-core/src/quicksilver/error.rs +++ b/crates/mpz-zk-core/src/quicksilver/error.rs @@ -5,10 +5,14 @@ use mpz_circuits::CircuitError; #[derive(Debug, thiserror::Error)] #[allow(missing_docs)] pub enum QsProverError { - #[error(transparent)] - CircuitError(#[from] CircuitError), #[error("invalid inputs")] InvalidInputs, + #[error("prover not finished")] + NotFinished, + #[error("not enough COT")] + NotEnoughCOT, + #[error(transparent)] + CircuitError(#[from] CircuitError), } /// Errors that can occur during verifying diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs index f4296da4..3e4792f6 100644 --- a/crates/mpz-zk-core/src/quicksilver/mod.rs +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -8,8 +8,13 @@ pub use error::*; pub use prover::Prover; pub use verifier::Verifier; +use serde::{Deserialize, Serialize}; + /// Buffer size of each check. -pub const CHECK_BUFFER_SIZE: usize = 1024 * 1024; +pub(crate) const CHECK_BUFFER_SIZE: usize = 1024 * 1024; + +/// Default amount of authenticated gates per batch. +pub(crate) const DEFAULT_BATCH_SIZE: usize = 128; #[inline] fn bools_to_bytes(bv: &[bool]) -> Vec { @@ -21,10 +26,29 @@ fn bools_to_bytes(bv: &[bool]) -> Vec { v } +/// A batch of bit masks. +/// +/// # Parameters +/// +/// - `N`: The size of a batch +#[derive(Debug, Serialize, Deserialize)] +pub struct MaskBitBatch( + #[serde(with = "serde_arrays")] [bool; N], +); + +impl MaskBitBatch { + /// Create a new batch of mask bits. + pub fn new(batch: [bool; N]) -> Self { + Self(batch) + } +} + #[cfg(test)] mod tests { use mpz_core::prg::Prg; - use mpz_ot_core::{ideal::cot::IdealCOT, test::assert_cot}; + use mpz_ot_core::{ + ideal::cot::IdealCOT, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput, + }; use crate::{ideal::vope::IdealVOPE, VOPEReceiverOutput, VOPESenderOutput}; @@ -49,9 +73,9 @@ mod tests { let (cot_sender, cot_receiver) = ideal_cot.random_correlated(input.len()); - let (masks, prover_labels) = prover.compute_input_bits(&input, cot_receiver).unwrap(); + let (masks, prover_labels) = prover.auth_input_bits(&input, cot_receiver).unwrap(); - let verifier_labels = verifier.compute_input_bits(&masks, cot_sender).unwrap(); + let verifier_labels = verifier.auth_input_bits(&masks, cot_sender).unwrap(); assert_cot(delta, &input, &prover_labels, &verifier_labels); @@ -61,19 +85,25 @@ mod tests { .for_each(|(&mac, &key)| { let (cot_sender, cot_receiver) = ideal_cot.random_correlated(1); - let mask = prover.compute_and_gate(mac, mac, cot_receiver).unwrap(); + let RCOTReceiverOutput { + choices: s, + msgs: blks, + .. + } = cot_receiver; + + let (mask, _) = prover.auth_and_gate(mac, mac, (s[0], blks[0])).unwrap(); + + let RCOTSenderOutput { msgs: blks, .. } = cot_sender; - verifier - .compute_and_gate(key, key, mask, cot_sender) - .unwrap(); + verifier.auth_and_gate(key, key, mask, blks[0]).unwrap(); }); let (VOPESenderOutput { eval, .. }, VOPEReceiverOutput { coeff, .. }) = ideal_vope.random_correlated(1); - let (u, v) = prover.check_and_gate((coeff[0], coeff[1])); + let (u, v) = prover.check_and_gates((coeff[0], coeff[1])); - verifier.check_and_gate(eval, u, v); + verifier.check_and_gates(eval, u, v); assert!(verifier.checked()); } diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs index 9d452222..dc549c0f 100644 --- a/crates/mpz-zk-core/src/quicksilver/prover.rs +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -4,8 +4,8 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIter use super::{bools_to_bytes, QsProverError, CHECK_BUFFER_SIZE}; -/// QuickSilver Prover. -#[derive(Debug)] +/// Internal QuickSilver Prover. +#[derive(Debug, Default)] pub struct Prover { /// Buffer for left wire label. buf_left: Vec, @@ -14,7 +14,7 @@ pub struct Prover { /// Buffer for output wire label. buf_out: Vec, /// Counter for check. - counter: usize, + check_counter: usize, /// Hasher. hasher: blake3::Hasher, /// Hash buffer for the bools. @@ -28,7 +28,7 @@ impl Prover { buf_left: vec![Block::ZERO; CHECK_BUFFER_SIZE], buf_right: vec![Block::ZERO; CHECK_BUFFER_SIZE], buf_out: vec![Block::ZERO; CHECK_BUFFER_SIZE], - counter: 0, + check_counter: 0, hasher: blake3::Hasher::new(), buf_hash: vec![false; CHECK_BUFFER_SIZE], } @@ -41,7 +41,7 @@ impl Prover { /// /// * `inputs` - The input bits. /// * `cot` - The COT mask received from Ideal COT as the receiver. - pub fn compute_input_bits( + pub fn auth_input_bits( &mut self, inputs: &[bool], cot: RCOTReceiverOutput, @@ -77,37 +77,31 @@ impl Prover { /// * `ma` - The MAC of wire a. /// * `mb` - The MAC of wire b. /// * `cot` - The COT mask received from Ideal COT as the receiver. - pub fn compute_and_gate( + pub fn auth_and_gate( &mut self, ma: Block, mb: Block, - cot: RCOTReceiverOutput, - ) -> Result { - if cot.choices.len() != 1 { - return Err(QsProverError::InvalidInputs); - } + cot: (bool, Block), + ) -> Result<(bool, Block), QsProverError> { + assert!(self.check_counter < CHECK_BUFFER_SIZE); - assert!(self.counter < CHECK_BUFFER_SIZE); + self.buf_left[self.check_counter] = ma; + self.buf_right[self.check_counter] = mb; - self.buf_left[self.counter] = ma; - self.buf_right[self.counter] = mb; - - let RCOTReceiverOutput { - choices: s, - msgs: blks, - .. - } = cot; + let s = cot.0; + let blk = cot.1; // Compute wa * wb let v = ma.lsb() & mb.lsb() == 1; // Compute the mask of v with s. - let d = v ^ s[0]; + let d = v ^ s; - self.buf_out[self.counter] = Self::set_value(blks[0], v); - self.buf_hash[self.counter] = d; - self.counter += 1; + let mc = Self::set_value(blk, v); + self.buf_out[self.check_counter] = mc; + self.buf_hash[self.check_counter] = d; + self.check_counter += 1; - Ok(d) + Ok((d, mc)) } /// Check and gate. @@ -116,19 +110,19 @@ impl Prover { /// # Arguments. /// /// * `vope` - The mask blocks received from ideal VOPE. - pub fn check_and_gate(&mut self, vope: (Block, Block)) -> (Block, Block) { - assert!(self.counter <= CHECK_BUFFER_SIZE); + pub fn check_and_gates(&mut self, vope: (Block, Block)) -> (Block, Block) { + assert!(self.check_counter <= CHECK_BUFFER_SIZE); cfg_if::cfg_if! { if #[cfg(feature = "rayon")]{ - let iter = self.buf_left[..self.counter] + let iter = self.buf_left[..self.check_counter] .par_iter() - .zip(self.buf_right[..self.counter].par_iter()) - .zip(self.buf_out[..self.counter].par_iter()); + .zip(self.buf_right[..self.check_counter].par_iter()) + .zip(self.buf_out[..self.check_counter].par_iter()); } else{ - let iter = self.buf_left[..self.counter] + let iter = self.buf_left[..self.check_counter] .iter() - .zip(self.buf_right[..self.counter].iter()) - .zip(self.buf_out[..self.counter].iter()) + .zip(self.buf_right[..self.check_counter].iter()) + .zip(self.buf_out[..self.check_counter].iter()) } } @@ -144,10 +138,10 @@ impl Prover { // Compute chi and powers. self.hasher - .update(&bools_to_bytes(&self.buf_hash[..self.counter])); + .update(&bools_to_bytes(&self.buf_hash[..self.check_counter])); let seed = *self.hasher.finalize().as_bytes(); let seed = Block::try_from(&seed[0..16]).unwrap(); - let chis = Block::powers(seed, self.counter); + let chis = Block::powers(seed, self.check_counter); // Compute the inner product. let u = Block::inn_prdt_red(&blocks.0, &chis); @@ -160,11 +154,19 @@ impl Prover { // Update the hasher self.hasher.update(&u.to_bytes()); self.hasher.update(&v.to_bytes()); - self.counter = 0; + self.check_counter = 0; (u, v) } + /// Enable and check or not. + /// If the check_counter is set into the default number, + /// we enable the check protocol. + #[inline] + pub fn enable_check(&self) -> bool { + self.check_counter == CHECK_BUFFER_SIZE + } + // Set the LSB of the block to as the bit. // This assumes the lsb of delta is 1. #[inline] diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs index 0afa54b7..747f2ed3 100644 --- a/crates/mpz-zk-core/src/quicksilver/verifier.rs +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -5,18 +5,18 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIter use super::{bools_to_bytes, QsVerifierError, CHECK_BUFFER_SIZE}; /// QuickSilver Verifier. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Verifier { /// Global secret. delta: Block, - /// Buffer for left wire label. + /// Buffer for left wire KEY. buf_left: Vec, - /// Buffer for right wire label. + /// Buffer for right wire KEY. buf_right: Vec, - /// Buffer for output wire label. + /// Buffer for output wire KEY. buf_out: Vec, /// Counter for check. - counter: usize, + check_counter: usize, /// Hasher. hasher: blake3::Hasher, /// Hash buffer for the bools. @@ -37,7 +37,7 @@ impl Verifier { buf_left: vec![Block::ZERO; CHECK_BUFFER_SIZE], buf_right: vec![Block::ZERO; CHECK_BUFFER_SIZE], buf_out: vec![Block::ZERO; CHECK_BUFFER_SIZE], - counter: 0, + check_counter: 0, hasher: blake3::Hasher::new(), buf_hash: vec![false; CHECK_BUFFER_SIZE], and_gate_checked: true, @@ -49,7 +49,7 @@ impl Verifier { /// /// * `masks` - The mask bits sent from the prover. /// * `cot` - The COT mask received from Ideal COT as the sender. - pub fn compute_input_bits( + pub fn auth_input_bits( &mut self, masks: &[bool], cot: RCOTSenderOutput, @@ -85,30 +85,25 @@ impl Verifier { /// * `kb` - The KEY of wire b. /// * `mask` - The mask sent by the prover. /// * `cot` - The COT mask received from Ideal COT as the sender. - pub fn compute_and_gate( + pub fn auth_and_gate( &mut self, ka: Block, kb: Block, mask: bool, - cot: RCOTSenderOutput, - ) -> Result<(), QsVerifierError> { - if cot.msgs.len() != 1 { - return Err(QsVerifierError::InvalidInputs); - } + cot: Block, + ) -> Result { + assert!(self.check_counter < CHECK_BUFFER_SIZE); - assert!(self.counter < CHECK_BUFFER_SIZE); - - self.buf_left[self.counter] = ka; - self.buf_right[self.counter] = kb; - self.buf_hash[self.counter] = mask; - - let RCOTSenderOutput { msgs: blks, .. } = cot; + self.buf_left[self.check_counter] = ka; + self.buf_right[self.check_counter] = kb; + self.buf_hash[self.check_counter] = mask; - let block = blks[0] ^ if mask { self.delta } else { Block::ZERO }; - self.buf_out[self.counter] = Self::set_zero(block); - self.counter += 1; + let block = cot ^ if mask { self.delta } else { Block::ZERO }; + let kc = Self::set_zero(block); + self.buf_out[self.check_counter] = kc; + self.check_counter += 1; - Ok(()) + Ok(kc) } /// Check and gate. @@ -119,14 +114,14 @@ impl Verifier { /// * `vope` - The mask block received from ideal VOPE. /// * `u` - The block sent by the prover. /// * `v` - The block sent by the prover. - pub fn check_and_gate(&mut self, vope: Block, u: Block, v: Block) { - assert!(self.counter <= CHECK_BUFFER_SIZE); + pub fn check_and_gates(&mut self, vope: Block, u: Block, v: Block) { + assert!(self.check_counter <= CHECK_BUFFER_SIZE); cfg_if::cfg_if! { if #[cfg(feature = "rayon")]{ - let iter = self.buf_left[..self.counter] + let iter = self.buf_left[..self.check_counter] .par_iter() - .zip(self.buf_right[..self.counter].par_iter()) - .zip(self.buf_out[..self.counter].par_iter()); + .zip(self.buf_right[..self.check_counter].par_iter()) + .zip(self.buf_out[..self.check_counter].par_iter()); } else{ let iter = self.buf_left[..self.counter] .iter() @@ -142,10 +137,10 @@ impl Verifier { // Compute chi and powers. self.hasher - .update(&bools_to_bytes(&self.buf_hash[..=self.counter])); + .update(&bools_to_bytes(&self.buf_hash[..=self.check_counter])); let seed = *self.hasher.finalize().as_bytes(); let seed = Block::try_from(&seed[0..16]).unwrap(); - let chis = Block::powers(seed, self.counter); + let chis = Block::powers(seed, self.check_counter); // Compute the inner product. let w = Block::inn_prdt_red(&block, &chis); @@ -153,10 +148,19 @@ impl Verifier { self.hasher.update(&u.to_bytes()); self.hasher.update(&v.to_bytes()); - self.counter = 0; + self.check_counter = 0; + } + + /// Enable and check or not. + /// If check_counter is set to the default buffer size, + /// we enable the check protocol. + #[inline] + pub fn enable_check(&self) -> bool { + self.check_counter == CHECK_BUFFER_SIZE } /// Returns the and_check results. + #[inline] pub fn checked(&self) -> bool { self.and_gate_checked } diff --git a/crates/mpz-zk/src/lib.rs b/crates/mpz-zk/src/lib.rs index 4ae2e10f..ed1aa96b 100644 --- a/crates/mpz-zk/src/lib.rs +++ b/crates/mpz-zk/src/lib.rs @@ -9,6 +9,7 @@ clippy::all )] +pub mod quicksilver; pub mod vope; /// An oblivious transfer error. diff --git a/crates/mpz-zk/src/quicksilver/error.rs b/crates/mpz-zk/src/quicksilver/error.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/error.rs @@ -0,0 +1 @@ + diff --git a/crates/mpz-zk/src/quicksilver/mod.rs b/crates/mpz-zk/src/quicksilver/mod.rs new file mode 100644 index 00000000..2e109d81 --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/mod.rs @@ -0,0 +1,5 @@ +//! Implementation of QuickSilver (https://eprint.iacr.org/2021/076.pdf). + +mod error; +mod prover; +mod verifier; diff --git a/crates/mpz-zk/src/quicksilver/prover.rs b/crates/mpz-zk/src/quicksilver/prover.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/prover.rs @@ -0,0 +1 @@ + diff --git a/crates/mpz-zk/src/quicksilver/verifier.rs b/crates/mpz-zk/src/quicksilver/verifier.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/verifier.rs @@ -0,0 +1 @@ + diff --git a/value.rs b/value.rs new file mode 100644 index 00000000..e69de29b From 30d3a11f495bc85029e51d24dd36721dd8b01b7d Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Mon, 9 Sep 2024 23:22:38 +0800 Subject: [PATCH 13/17] refine core --- crates/mpz-zk-core/src/quicksilver/error.rs | 23 +++------------- crates/mpz-zk-core/src/quicksilver/mod.rs | 26 ++----------------- crates/mpz-zk-core/src/quicksilver/prover.rs | 11 +++----- .../mpz-zk-core/src/quicksilver/verifier.rs | 6 ++--- 4 files changed, 12 insertions(+), 54 deletions(-) diff --git a/crates/mpz-zk-core/src/quicksilver/error.rs b/crates/mpz-zk-core/src/quicksilver/error.rs index 7bcc0b5a..468a7d34 100644 --- a/crates/mpz-zk-core/src/quicksilver/error.rs +++ b/crates/mpz-zk-core/src/quicksilver/error.rs @@ -1,26 +1,11 @@ //! Errors in QuickSilver. -use mpz_circuits::CircuitError; /// Errors that can occur during proving #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum QsProverError { - #[error("invalid inputs")] - InvalidInputs, - #[error("prover not finished")] - NotFinished, - #[error("not enough COT")] - NotEnoughCOT, - #[error(transparent)] - CircuitError(#[from] CircuitError), -} +#[error("invalid inputs: expect {0}")] +pub struct QsProverError(pub String); /// Errors that can occur during verifying #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum QsVerifierError { - #[error(transparent)] - CircuitError(#[from] CircuitError), - #[error("invalid inputs")] - InvalidInputs, -} +#[error("invalid inputs: expect {0}")] +pub struct QsVerifierError(pub String); diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs index 3e4792f6..3232c2a2 100644 --- a/crates/mpz-zk-core/src/quicksilver/mod.rs +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -8,14 +8,9 @@ pub use error::*; pub use prover::Prover; pub use verifier::Verifier; -use serde::{Deserialize, Serialize}; - /// Buffer size of each check. pub(crate) const CHECK_BUFFER_SIZE: usize = 1024 * 1024; -/// Default amount of authenticated gates per batch. -pub(crate) const DEFAULT_BATCH_SIZE: usize = 128; - #[inline] fn bools_to_bytes(bv: &[bool]) -> Vec { let offset = if bv.len() % 8 == 0 { 0 } else { 1 }; @@ -26,23 +21,6 @@ fn bools_to_bytes(bv: &[bool]) -> Vec { v } -/// A batch of bit masks. -/// -/// # Parameters -/// -/// - `N`: The size of a batch -#[derive(Debug, Serialize, Deserialize)] -pub struct MaskBitBatch( - #[serde(with = "serde_arrays")] [bool; N], -); - -impl MaskBitBatch { - /// Create a new batch of mask bits. - pub fn new(batch: [bool; N]) -> Self { - Self(batch) - } -} - #[cfg(test)] mod tests { use mpz_core::prg::Prg; @@ -91,11 +69,11 @@ mod tests { .. } = cot_receiver; - let (mask, _) = prover.auth_and_gate(mac, mac, (s[0], blks[0])).unwrap(); + let (mask, _) = prover.auth_and_gate(mac, mac, (s[0], blks[0])); let RCOTSenderOutput { msgs: blks, .. } = cot_sender; - verifier.auth_and_gate(key, key, mask, blks[0]).unwrap(); + verifier.auth_and_gate(key, key, mask, blks[0]); }); let (VOPESenderOutput { eval, .. }, VOPEReceiverOutput { coeff, .. }) = diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs index dc549c0f..b08ae577 100644 --- a/crates/mpz-zk-core/src/quicksilver/prover.rs +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -47,7 +47,7 @@ impl Prover { cot: RCOTReceiverOutput, ) -> Result<(Vec, Vec), QsProverError> { if cot.choices.len() != inputs.len() { - return Err(QsProverError::InvalidInputs); + return Err(QsProverError(format!("lengths not match"))); } let RCOTReceiverOutput { @@ -77,12 +77,7 @@ impl Prover { /// * `ma` - The MAC of wire a. /// * `mb` - The MAC of wire b. /// * `cot` - The COT mask received from Ideal COT as the receiver. - pub fn auth_and_gate( - &mut self, - ma: Block, - mb: Block, - cot: (bool, Block), - ) -> Result<(bool, Block), QsProverError> { + pub fn auth_and_gate(&mut self, ma: Block, mb: Block, cot: (bool, Block)) -> (bool, Block) { assert!(self.check_counter < CHECK_BUFFER_SIZE); self.buf_left[self.check_counter] = ma; @@ -101,7 +96,7 @@ impl Prover { self.buf_hash[self.check_counter] = d; self.check_counter += 1; - Ok((d, mc)) + (d, mc) } /// Check and gate. diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs index 747f2ed3..cdd1b27c 100644 --- a/crates/mpz-zk-core/src/quicksilver/verifier.rs +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -56,7 +56,7 @@ impl Verifier { // The mask bits sent by prover. ) -> Result, QsVerifierError> { if masks.len() != cot.msgs.len() { - return Err(QsVerifierError::InvalidInputs); + return Err(QsVerifierError(format!("lengths not match"))); } // Hash the bools. @@ -91,7 +91,7 @@ impl Verifier { kb: Block, mask: bool, cot: Block, - ) -> Result { + ) -> Block { assert!(self.check_counter < CHECK_BUFFER_SIZE); self.buf_left[self.check_counter] = ka; @@ -103,7 +103,7 @@ impl Verifier { self.buf_out[self.check_counter] = kc; self.check_counter += 1; - Ok(kc) + kc } /// Check and gate. From c99d7f6b6a6171b294684a9486a1e971d3ed55dc Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 10 Sep 2024 17:49:13 +0800 Subject: [PATCH 14/17] qs with io --- crates/mpz-zk-core/src/quicksilver/mod.rs | 24 ++- crates/mpz-zk-core/src/quicksilver/prover.rs | 25 ++- .../mpz-zk-core/src/quicksilver/verifier.rs | 59 ++++-- crates/mpz-zk-core/src/test.rs | 13 ++ crates/mpz-zk/Cargo.toml | 2 + crates/mpz-zk/src/lib.rs | 14 +- crates/mpz-zk/src/quicksilver/error.rs | 52 +++++ crates/mpz-zk/src/quicksilver/mod.rs | 57 ++++++ crates/mpz-zk/src/quicksilver/prover.rs | 170 +++++++++++++++++ crates/mpz-zk/src/quicksilver/verifier.rs | 177 ++++++++++++++++++ crates/mpz-zk/src/vope/mod.rs | 44 ++--- crates/mpz-zk/src/vope/receiver.rs | 48 +++-- crates/mpz-zk/src/vope/sender.rs | 57 +++--- 13 files changed, 634 insertions(+), 108 deletions(-) diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs index 3232c2a2..aac81cef 100644 --- a/crates/mpz-zk-core/src/quicksilver/mod.rs +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -11,8 +11,9 @@ pub use verifier::Verifier; /// Buffer size of each check. pub(crate) const CHECK_BUFFER_SIZE: usize = 1024 * 1024; +/// Convert bool vector to byte vector. #[inline] -fn bools_to_bytes(bv: &[bool]) -> Vec { +pub fn bools_to_bytes(bv: &[bool]) -> Vec { let offset = if bv.len() % 8 == 0 { 0 } else { 1 }; let mut v = vec![0u8; bv.len() / 8 + offset]; for (i, b) in bv.iter().enumerate() { @@ -21,6 +22,18 @@ fn bools_to_bytes(bv: &[bool]) -> Vec { v } +/// Convert byte vector to bool vector. +#[inline] +pub fn bytes_to_bools(v: &[u8]) -> Vec { + let mut bv = Vec::with_capacity(v.len() * 8); + for byte in v.iter() { + for i in 0..8 { + bv.push(((byte >> (7 - i)) & 1) != 0); + } + } + bv +} + #[cfg(test)] mod tests { use mpz_core::prg::Prg; @@ -28,7 +41,7 @@ mod tests { ideal::cot::IdealCOT, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput, }; - use crate::{ideal::vope::IdealVOPE, VOPEReceiverOutput, VOPESenderOutput}; + use crate::ideal::vope::IdealVOPE; use super::{Prover, Verifier}; @@ -76,12 +89,11 @@ mod tests { verifier.auth_and_gate(key, key, mask, blks[0]); }); - let (VOPESenderOutput { eval, .. }, VOPEReceiverOutput { coeff, .. }) = - ideal_vope.random_correlated(1); + let (vope_sender, vope_receiver) = ideal_vope.random_correlated(1); - let (u, v) = prover.check_and_gates((coeff[0], coeff[1])); + let (u, v) = prover.check_and_gates(vope_receiver); - verifier.check_and_gates(eval, u, v); + verifier.check_and_gates(vope_sender, u, v); assert!(verifier.checked()); } diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs index b08ae577..91f7b510 100644 --- a/crates/mpz-zk-core/src/quicksilver/prover.rs +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -1,7 +1,9 @@ -use mpz_core::Block; +use mpz_core::{hash::Hash, serialize::CanonicalSerialize, utils::blake3, Block}; use mpz_ot_core::RCOTReceiverOutput; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use crate::VOPEReceiverOutput; + use super::{bools_to_bytes, QsProverError, CHECK_BUFFER_SIZE}; /// Internal QuickSilver Prover. @@ -105,7 +107,7 @@ impl Prover { /// # Arguments. /// /// * `vope` - The mask blocks received from ideal VOPE. - pub fn check_and_gates(&mut self, vope: (Block, Block)) -> (Block, Block) { + pub fn check_and_gates(&mut self, vope: VOPEReceiverOutput) -> (Block, Block) { assert!(self.check_counter <= CHECK_BUFFER_SIZE); cfg_if::cfg_if! { if #[cfg(feature = "rayon")]{ @@ -143,8 +145,8 @@ impl Prover { let v = Block::inn_prdt_red(&blocks.1, &chis); // Mask the results. - let u = u ^ vope.0; - let v = v ^ vope.1; + let u = u ^ vope.coeff[0]; + let v = v ^ vope.coeff[1]; // Update the hasher self.hasher.update(&u.to_bytes()); @@ -155,13 +157,26 @@ impl Prover { } /// Enable and check or not. - /// If the check_counter is set into the default number, + /// If check_counter is set into the default number, /// we enable the check protocol. #[inline] pub fn enable_check(&self) -> bool { self.check_counter == CHECK_BUFFER_SIZE } + /// Enable the final check or not. + /// if check_counter is zero, then no need to check. + #[inline] + pub fn enable_final_check(&self) -> bool { + self.check_counter != 0 + } + + /// Hash the output macs + #[inline] + pub fn finish(&self, macs: &[Block]) -> Hash { + Hash::from(blake3(&macs.to_bytes())) + } + // Set the LSB of the block to as the bit. // This assumes the lsb of delta is 1. #[inline] diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs index cdd1b27c..d3e39b37 100644 --- a/crates/mpz-zk-core/src/quicksilver/verifier.rs +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -1,7 +1,9 @@ -use mpz_core::Block; +use mpz_core::{hash::Hash, serialize::CanonicalSerialize, utils::blake3, Block}; use mpz_ot_core::RCOTSenderOutput; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use crate::VOPESenderOutput; + use super::{bools_to_bytes, QsVerifierError, CHECK_BUFFER_SIZE}; /// QuickSilver Verifier. @@ -21,8 +23,8 @@ pub struct Verifier { hasher: blake3::Hasher, /// Hash buffer for the bools. buf_hash: Vec, - /// Indicate the and_gate check passes or not. - and_gate_checked: bool, + /// Indicate the checks pass or not. + checked: bool, } impl Verifier { @@ -40,7 +42,7 @@ impl Verifier { check_counter: 0, hasher: blake3::Hasher::new(), buf_hash: vec![false; CHECK_BUFFER_SIZE], - and_gate_checked: true, + checked: true, } } /// Compute authenticated bits for inputs. @@ -85,13 +87,7 @@ impl Verifier { /// * `kb` - The KEY of wire b. /// * `mask` - The mask sent by the prover. /// * `cot` - The COT mask received from Ideal COT as the sender. - pub fn auth_and_gate( - &mut self, - ka: Block, - kb: Block, - mask: bool, - cot: Block, - ) -> Block { + pub fn auth_and_gate(&mut self, ka: Block, kb: Block, mask: bool, cot: Block) -> Block { assert!(self.check_counter < CHECK_BUFFER_SIZE); self.buf_left[self.check_counter] = ka; @@ -114,7 +110,7 @@ impl Verifier { /// * `vope` - The mask block received from ideal VOPE. /// * `u` - The block sent by the prover. /// * `v` - The block sent by the prover. - pub fn check_and_gates(&mut self, vope: Block, u: Block, v: Block) { + pub fn check_and_gates(&mut self, vope: VOPESenderOutput, u: Block, v: Block) { assert!(self.check_counter <= CHECK_BUFFER_SIZE); cfg_if::cfg_if! { if #[cfg(feature = "rayon")]{ @@ -144,7 +140,7 @@ impl Verifier { // Compute the inner product. let w = Block::inn_prdt_red(&block, &chis); - self.and_gate_checked &= (w ^ vope) == u ^ v.gfmul(self.delta); + self.checked &= (w ^ vope.eval) == u ^ v.gfmul(self.delta); self.hasher.update(&u.to_bytes()); self.hasher.update(&v.to_bytes()); @@ -159,10 +155,45 @@ impl Verifier { self.check_counter == CHECK_BUFFER_SIZE } + /// Enable the final check or not. + /// if check_counter is zero, then no need to check. + #[inline] + pub fn enable_final_check(&self) -> bool { + self.check_counter != 0 + } + + /// Hash the output keys with the outputs. + pub fn finish( + &mut self, + hash: Hash, + keys: &[Block], + outputs: &[bool], + ) -> Result<(), QsVerifierError> { + if keys.len() != outputs.len() { + return Err(QsVerifierError(format!("lengths not match"))); + } + + let pre_hash: Vec = keys + .iter() + .zip(outputs.iter()) + .map(|(&k, &o)| if o { k ^ self.delta } else { k }) + .collect(); + + let expected_hash = Hash::from(blake3(&pre_hash.to_bytes())); + self.checked &= hash == expected_hash; + + Ok(()) + } + /// Returns the and_check results. #[inline] pub fn checked(&self) -> bool { - self.and_gate_checked + self.checked + } + + /// Returns delta. + pub fn delta(&self) -> Block { + self.delta } // Set the lsb of the block to zero. diff --git a/crates/mpz-zk-core/src/test.rs b/crates/mpz-zk-core/src/test.rs index b22369d3..c02a81ee 100644 --- a/crates/mpz-zk-core/src/test.rs +++ b/crates/mpz-zk-core/src/test.rs @@ -2,6 +2,8 @@ use mpz_core::Block; +use crate::{VOPEReceiverOutput, VOPESenderOutput}; + /// Check polynomial relation. pub fn poly_check(a: &[Block], b: Block, delta: Block) -> bool { b == a @@ -9,3 +11,14 @@ pub fn poly_check(a: &[Block], b: Block, delta: Block) -> bool { .rev() .fold(Block::ZERO, |acc, &x| x ^ (delta.gfmul(acc))) } + +/// Assert VOPE relation. +pub fn assert_vope( + send: VOPESenderOutput, + recv: VOPEReceiverOutput, + delta: Block, +) -> bool { + let send = send.eval; + let recv = recv.coeff; + poly_check(&recv, send, delta) +} diff --git a/crates/mpz-zk/Cargo.toml b/crates/mpz-zk/Cargo.toml index 54b42ced..cc147603 100644 --- a/crates/mpz-zk/Cargo.toml +++ b/crates/mpz-zk/Cargo.toml @@ -23,6 +23,7 @@ mpz-common.workspace = true mpz-cointoss.workspace = true mpz-ot-core.workspace = true mpz-ot.workspace = true +mpz-circuits.workspace = true tlsn-utils-aio.workspace = true @@ -51,3 +52,4 @@ tokio = { workspace = true, features = [ "rt", "rt-multi-thread", ] } +aes = { workspace = true, features = [] } diff --git a/crates/mpz-zk/src/lib.rs b/crates/mpz-zk/src/lib.rs index ed1aa96b..6a336b8e 100644 --- a/crates/mpz-zk/src/lib.rs +++ b/crates/mpz-zk/src/lib.rs @@ -12,7 +12,7 @@ pub mod quicksilver; pub mod vope; -/// An oblivious transfer error. +/// A vope error. #[derive(Debug, thiserror::Error)] #[allow(missing_docs)] pub enum VOPEError { @@ -23,3 +23,15 @@ pub enum VOPEError { #[error("receiver error: {0}")] ReceiverError(Box), } + +/// A zk error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ZKError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("prover error: {0}")] + ProverError(Box), + #[error("verifier error: {0}")] + VerifierError(Box), +} diff --git a/crates/mpz-zk/src/quicksilver/error.rs b/crates/mpz-zk/src/quicksilver/error.rs index 8b137891..294e6b6c 100644 --- a/crates/mpz-zk/src/quicksilver/error.rs +++ b/crates/mpz-zk/src/quicksilver/error.rs @@ -1 +1,53 @@ +use mpz_circuits::CircuitError; +use crate::ZKError; + +/// Prover error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ProverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::quicksilver::QsProverError), + #[error(transparent)] + OTError(#[from] mpz_ot::OTError), + #[error(transparent)] + CircuitError(#[from] CircuitError), + #[error(transparent)] + VOPEError(#[from] crate::vope::error::ReceiverError), +} + +/// Verifier error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum VerifierError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::quicksilver::QsVerifierError), + #[error(transparent)] + OTError(#[from] mpz_ot::OTError), + #[error(transparent)] + CircuitError(#[from] CircuitError), + #[error(transparent)] + VOPEError(#[from] crate::vope::error::SenderError), +} + +impl From for ZKError { + fn from(err: ProverError) -> Self { + match err { + ProverError::IOError(e) => e.into(), + e => ZKError::ProverError(Box::new(e)), + } + } +} + +impl From for ZKError { + fn from(err: VerifierError) -> Self { + match err { + VerifierError::IOError(e) => e.into(), + e => ZKError::VerifierError(Box::new(e)), + } + } +} diff --git a/crates/mpz-zk/src/quicksilver/mod.rs b/crates/mpz-zk/src/quicksilver/mod.rs index 2e109d81..99b4be92 100644 --- a/crates/mpz-zk/src/quicksilver/mod.rs +++ b/crates/mpz-zk/src/quicksilver/mod.rs @@ -3,3 +3,60 @@ mod error; mod prover; mod verifier; + +pub use error::{ProverError, VerifierError}; +pub use prover::Prover; +pub use verifier::Verifier; + +#[cfg(test)] +mod tests { + use crate::{ + quicksilver::{Prover, Verifier}, + ZKError, + }; + use aes::{ + cipher::{BlockEncrypt, KeyInit}, + Aes128, + }; + use futures::TryFutureExt; + use mpz_circuits::circuits::AES128; + use mpz_common::executor::test_st_executor; + use mpz_ot::{ideal::cot::ideal_cot, Correlation}; + + #[tokio::test] + async fn test_qs() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut rcot_sender, mut rcot_receiver) = ideal_cot(); + + let mut delta = rcot_sender.delta(); + delta.set_lsb(); + + let mut prover = Prover::new(); + let mut verifier = Verifier::new(delta); + + let key = [69u8; 16]; + let msg = [42u8; 16]; + + let expected: [u8; 16] = { + let cipher = Aes128::new_from_slice(&key).unwrap(); + let mut out = msg.into(); + cipher.encrypt_block(&mut out); + out.into() + }; + + let input_value = [key, msg].concat(); + + tokio::try_join!( + prover + .prove(&mut ctx_sender, &AES128, input_value, &mut rcot_receiver) + .map_err(ZKError::from), + verifier + .verify(&mut ctx_receiver, &AES128, expected, &mut rcot_sender) + .map_err(ZKError::from) + ) + .unwrap(); + + assert!(verifier.checked()) + } +} diff --git a/crates/mpz-zk/src/quicksilver/prover.rs b/crates/mpz-zk/src/quicksilver/prover.rs index 8b137891..68421a17 100644 --- a/crates/mpz-zk/src/quicksilver/prover.rs +++ b/crates/mpz-zk/src/quicksilver/prover.rs @@ -1 +1,171 @@ +use itybity::IntoBits; +use mpz_circuits::{types::Value, Circuit, CircuitError, Gate}; +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver}; +use mpz_zk_core::quicksilver::{bools_to_bytes, Prover as ProverCore}; +use serio::SinkExt; +use super::error::ProverError; + +/// QuickSilver Prover. +pub struct Prover { + macs: Vec, + prover_core: ProverCore, +} + +impl Prover { + /// Create a new instance. + pub fn new() -> Self { + Self { + macs: Vec::default(), + prover_core: ProverCore::new(), + } + } + /// Authenticate inputs. + async fn auth_inputs( + &mut self, + ctx: &mut Ctx, + inputs: &[bool], + rcot: &mut RCOT, + ) -> Result, ProverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let cot = rcot.receive_random_correlated(ctx, inputs.len()).await?; + + let (bits, macs) = self.prover_core.auth_input_bits(&inputs, cot)?; + + ctx.io_mut().send(bools_to_bytes(&bits)).await?; + + Ok(macs) + } + + /// Prove. + pub async fn prove( + &mut self, + ctx: &mut Ctx, + circ: &Circuit, + input_value: impl Into, + rcot: &mut RCOT, + ) -> Result<(), ProverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let len: usize = circ.inputs().iter().map(|v| v.len()).sum(); + + let input_value = input_value.into().into_lsb0_vec(); + if input_value.len() != len { + return Err(CircuitError::InvalidInputCount(len, input_value.len()))?; + } + + if circ.feed_count() > self.macs.len() { + self.macs.resize(circ.feed_count(), Default::default()); + } + + // Handle inputs. + let input_macs = self.auth_inputs(ctx, &input_value, rcot).await?; + + for (mac, node) in input_macs + .iter() + .zip(circ.inputs().iter().flat_map(|v| v.iter())) + { + self.macs[node.id()] = *mac; + } + + // Authenticate the circuit. + while let Some(gate) = circ.gates().iter().next() { + match gate { + Gate::Xor { + x: node_x, + y: node_y, + z: node_z, + } => { + let x_0 = self.macs[node_x.id()]; + let y_0 = self.macs[node_y.id()]; + self.macs[node_z.id()] = x_0 ^ y_0; + } + Gate::And { + x: node_x, + y: node_y, + z: node_z, + } => { + // Check the batched authenticated and gates. + if self.prover_core.enable_check() { + self.check_and_gates(ctx, rcot).await?; + } + + let x_0 = self.macs[node_x.id()]; + let y_0 = self.macs[node_y.id()]; + + let RCOTReceiverOutput { + choices: bit, + msgs: blk, + .. + } = rcot.receive_random_correlated(ctx, 1).await?; + + let (d, z_0) = self.prover_core.auth_and_gate(x_0, y_0, (bit[0], blk[0])); + + ctx.io_mut().send(d).await?; + + self.macs[node_z.id()] = z_0; + } + Gate::Inv { + x: node_x, + z: node_z, + } => { + let x_0 = self.macs[node_x.id()]; + self.macs[node_z.id()] = x_0 ^ Block::ONE; + } + } + } + + // Handle final check. + if self.prover_core.enable_final_check() { + self.check_and_gates(ctx, rcot).await?; + } + + // Handle outputs. + let output_macs: Vec = circ + .outputs() + .iter() + .flat_map(|v| v.iter()) + .map(|node| self.macs[node.id()]) + .collect(); + + // Send the hash of the output macs. + let hash = self.prover_core.finish(&output_macs); + ctx.io_mut().send(hash).await?; + + Ok(()) + } + + // Check the and gates. + async fn check_and_gates( + &mut self, + ctx: &mut Ctx, + rcot: &mut RCOT, + ) -> Result<(), ProverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let mut vope = crate::vope::receiver::Receiver::new(); + vope.setup()?; + + let v = vope.receive(ctx, rcot, 1).await?; + + let mut prover_core = std::mem::replace(&mut self.prover_core, ProverCore::default()); + + let (u, prover_core) = + CpuBackend::blocking(move || (prover_core.check_and_gates(v), prover_core)).await; + + // Send (U, V) + ctx.io_mut().send(u).await?; + + self.prover_core = prover_core; + Ok(()) + } +} diff --git a/crates/mpz-zk/src/quicksilver/verifier.rs b/crates/mpz-zk/src/quicksilver/verifier.rs index 8b137891..d508c036 100644 --- a/crates/mpz-zk/src/quicksilver/verifier.rs +++ b/crates/mpz-zk/src/quicksilver/verifier.rs @@ -1 +1,178 @@ +use itybity::IntoBits; +use mpz_circuits::{types::Value, Circuit, CircuitError, Gate}; +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot::{RCOTSenderOutput, RandomCOTSender}; +use mpz_zk_core::quicksilver::{bytes_to_bools, Verifier as VerifierCore}; +use serio::stream::IoStreamExt; +use super::error::VerifierError; + +/// QuickSilver Verifier. +pub struct Verifier { + keys: Vec, + verifier_core: VerifierCore, +} + +impl Verifier { + /// Create a new instance. + pub fn new(delta: Block) -> Self { + Self { + keys: Vec::default(), + verifier_core: VerifierCore::new(delta), + } + } + + /// Authenticate inputs. + async fn auth_inputs( + &mut self, + len: usize, + ctx: &mut Ctx, + rcot: &mut RCOT, + ) -> Result, VerifierError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let cot = rcot.send_random_correlated(ctx, len).await?; + + let bytes: Vec = ctx.io_mut().expect_next().await?; + + let masks = bytes_to_bools(&bytes); + assert_eq!(masks.len(), len); + + let blks = self.verifier_core.auth_input_bits(&masks, cot)?; + + Ok(blks) + } + + /// Verify + pub async fn verify( + &mut self, + ctx: &mut Ctx, + circ: &Circuit, + output_value: impl Into, + rcot: &mut RCOT, + ) -> Result<(), VerifierError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let len: usize = circ.outputs().iter().map(|v| v.len()).sum(); + let output_value = output_value.into().into_lsb0_vec(); + if output_value.len() != len { + return Err(CircuitError::InvalidOutputCount(len, output_value.len()))?; + } + + if circ.feed_count() > self.keys.len() { + self.keys.resize(circ.feed_count(), Default::default()); + } + + let input_len: usize = circ.inputs().iter().map(|v| v.len()).sum(); + // Handle inputs. + let input_keys = self.auth_inputs(input_len, ctx, rcot).await?; + + for (key, node) in input_keys + .iter() + .zip(circ.inputs().iter().flat_map(|v| v.iter())) + { + self.keys[node.id()] = *key; + } + + // Authenticate the circuit. + while let Some(gate) = circ.gates().iter().next() { + match gate { + Gate::Xor { + x: node_x, + y: node_y, + z: node_z, + } => { + let x_0 = self.keys[node_x.id()]; + let y_0 = self.keys[node_y.id()]; + self.keys[node_z.id()] = x_0 ^ y_0; + } + Gate::And { + x: node_x, + y: node_y, + z: node_z, + } => { + // Check the batched authenticated and gats. + if self.verifier_core.enable_check() { + self.check_and_gates(ctx, rcot).await?; + } + + let x_0 = self.keys[node_x.id()]; + let y_0 = self.keys[node_y.id()]; + + let RCOTSenderOutput { msgs: blk, .. } = + rcot.send_random_correlated(ctx, 1).await?; + + let mask = ctx.io_mut().expect_next().await?; + let z_0 = self.verifier_core.auth_and_gate(x_0, y_0, mask, blk[0]); + + self.keys[node_z.id()] = z_0; + } + Gate::Inv { + x: node_x, + z: node_z, + } => { + let x_0 = self.keys[node_x.id()]; + self.keys[node_z.id()] = x_0 ^ self.verifier_core.delta() ^ Block::ONE; + } + } + } + + // Handle final check. + if self.verifier_core.enable_final_check() { + self.check_and_gates(ctx, rcot).await?; + } + + // Handle outputs. + let output_keys: Vec = circ + .outputs() + .iter() + .flat_map(|v| v.iter()) + .map(|node| self.keys[node.id()]) + .collect(); + + let hash = ctx.io_mut().expect_next().await?; + self.verifier_core + .finish(hash, &output_keys, &output_value)?; + + Ok(()) + } + + // Check the and gates. + async fn check_and_gates( + &mut self, + ctx: &mut Ctx, + rcot: &mut RCOT, + ) -> Result<(), VerifierError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let mut vope = crate::vope::sender::Sender::new(); + vope.setup(self.verifier_core.delta())?; + + let v = vope.send(ctx, rcot, 1).await?; + + let u: (Block, Block) = ctx.io_mut().expect_next().await?; + + let mut verifier_core = std::mem::replace(&mut self.verifier_core, VerifierCore::default()); + + let (_, verifier_core) = CpuBackend::blocking(move || { + (verifier_core.check_and_gates(v, u.0, u.1), verifier_core) + }) + .await; + + self.verifier_core = verifier_core; + Ok(()) + } + + /// Returns checked over not. + #[inline] + pub fn checked(&self) -> bool { + self.verifier_core.checked() + } +} diff --git a/crates/mpz-zk/src/vope/mod.rs b/crates/mpz-zk/src/vope/mod.rs index 789a131b..e395fa0a 100644 --- a/crates/mpz-zk/src/vope/mod.rs +++ b/crates/mpz-zk/src/vope/mod.rs @@ -12,55 +12,49 @@ mod tests { }; use futures::TryFutureExt; use mpz_common::executor::test_st_executor; - use mpz_core::Block; - use mpz_ot::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - Correlation, - }; - use mpz_zk_core::test::poly_check; - - fn setup() -> (Sender, Receiver, Block) { - let (rcot_sender, rcot_receiver) = ideal_rcot(); - - let delta = rcot_sender.delta(); - - let sender = Sender::new(rcot_sender); - let receiver = Receiver::new(rcot_receiver); - - (sender, receiver, delta) - } + use mpz_ot::{ideal::cot::ideal_rcot, Correlation}; + use mpz_zk_core::test::assert_vope; #[tokio::test] async fn test_vope() { let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver, delta) = setup(); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let mut sender = Sender::new(); + let mut receiver = Receiver::new(); + + let delta = rcot_sender.delta(); - sender.setup_with_delta(delta).unwrap(); + sender.setup(delta).unwrap(); receiver.setup().unwrap(); let d = 1; let (output_sender, output_receiver) = tokio::try_join!( - sender.extend(&mut ctx_sender, d).map_err(VOPEError::from), + sender + .send(&mut ctx_sender, &mut rcot_sender, d) + .map_err(VOPEError::from), receiver - .extend(&mut ctx_receiver, d) + .receive(&mut ctx_receiver, &mut rcot_receiver, d) .map_err(VOPEError::from) ) .unwrap(); - assert!(poly_check(&output_receiver, output_sender, delta)); + assert!(assert_vope(output_sender, output_receiver, delta)); let d = 5; let (output_sender, output_receiver) = tokio::try_join!( - sender.extend(&mut ctx_sender, d).map_err(VOPEError::from), + sender + .send(&mut ctx_sender, &mut rcot_sender, d) + .map_err(VOPEError::from), receiver - .extend(&mut ctx_receiver, d) + .receive(&mut ctx_receiver, &mut rcot_receiver, d) .map_err(VOPEError::from) ) .unwrap(); - assert!(poly_check(&output_receiver, output_sender, delta)); + assert!(assert_vope(output_sender, output_receiver, delta)); } } diff --git a/crates/mpz-zk/src/vope/receiver.rs b/crates/mpz-zk/src/vope/receiver.rs index d1fb3cdd..4344d93f 100644 --- a/crates/mpz-zk/src/vope/receiver.rs +++ b/crates/mpz-zk/src/vope/receiver.rs @@ -4,10 +4,13 @@ use crate::vope::error::ReceiverError; use enum_try_as_inner::EnumTryAsInner; use mpz_common::Context; use mpz_core::Block; -use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver}; -use mpz_zk_core::vope::{ - receiver::{state, Receiver as ReceiverCore}, - CSP, +use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver, TransferId}; +use mpz_zk_core::{ + vope::{ + receiver::{state, Receiver as ReceiverCore}, + CSP, + }, + VOPEReceiverOutput, }; use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; @@ -17,27 +20,26 @@ use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; pub enum State { Initialized(ReceiverCore), Extension(ReceiverCore), - Complete, Error, } /// VOPE receiver (prover) #[derive(Debug)] -pub struct Receiver { +pub struct Receiver { state: State, - rcot: RandomCOT, + id: TransferId, } -impl Receiver { +impl Receiver { /// Creates a new receiver. /// /// # Arguments /// /// * `rcot` - The random COT used by the receiver. - pub fn new(rcot: RandomCOT) -> Self { + pub fn new() -> Self { Self { state: State::Initialized(ReceiverCore::new()), - rcot, + id: TransferId::default(), } } @@ -58,14 +60,17 @@ impl Receiver { /// # Arguments /// /// * `ctx` - The context. + /// * `rcot` - The ideal random COT. /// * `d` - The polynomial degree. - pub async fn extend( + pub async fn receive( &mut self, ctx: &mut Ctx, + rcot: &mut RCOT, d: usize, - ) -> Result, ReceiverError> + ) -> Result, ReceiverError> where - RandomCOT: RandomCOTReceiver, + Ctx: Context, + RCOT: RandomCOTReceiver, { let mut ext_receiver = std::mem::replace(&mut self.state, State::Error).try_into_extension()?; @@ -76,8 +81,7 @@ impl Receiver { msgs: ms, choices: us, .. - } = self - .rcot + } = rcot .receive_random_correlated(ctx, (2 * d - 1) * CSP) .await?; @@ -91,15 +95,9 @@ impl Receiver { self.state = State::Extension(ext_receiver); - Ok(res) - } - - /// Complete extension. - pub fn finalize(&mut self) -> Result<(), ReceiverError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - self.state = State::Complete; - - Ok(()) + Ok(VOPEReceiverOutput { + id: self.id.next_id(), + coeff: res, + }) } } diff --git a/crates/mpz-zk/src/vope/sender.rs b/crates/mpz-zk/src/vope/sender.rs index 447dcb6b..460b58cd 100644 --- a/crates/mpz-zk/src/vope/sender.rs +++ b/crates/mpz-zk/src/vope/sender.rs @@ -4,10 +4,13 @@ use crate::vope::error::SenderError; use enum_try_as_inner::EnumTryAsInner; use mpz_common::Context; use mpz_core::Block; -use mpz_ot::{RCOTSenderOutput, RandomCOTSender}; -use mpz_zk_core::vope::{ - sender::{state, Sender as SenderCore}, - CSP, +use mpz_ot::{RCOTSenderOutput, RandomCOTSender, TransferId}; +use mpz_zk_core::{ + vope::{ + sender::{state, Sender as SenderCore}, + CSP, + }, + VOPESenderOutput, }; use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; @@ -17,27 +20,22 @@ use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; pub enum State { Initialized(SenderCore), Extension(SenderCore), - Complete, Error, } /// VOPE sender (verifier) #[derive(Debug)] -pub struct Sender { +pub struct Sender { state: State, - rcot: RandomCOT, + id: TransferId, } -impl Sender { +impl Sender { /// Creates a new Sender. - /// - /// # Arguments - /// - /// * `rcot` - The random COT used by the sender. - pub fn new(rcot: RandomCOT) -> Self { + pub fn new() -> Self { Self { state: State::Initialized(SenderCore::new()), - rcot, + id: TransferId::default(), } } @@ -46,7 +44,7 @@ impl Sender { /// # Arguments /// /// * `delta` - The delta value to use for VOPE extension. - pub fn setup_with_delta(&mut self, delta: Block) -> Result<(), SenderError> { + pub fn setup(&mut self, delta: Block) -> Result<(), SenderError> { let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; let ext_sender = ext_sender.setup(delta); @@ -61,39 +59,34 @@ impl Sender { /// # Arguments /// /// * `ctx` - The context. + /// * `rcot` - The ideal random COT. /// * `d` - The polynomial degree. - pub async fn extend( + pub async fn send( &mut self, ctx: &mut Ctx, + rcot: &mut RCOT, d: usize, - ) -> Result + ) -> Result, SenderError> where - RandomCOT: RandomCOTSender, + Ctx: Context, + RCOT: RandomCOTSender, { let mut ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_extension()?; assert!(d > 0); - let RCOTSenderOutput { msgs: ks, .. } = self - .rcot - .send_random_correlated(ctx, (2 * d - 1) * CSP) - .await?; + let RCOTSenderOutput { msgs: ks, .. } = + rcot.send_random_correlated(ctx, (2 * d - 1) * CSP).await?; let (ext_sender, res) = Backend::spawn(move || ext_sender.extend(&ks, d).map(|res| (ext_sender, res))).await?; self.state = State::Extension(ext_sender); - Ok(res) - } - - /// Complete extension. - pub fn finalize(&mut self) -> Result<(), SenderError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - self.state = State::Complete; - - Ok(()) + Ok(VOPESenderOutput { + id: self.id.next_id(), + eval: res, + }) } } From f8d070124d7fedc5af2510f90c7e5bf3ff17342d Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Thu, 12 Sep 2024 18:56:44 +0800 Subject: [PATCH 15/17] add qs with io --- crates/mpz-core/src/block.rs | 3 +- crates/mpz-ot/src/ideal/cot.rs | 9 +++++ crates/mpz-zk-core/src/ideal/vope.rs | 6 +++- crates/mpz-zk-core/src/quicksilver/mod.rs | 32 +++++++++++------ crates/mpz-zk-core/src/quicksilver/prover.rs | 2 +- .../mpz-zk-core/src/quicksilver/verifier.rs | 6 ++-- crates/mpz-zk/src/quicksilver/mod.rs | 18 ++++++---- crates/mpz-zk/src/quicksilver/prover.rs | 31 ++++++++++++---- crates/mpz-zk/src/quicksilver/verifier.rs | 36 +++++++++++-------- crates/mpz-zk/src/vope/receiver.rs | 11 +++--- crates/mpz-zk/src/vope/sender.rs | 7 ++++ 11 files changed, 112 insertions(+), 49 deletions(-) diff --git a/crates/mpz-core/src/block.rs b/crates/mpz-core/src/block.rs index 1e1d67ac..b191eea9 100644 --- a/crates/mpz-core/src/block.rs +++ b/crates/mpz-core/src/block.rs @@ -130,9 +130,10 @@ impl Block { } /// Generate the powers of the seed. + /// Starting with seed. #[inline(always)] pub fn powers(seed: Self, size: usize) -> Vec { - successors(Some(Block::ONE), |pow| Some(pow.gfmul(seed))) + successors(Some(seed), |pow| Some(pow.gfmul(seed))) .take(size) .collect() } diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index bc7df0a6..aa441c60 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -47,6 +47,15 @@ pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { (IdealCOTSender(alice), IdealCOTReceiver(bob)) } +/// Returns an ideal random COT sender and receiver with a given delta. +pub fn ideal_rcot_with_delta(delta: Block) -> (IdealCOTSender, IdealCOTReceiver) { + let (alice, bob) = ideal_f2p(IdealCOT::new( + mpz_core::prg::Prg::new().random_block(), + delta, + )); + (IdealCOTSender(alice), IdealCOTReceiver(bob)) +} + /// Ideal COT sender. #[derive(Debug, Clone, Default)] pub struct IdealCOTSender(Alice); diff --git a/crates/mpz-zk-core/src/ideal/vope.rs b/crates/mpz-zk-core/src/ideal/vope.rs index c1238c22..14cc2a79 100644 --- a/crates/mpz-zk-core/src/ideal/vope.rs +++ b/crates/mpz-zk-core/src/ideal/vope.rs @@ -1,5 +1,7 @@ //! Ideal VOPE functionality. +use std::iter::successors; + use mpz_core::{prg::Prg, Block}; use mpz_ot_core::TransferId; use rand_core::SeedableRng; @@ -63,7 +65,9 @@ impl IdealVOPE { let mut coeff = vec![Block::ZERO; degree + 1]; self.prg.random_blocks(&mut coeff); - let powers = Block::powers(self.delta, degree + 1); + let powers: Vec = successors(Some(Block::ONE), |pow| Some(pow.gfmul(self.delta))) + .take(degree + 1) + .collect(); let eval = Block::inn_prdt_red(&coeff, &powers); diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs index aac81cef..89783a9d 100644 --- a/crates/mpz-zk-core/src/quicksilver/mod.rs +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -36,7 +36,7 @@ pub fn bytes_to_bools(v: &[u8]) -> Vec { #[cfg(test)] mod tests { - use mpz_core::prg::Prg; + use mpz_core::{prg::Prg, Block}; use mpz_ot_core::{ ideal::cot::IdealCOT, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput, }; @@ -47,17 +47,15 @@ mod tests { #[test] fn test_qs_core() { + const N: usize = 200; let mut prg = Prg::new(); - let mut input = vec![false; 100]; + let mut input = vec![false; N]; prg.random_bools(&mut input); - - let mut ideal_cot = IdealCOT::default(); - let mut ideal_vope = IdealVOPE::default(); - - let mut delta = ideal_cot.delta(); + let mut delta = prg.random_block(); delta.set_lsb(); - ideal_vope.set_delta(delta); + let mut ideal_cot = IdealCOT::new(Block::ZERO, delta); + let mut ideal_vope = IdealVOPE::new(Block::ZERO, delta); let mut prover = Prover::new(); let mut verifier = Verifier::new(delta); @@ -67,13 +65,19 @@ mod tests { let (masks, prover_labels) = prover.auth_input_bits(&input, cot_receiver).unwrap(); let verifier_labels = verifier.auth_input_bits(&masks, cot_sender).unwrap(); + let input_exp: Vec = prover_labels.iter().map(|x| x.lsb() == 1).collect(); + assert_eq!(input, input_exp); assert_cot(delta, &input, &prover_labels, &verifier_labels); + let mut output_macs = vec![Block::ZERO; N]; + let mut output_keys = vec![Block::ZERO; N]; prover_labels .iter() .zip(verifier_labels.iter()) - .for_each(|(&mac, &key)| { + .zip(output_macs.iter_mut()) + .zip(output_keys.iter_mut()) + .for_each(|(((&mac, &key), output_mac), output_key)| { let (cot_sender, cot_receiver) = ideal_cot.random_correlated(1); let RCOTReceiverOutput { @@ -82,19 +86,25 @@ mod tests { .. } = cot_receiver; - let (mask, _) = prover.auth_and_gate(mac, mac, (s[0], blks[0])); + let (mask, tmp) = prover.auth_and_gate(mac, mac, (s[0], blks[0])); + *output_mac = tmp; let RCOTSenderOutput { msgs: blks, .. } = cot_sender; - verifier.auth_and_gate(key, key, mask, blks[0]); + *output_key = verifier.auth_and_gate(key, key, mask, blks[0]); }); + assert_cot(delta, &input, &output_macs, &output_keys); + let (vope_sender, vope_receiver) = ideal_vope.random_correlated(1); let (u, v) = prover.check_and_gates(vope_receiver); verifier.check_and_gates(vope_sender, u, v); + let hash = prover.finish(&output_macs); + verifier.finish(hash, &output_keys, &input).unwrap(); + assert!(verifier.checked()); } } diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs index 91f7b510..0a095b9a 100644 --- a/crates/mpz-zk-core/src/quicksilver/prover.rs +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -49,7 +49,7 @@ impl Prover { cot: RCOTReceiverOutput, ) -> Result<(Vec, Vec), QsProverError> { if cot.choices.len() != inputs.len() { - return Err(QsProverError(format!("lengths not match"))); + return Err(QsProverError("lengths not match".to_string())); } let RCOTReceiverOutput { diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs index d3e39b37..7120c78d 100644 --- a/crates/mpz-zk-core/src/quicksilver/verifier.rs +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -58,7 +58,7 @@ impl Verifier { // The mask bits sent by prover. ) -> Result, QsVerifierError> { if masks.len() != cot.msgs.len() { - return Err(QsVerifierError(format!("lengths not match"))); + return Err(QsVerifierError("lengths not match".to_string())); } // Hash the bools. @@ -133,7 +133,7 @@ impl Verifier { // Compute chi and powers. self.hasher - .update(&bools_to_bytes(&self.buf_hash[..=self.check_counter])); + .update(&bools_to_bytes(&self.buf_hash[..self.check_counter])); let seed = *self.hasher.finalize().as_bytes(); let seed = Block::try_from(&seed[0..16]).unwrap(); let chis = Block::powers(seed, self.check_counter); @@ -170,7 +170,7 @@ impl Verifier { outputs: &[bool], ) -> Result<(), QsVerifierError> { if keys.len() != outputs.len() { - return Err(QsVerifierError(format!("lengths not match"))); + return Err(QsVerifierError("lengths not match".to_string())); } let pre_hash: Vec = keys diff --git a/crates/mpz-zk/src/quicksilver/mod.rs b/crates/mpz-zk/src/quicksilver/mod.rs index 99b4be92..eb26621d 100644 --- a/crates/mpz-zk/src/quicksilver/mod.rs +++ b/crates/mpz-zk/src/quicksilver/mod.rs @@ -19,25 +19,28 @@ mod tests { Aes128, }; use futures::TryFutureExt; - use mpz_circuits::circuits::AES128; + use mpz_circuits::{circuits::AES128, evaluate}; use mpz_common::executor::test_st_executor; - use mpz_ot::{ideal::cot::ideal_cot, Correlation}; + use mpz_core::prg::Prg; + use mpz_ot::ideal::cot::ideal_rcot_with_delta; #[tokio::test] async fn test_qs() { let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut rcot_sender, mut rcot_receiver) = ideal_cot(); - - let mut delta = rcot_sender.delta(); + let mut delta = Prg::new().random_block(); delta.set_lsb(); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot_with_delta(delta); + let mut prover = Prover::new(); let mut verifier = Verifier::new(delta); let key = [69u8; 16]; let msg = [42u8; 16]; + let output = evaluate!(AES128, fn(key, msg) -> [u8; 16]).unwrap(); + let expected: [u8; 16] = { let cipher = Aes128::new_from_slice(&key).unwrap(); let mut out = msg.into(); @@ -45,14 +48,15 @@ mod tests { out.into() }; + assert_eq!(output, expected); + let input_value = [key, msg].concat(); - tokio::try_join!( prover .prove(&mut ctx_sender, &AES128, input_value, &mut rcot_receiver) .map_err(ZKError::from), verifier - .verify(&mut ctx_receiver, &AES128, expected, &mut rcot_sender) + .verify(&mut ctx_receiver, &AES128, output, &mut rcot_sender) .map_err(ZKError::from) ) .unwrap(); diff --git a/crates/mpz-zk/src/quicksilver/prover.rs b/crates/mpz-zk/src/quicksilver/prover.rs index 68421a17..be867c24 100644 --- a/crates/mpz-zk/src/quicksilver/prover.rs +++ b/crates/mpz-zk/src/quicksilver/prover.rs @@ -3,7 +3,7 @@ use mpz_circuits::{types::Value, Circuit, CircuitError, Gate}; use mpz_common::{cpu::CpuBackend, Context}; use mpz_core::Block; use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver}; -use mpz_zk_core::quicksilver::{bools_to_bytes, Prover as ProverCore}; +use mpz_zk_core::quicksilver::Prover as ProverCore; use serio::SinkExt; use super::error::ProverError; @@ -22,7 +22,8 @@ impl Prover { prover_core: ProverCore::new(), } } - /// Authenticate inputs. + + // Authenticate inputs. async fn auth_inputs( &mut self, ctx: &mut Ctx, @@ -35,14 +36,22 @@ impl Prover { { let cot = rcot.receive_random_correlated(ctx, inputs.len()).await?; - let (bits, macs) = self.prover_core.auth_input_bits(&inputs, cot)?; + let (bits, macs) = self.prover_core.auth_input_bits(inputs, cot)?; - ctx.io_mut().send(bools_to_bytes(&bits)).await?; + // TODO: optimize sending bools. + ctx.io_mut().send(bits).await?; Ok(macs) } - /// Prove. + /// Prove a circuit. + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `circ` - The circuit. + /// * `input_value` - The witness hold by the prover. + /// * `rcot` - The ideal RCOT functionality. pub async fn prove( &mut self, ctx: &mut Ctx, @@ -76,7 +85,7 @@ impl Prover { } // Authenticate the circuit. - while let Some(gate) = circ.gates().iter().next() { + for gate in circ.gates() { match gate { Gate::Xor { x: node_x, @@ -108,6 +117,7 @@ impl Prover { let (d, z_0) = self.prover_core.auth_and_gate(x_0, y_0, (bit[0], blk[0])); + // TODO: optimize sending bool. ctx.io_mut().send(d).await?; self.macs[node_z.id()] = z_0; @@ -157,7 +167,7 @@ impl Prover { let v = vope.receive(ctx, rcot, 1).await?; - let mut prover_core = std::mem::replace(&mut self.prover_core, ProverCore::default()); + let mut prover_core = std::mem::take(&mut self.prover_core); let (u, prover_core) = CpuBackend::blocking(move || (prover_core.check_and_gates(v), prover_core)).await; @@ -169,3 +179,10 @@ impl Prover { Ok(()) } } + +impl Default for Prover { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/crates/mpz-zk/src/quicksilver/verifier.rs b/crates/mpz-zk/src/quicksilver/verifier.rs index d508c036..84f0ab65 100644 --- a/crates/mpz-zk/src/quicksilver/verifier.rs +++ b/crates/mpz-zk/src/quicksilver/verifier.rs @@ -3,7 +3,7 @@ use mpz_circuits::{types::Value, Circuit, CircuitError, Gate}; use mpz_common::{cpu::CpuBackend, Context}; use mpz_core::Block; use mpz_ot::{RCOTSenderOutput, RandomCOTSender}; -use mpz_zk_core::quicksilver::{bytes_to_bools, Verifier as VerifierCore}; +use mpz_zk_core::quicksilver::Verifier as VerifierCore; use serio::stream::IoStreamExt; use super::error::VerifierError; @@ -23,11 +23,11 @@ impl Verifier { } } - /// Authenticate inputs. + // Authenticate inputs. async fn auth_inputs( &mut self, - len: usize, ctx: &mut Ctx, + len: usize, rcot: &mut RCOT, ) -> Result, VerifierError> where @@ -36,17 +36,23 @@ impl Verifier { { let cot = rcot.send_random_correlated(ctx, len).await?; - let bytes: Vec = ctx.io_mut().expect_next().await?; + let masks: Vec = ctx.io_mut().expect_next().await?; - let masks = bytes_to_bools(&bytes); assert_eq!(masks.len(), len); - let blks = self.verifier_core.auth_input_bits(&masks, cot)?; + let keys = self.verifier_core.auth_input_bits(&masks, cot)?; - Ok(blks) + Ok(keys) } - /// Verify + /// Verify a circuit. + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `circ` - The circuit. + /// * `output_value` - The public output value hold by the verifier and prover. + /// * `rcot` - The ideal RCOT functionality. pub async fn verify( &mut self, ctx: &mut Ctx, @@ -59,6 +65,7 @@ impl Verifier { RCOT: RandomCOTSender, { let len: usize = circ.outputs().iter().map(|v| v.len()).sum(); + let output_value = output_value.into().into_lsb0_vec(); if output_value.len() != len { return Err(CircuitError::InvalidOutputCount(len, output_value.len()))?; @@ -68,9 +75,9 @@ impl Verifier { self.keys.resize(circ.feed_count(), Default::default()); } - let input_len: usize = circ.inputs().iter().map(|v| v.len()).sum(); // Handle inputs. - let input_keys = self.auth_inputs(input_len, ctx, rcot).await?; + let input_len: usize = circ.inputs().iter().map(|v| v.len()).sum(); + let input_keys = self.auth_inputs(ctx, input_len, rcot).await?; for (key, node) in input_keys .iter() @@ -80,7 +87,7 @@ impl Verifier { } // Authenticate the circuit. - while let Some(gate) = circ.gates().iter().next() { + for gate in circ.gates() { match gate { Gate::Xor { x: node_x, @@ -96,7 +103,7 @@ impl Verifier { y: node_y, z: node_z, } => { - // Check the batched authenticated and gats. + // Check the batched authenticated and gates. if self.verifier_core.enable_check() { self.check_and_gates(ctx, rcot).await?; } @@ -135,6 +142,7 @@ impl Verifier { .map(|node| self.keys[node.id()]) .collect(); + // Receive the hash of output macs and verify. let hash = ctx.io_mut().expect_next().await?; self.verifier_core .finish(hash, &output_keys, &output_value)?; @@ -159,7 +167,7 @@ impl Verifier { let u: (Block, Block) = ctx.io_mut().expect_next().await?; - let mut verifier_core = std::mem::replace(&mut self.verifier_core, VerifierCore::default()); + let mut verifier_core = std::mem::take(&mut self.verifier_core); let (_, verifier_core) = CpuBackend::blocking(move || { (verifier_core.check_and_gates(v, u.0, u.1), verifier_core) @@ -170,7 +178,7 @@ impl Verifier { Ok(()) } - /// Returns checked over not. + /// Returns checked or not. #[inline] pub fn checked(&self) -> bool { self.verifier_core.checked() diff --git a/crates/mpz-zk/src/vope/receiver.rs b/crates/mpz-zk/src/vope/receiver.rs index 4344d93f..f34958a1 100644 --- a/crates/mpz-zk/src/vope/receiver.rs +++ b/crates/mpz-zk/src/vope/receiver.rs @@ -32,10 +32,6 @@ pub struct Receiver { impl Receiver { /// Creates a new receiver. - /// - /// # Arguments - /// - /// * `rcot` - The random COT used by the receiver. pub fn new() -> Self { Self { state: State::Initialized(ReceiverCore::new()), @@ -101,3 +97,10 @@ impl Receiver { }) } } + +impl Default for Receiver { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/crates/mpz-zk/src/vope/sender.rs b/crates/mpz-zk/src/vope/sender.rs index 460b58cd..2643ad5a 100644 --- a/crates/mpz-zk/src/vope/sender.rs +++ b/crates/mpz-zk/src/vope/sender.rs @@ -90,3 +90,10 @@ impl Sender { }) } } + +impl Default for Sender { + #[inline] + fn default() -> Self { + Self::new() + } +} From 537d97904f3be4e68e959607123b530e58f43bf9 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Thu, 12 Sep 2024 19:04:59 +0800 Subject: [PATCH 16/17] fmt --- crates/mpz-zk/src/quicksilver/prover.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/mpz-zk/src/quicksilver/prover.rs b/crates/mpz-zk/src/quicksilver/prover.rs index be867c24..e4a9dc86 100644 --- a/crates/mpz-zk/src/quicksilver/prover.rs +++ b/crates/mpz-zk/src/quicksilver/prover.rs @@ -22,7 +22,7 @@ impl Prover { prover_core: ProverCore::new(), } } - + // Authenticate inputs. async fn auth_inputs( &mut self, From 84411b323c14a798de1f4dfa7201841fcec11efa Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Thu, 12 Sep 2024 19:15:30 +0800 Subject: [PATCH 17/17] remove empty files --- value.rs | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 value.rs diff --git a/value.rs b/value.rs deleted file mode 100644 index e69de29b..00000000