Skip to content

Commit

Permalink
clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangxiecrypto committed Oct 20, 2023
1 parent 6561413 commit 92ef8f8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 38 deletions.
2 changes: 1 addition & 1 deletion ot/mpz-ot-core/src/ferret/spcot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ mod tests {
.iter_mut()
.zip(output_receiver.iter())
.all(|(vs, (ws, alpha))| {
vs[*alpha] ^= delta;
vs[*alpha as usize] ^= delta;
vs == ws
});
}
Expand Down
46 changes: 22 additions & 24 deletions ot/mpz-ot-core/src/ferret/spcot/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Receiver<state::Extension> {
pub fn extend_mask_bits(
&mut self,
h: usize,
alpha: usize,
alpha: u32,
rs: &[bool],
) -> Result<MaskBits, ReceiverError> {
if self.state.extended {
Expand All @@ -77,14 +77,13 @@ impl Receiver<state::Extension> {
}

// Step 4 in Figure 6
let mut alpha_vec = alpha.to_msb0_vec();
alpha_vec.drain(0..alpha_vec.len() - h);

let bs: Vec<bool> = alpha_vec
.iter()
let bs: Vec<bool> = alpha
.iter_msb0()
.skip(u32::BITS as usize - h)
// Computes alpha_i XOR r_i XOR 1.
.zip(rs.iter())
.map(|(alpha, r)| alpha == r)
.map(|(alpha, &r)| alpha == r)
.collect();

// Updates hasher.
Expand All @@ -106,7 +105,7 @@ impl Receiver<state::Extension> {
pub fn extend(
&mut self,
h: usize,
alpha: usize,
alpha: u32,
ts: &[Block],
extendfs: ExtendFromSender,
) -> Result<(), ReceiverError> {
Expand Down Expand Up @@ -139,10 +138,11 @@ impl Receiver<state::Extension> {
self.state.hasher.update(&ms.to_bytes());
self.state.hasher.update(&sum.to_bytes());

let mut alpha_bar_vec = alpha.to_msb0_vec();
alpha_bar_vec.drain(0..alpha_bar_vec.len() - h);

alpha_bar_vec.iter_mut().for_each(|a| *a = !*a);
let alpha_bar_vec: Vec<bool> = alpha
.iter_msb0()
.skip(u32::BITS as usize - h)
.map(|a| !a)
.collect();

// Setp 5 in Figure 6.
let k: Vec<Block> = ms
Expand All @@ -168,7 +168,7 @@ impl Receiver<state::Extension> {
ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec);

// Sets `tree[alpha]`, which is `ws[alpha]`.
tree[alpha] = tree.iter().fold(sum, |acc, &x| acc ^ x);
tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x);

self.state.unchecked_ws.extend_from_slice(&tree);
self.state.alphas_and_length.push((alpha, 1 << h));
Expand All @@ -187,9 +187,9 @@ impl Receiver<state::Extension> {
/// * `x_star` - The message from COT ideal functionality for the receiver. Only the random bits are used.
pub fn check_pre(&mut self, x_star: &[bool]) -> Result<CheckFromReceiver, ReceiverError> {
if x_star.len() != CSP {
return Err(ReceiverError::InvalidLength(
"the length of x* should be 128".to_string(),
));
return Err(ReceiverError::InvalidLength(format!(
"the length of x* should be {CSP}"
)));
}

let seed = *self.state.hasher.finalize().as_bytes();
Expand All @@ -201,13 +201,12 @@ impl Receiver<state::Extension> {
for (alpha, n) in &self.state.alphas_and_length {
let mut chis = vec![Block::ZERO; *n];
prg.random_blocks(&mut chis);
sum_chi_alpha ^= chis[*alpha];
sum_chi_alpha ^= chis[*alpha as usize];
self.state.chis.extend_from_slice(&chis);
}

let x_prime: Vec<bool> = sum_chi_alpha
.to_lsb0_vec()
.into_iter()
.iter_lsb0()
.zip(x_star)
.map(|(x, &x_star)| x != x_star)
.collect();
Expand All @@ -227,13 +226,13 @@ impl Receiver<state::Extension> {
&mut self,
z_star: &[Block],
check: CheckFromSender,
) -> Result<Vec<(Vec<Block>, usize)>, ReceiverError> {
) -> Result<Vec<(Vec<Block>, u32)>, ReceiverError> {
let CheckFromSender { hashed_v } = check;

if z_star.len() != CSP {
return Err(ReceiverError::InvalidLength(
"the length of z* should be 128".to_string(),
));
return Err(ReceiverError::InvalidLength(format!(
"the length of z* should be {CSP}"
)));
}

// Computes the base X^i
Expand Down Expand Up @@ -281,7 +280,6 @@ pub mod state {

/// The receiver's initial state.
#[derive(Default)]
#[allow(missing_docs)]
pub struct Initialized {}

impl State for Initialized {}
Expand All @@ -297,7 +295,7 @@ pub mod state {
/// Receiver's random challenges chis.
pub(super) chis: Vec<Block>,
/// Stores the alpha and the length in each extend phase.
pub(super) alphas_and_length: Vec<(usize, usize)>,
pub(super) alphas_and_length: Vec<(u32, usize)>,

/// Current COT counter
pub(super) cot_counter: usize,
Expand Down
26 changes: 13 additions & 13 deletions ot/mpz-ot-core/src/ferret/spcot/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ impl Sender<state::Extension> {
})
.collect();

ms.iter_mut().enumerate().for_each(|(i, blks)| {
let tweak: Block = bytemuck::cast([i, self.state.exec_counter]);
let tweaks = [tweak, tweak];
FIXED_KEY_AES.tccr_many(&tweaks, blks);
});

ms.iter_mut()
.enumerate()
.map(|(i, blks)| {
let tweak: Block = bytemuck::cast([i, self.state.exec_counter]);
let tweaks = [tweak, tweak];
FIXED_KEY_AES.tccr_many(&tweaks, blks);
blks
})
.zip(k0.iter().zip(k1.iter()))
.for_each(|([m0, m1], (k0, k1))| {
*m0 ^= *k0;
Expand Down Expand Up @@ -154,15 +155,15 @@ impl Sender<state::Extension> {
let CheckFromReceiver { x_prime } = checkfr;

if y_star.len() != CSP {
return Err(SenderError::InvalidLength(
"the length of y* should be 128".to_string(),
));
return Err(SenderError::InvalidLength(format!(
"the length of y* should be {CSP}"
)));
}

if x_prime.len() != CSP {
return Err(SenderError::InvalidLength(
"the length of x' should be 128".to_string(),
));
return Err(SenderError::InvalidLength(format!(
"the length of x' should be {CSP}"
)));
}

// Step 8 in Figure 6.
Expand Down Expand Up @@ -225,7 +226,6 @@ pub mod state {

/// The sender's initial state.
#[derive(Default)]
#[allow(missing_docs)]
pub struct Initialized {}

impl State for Initialized {}
Expand Down

0 comments on commit 92ef8f8

Please sign in to comment.