Skip to content

Commit

Permalink
refactor(mpz-ot): ferret clean up (#173)
Browse files Browse the repository at this point in the history
* refactor(mpz-ot): ferret clean up

* buffer OTs, setup rcot only invoked once

* fix mpcot test

---------

Co-authored-by: Xiang Xie <xiexiangiscas@gmail.com>
  • Loading branch information
sinui0 and xiangxiecrypto authored Aug 16, 2024
1 parent ef72b9b commit 116035e
Show file tree
Hide file tree
Showing 21 changed files with 1,227 additions and 1,403 deletions.
4 changes: 2 additions & 2 deletions crates/mpz-common/src/ideal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<F> Clone for Alice<F> {

impl<F> Alice<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
pub fn lock(&self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

Expand Down Expand Up @@ -96,7 +96,7 @@ impl<F> Clone for Bob<F> {

impl<F> Bob<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
pub fn lock(&self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

Expand Down
30 changes: 23 additions & 7 deletions crates/mpz-ot-core/src/ferret/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ mod tests {
use receiver::Receiver;
use sender::Sender;

use crate::ideal::{cot::IdealCOT, mpcot::IdealMpcot};
use crate::test::assert_cot;
use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput};
use crate::{
ideal::{cot::IdealCOT, mpcot::IdealMpcot},
test::assert_cot,
MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput,
};
use mpz_core::{lpn::LpnParameters, prg::Prg};

const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters {
Expand Down Expand Up @@ -111,8 +113,15 @@ mod tests {
let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) =
ideal_mpcot.extend(&query.0, query.1);

let msgs = sender.extend(&s).unwrap();
let (choices, received) = receiver.extend(&r).unwrap();
sender.extend(s).unwrap();
receiver.extend(r).unwrap();

let RCOTSenderOutput { msgs, .. } = sender.consume(2).unwrap();
let RCOTReceiverOutput {
choices,
msgs: received,
..
} = receiver.consume(2).unwrap();

assert_cot(delta, &choices, &msgs, &received);

Expand All @@ -123,8 +132,15 @@ mod tests {
let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) =
ideal_mpcot.extend(&query.0, query.1);

let msgs = sender.extend(&s).unwrap();
let (choices, received) = receiver.extend(&r).unwrap();
sender.extend(s).unwrap();
receiver.extend(r).unwrap();

let RCOTSenderOutput { msgs, .. } = sender.consume(sender.remaining()).unwrap();
let RCOTReceiverOutput {
choices,
msgs: received,
..
} = receiver.consume(receiver.remaining()).unwrap();

assert_cot(delta, &choices, &msgs, &received);
}
Expand Down
51 changes: 44 additions & 7 deletions crates/mpz-ot-core/src/ferret/receiver.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Ferret receiver
use std::collections::VecDeque;

use mpz_core::{
lpn::{LpnEncoder, LpnParameters},
Block,
};

use crate::{
ferret::{error::ReceiverError, LpnType},
TransferId,
RCOTReceiverOutput, TransferId,
};

use super::msgs::LpnMatrixSeed;
Expand Down Expand Up @@ -63,6 +65,8 @@ impl Receiver {
w: w.to_vec(),
e: Vec::default(),
id: TransferId::default(),
choices_buffer: VecDeque::new(),
msgs_buffer: VecDeque::new(),
},
},
LpnMatrixSeed { seed },
Expand All @@ -71,6 +75,16 @@ impl Receiver {
}

impl Receiver<state::Extension> {
/// Returns the current transfer id.
pub fn id(&self) -> TransferId {
self.state.id
}

/// Returns the number of remaining COTs.
pub fn remaining(&self) -> usize {
self.state.choices_buffer.len()
}

/// The prepare precedure of extension, sample error vectors and outputs information for MPCOT.
/// See step 3 and 4.
pub fn get_mpcot_query(&mut self) -> (Vec<u32>, usize) {
Expand Down Expand Up @@ -100,15 +114,15 @@ impl Receiver<state::Extension> {
/// # Arguments.
///
/// * `r` - The vector received from the MPCOT protocol.
pub fn extend(&mut self, r: &[Block]) -> Result<(Vec<bool>, Vec<Block>), ReceiverError> {
pub fn extend(&mut self, r: Vec<Block>) -> Result<(), ReceiverError> {
if r.len() != self.state.lpn_parameters.n {
return Err(ReceiverError("the length of r should be n".to_string()));
}

self.state.id.next();

// Compute z = A * w + r.
let mut z = r.to_vec();
let mut z = r;
self.state.lpn_encoder.compute(&mut z, &self.state.w);

// Compute x = A * u + e.
Expand All @@ -133,12 +147,32 @@ impl Receiver<state::Extension> {
// Update counter
self.state.counter += 1;

Ok((x_, z_))
self.state.choices_buffer.extend(x_);
self.state.msgs_buffer.extend(z_);

Ok(())
}

/// Returns id
pub fn id(&self) -> TransferId {
self.state.id
/// Consumes `count` COTs.
pub fn consume(
&mut self,
count: usize,
) -> Result<RCOTReceiverOutput<bool, Block>, ReceiverError> {
if count > self.state.choices_buffer.len() {
return Err(ReceiverError(format!(
"insufficient OTs: {} < {count}",
self.state.choices_buffer.len()
)));
}

let choices = self.state.choices_buffer.drain(0..count).collect();
let msgs = self.state.msgs_buffer.drain(0..count).collect();

Ok(RCOTReceiverOutput {
id: self.state.id.next(),
choices,
msgs,
})
}
}

Expand Down Expand Up @@ -186,6 +220,9 @@ pub mod state {

/// TransferID
pub(super) id: TransferId,
/// Extended OTs buffers.
pub(super) choices_buffer: VecDeque<bool>,
pub(super) msgs_buffer: VecDeque<Block>,
}

impl State for Extension {}
Expand Down
49 changes: 41 additions & 8 deletions crates/mpz-ot-core/src/ferret/sender.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Ferret sender.
use std::collections::VecDeque;

use mpz_core::{
lpn::{LpnEncoder, LpnParameters},
Block,
};

use crate::{
ferret::{error::SenderError, LpnType},
TransferId,
RCOTSenderOutput, TransferId,
};

use super::msgs::LpnMatrixSeed;
Expand Down Expand Up @@ -61,12 +63,28 @@ impl Sender {
lpn_encoder,
v: v.to_vec(),
id: TransferId::default(),
msgs_buffer: VecDeque::new(),
},
})
}
}

impl Sender<state::Extension> {
/// Returns the current transfer id.
pub fn id(&self) -> TransferId {
self.state.id
}

/// Returns the number of remaining COTs.
pub fn remaining(&self) -> usize {
self.state.msgs_buffer.len()
}

/// Returns the delta correlation.
pub fn delta(&self) -> Block {
self.state.delta
}

/// Outputs the information for MPCOT.
///
/// See step 3 and 4.
Expand All @@ -86,15 +104,15 @@ impl Sender<state::Extension> {
/// # Arguments.
///
/// * `s` - The vector received from the MPCOT protocol.
pub fn extend(&mut self, s: &[Block]) -> Result<Vec<Block>, SenderError> {
pub fn extend(&mut self, s: Vec<Block>) -> Result<(), SenderError> {
if s.len() != self.state.lpn_parameters.n {
return Err(SenderError("the length of s should be n".to_string()));
}

self.state.id.next();

// Compute y = A * v + s
let mut y = s.to_vec();
let mut y = s;
self.state.lpn_encoder.compute(&mut y, &self.state.v);

let y_ = y.split_off(self.state.lpn_parameters.k);
Expand All @@ -104,13 +122,26 @@ impl Sender<state::Extension> {

// Update counter
self.state.counter += 1;
self.state.msgs_buffer.extend(y_);

Ok(y_)
Ok(())
}

/// Returns id
pub fn id(&self) -> TransferId {
self.state.id
/// Consumes `count` COTs.
pub fn consume(&mut self, count: usize) -> Result<RCOTSenderOutput<Block>, SenderError> {
if count > self.state.msgs_buffer.len() {
return Err(SenderError(format!(
"insufficient OTs: {} < {count}",
self.state.msgs_buffer.len()
)));
}

let msgs = self.state.msgs_buffer.drain(0..count).collect();

Ok(RCOTSenderOutput {
id: self.state.id.next(),
msgs,
})
}
}

Expand Down Expand Up @@ -159,8 +190,10 @@ pub mod state {
/// Sender's COT message in the setup phase.
pub(super) v: Vec<Block>,

/// TransferID.
/// Transfer ID.
pub(crate) id: TransferId,
/// COT messages buffer.
pub(super) msgs_buffer: VecDeque<Block>,
}

impl State for Extension {}
Expand Down
2 changes: 1 addition & 1 deletion crates/mpz-ot-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId {

impl TransferId {
/// Returns the current transfer ID, incrementing `self` in-place.
pub(crate) fn next(&mut self) -> Self {
pub fn next(&mut self) -> Self {
let id = *self;
self.0 += 1;
id
Expand Down
Loading

0 comments on commit 116035e

Please sign in to comment.