From 8ea4e75e714788b91f09f7cd596246b8037ce2fe Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:30:18 -0700 Subject: [PATCH 1/5] implement pre-garbling --- garble/mpz-garble-core/src/circuit.rs | 11 ++ garble/mpz-garble-core/src/lib.rs | 2 +- garble/mpz-garble/src/evaluator/error.rs | 2 + garble/mpz-garble/src/evaluator/mod.rs | 154 ++++++++++++++--- garble/mpz-garble/src/generator/error.rs | 4 +- garble/mpz-garble/src/generator/mod.rs | 131 +++++++++++---- garble/mpz-garble/src/lib.rs | 25 +++ garble/mpz-garble/src/protocol/deap/error.rs | 11 +- garble/mpz-garble/src/protocol/deap/memory.rs | 4 +- garble/mpz-garble/src/protocol/deap/mod.rs | 156 +++++++++++++++++- garble/mpz-garble/src/protocol/deap/vm.rs | 23 ++- garble/mpz-garble/src/value.rs | 7 + garble/mpz-garble/tests/offline-garble.rs | 136 +++++++++++++++ garble/mpz-garble/tests/semihonest.rs | 19 ++- 14 files changed, 610 insertions(+), 75 deletions(-) create mode 100644 garble/mpz-garble/tests/offline-garble.rs diff --git a/garble/mpz-garble-core/src/circuit.rs b/garble/mpz-garble-core/src/circuit.rs index 3852e909..14dffd8b 100644 --- a/garble/mpz-garble-core/src/circuit.rs +++ b/garble/mpz-garble-core/src/circuit.rs @@ -3,6 +3,8 @@ use std::ops::Index; use mpz_core::Block; use serde::{Deserialize, Serialize}; +use crate::EncodingCommitment; + /// Encrypted gate truth table /// /// For the half-gate garbling scheme a truth table will typically have 2 rows, except for in @@ -32,3 +34,12 @@ impl Index for EncryptedGate { &self.0[index] } } + +/// A garbled circuit +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GarbledCircuit { + /// Encrypted gates of the circuit + pub gates: Vec, + /// Encoding commitments of the circuit outputs + pub commitments: Option>, +} diff --git a/garble/mpz-garble-core/src/lib.rs b/garble/mpz-garble-core/src/lib.rs index f747a4fe..6de97cee 100644 --- a/garble/mpz-garble-core/src/lib.rs +++ b/garble/mpz-garble-core/src/lib.rs @@ -59,7 +59,7 @@ mod evaluator; mod generator; pub mod msg; -pub use circuit::EncryptedGate; +pub use circuit::{EncryptedGate, GarbledCircuit}; pub use encoding::{ state as encoding_state, ChaChaEncoder, Decoding, Delta, Encode, EncodedValue, Encoder, EncodingCommitment, EqualityCheck, Label, ValueError, diff --git a/garble/mpz-garble/src/evaluator/error.rs b/garble/mpz-garble/src/evaluator/error.rs index 3cff8974..cb13341d 100644 --- a/garble/mpz-garble/src/evaluator/error.rs +++ b/garble/mpz-garble/src/evaluator/error.rs @@ -21,6 +21,8 @@ pub enum EvaluatorError { EncodingRegistryError(#[from] crate::memory::EncodingMemoryError), #[error("missing active encoding for value")] MissingEncoding(ValueRef), + #[error("duplicate garbled circuit")] + DuplicateCircuit, #[error("duplicate decoding for value: {0:?}")] DuplicateDecoding(ValueId), #[error(transparent)] diff --git a/garble/mpz-garble/src/evaluator/mod.rs b/garble/mpz-garble/src/evaluator/mod.rs index fa933888..c4871fb2 100644 --- a/garble/mpz-garble/src/evaluator/mod.rs +++ b/garble/mpz-garble/src/evaluator/mod.rs @@ -17,6 +17,7 @@ use mpz_circuits::{ use mpz_core::hash::Hash; use mpz_garble_core::{ encoding_state, msg::GarbleMessage, Decoding, EncodedValue, Evaluator as EvaluatorCore, + GarbledCircuit, }; use utils::iter::FilterDrain; use utils_aio::{ @@ -27,7 +28,7 @@ use utils_aio::{ use crate::{ memory::EncodingMemory, ot::{OTReceiveEncoding, OTVerifyEncoding}, - value::{ValueId, ValueRef}, + value::{CircuitRefs, ValueId, ValueRef}, AssignedValues, Generator, GeneratorConfigBuilder, }; @@ -60,6 +61,10 @@ struct State { received_values: HashMap, /// Values which have been decoded decoded_values: HashSet, + /// Pre-transferred garbled circuits + /// + /// (inputs, outputs) => garbled circuit + garbled_circuits: HashMap, /// OT logs ot_log: HashMap>, /// Garbled circuit logs @@ -258,7 +263,68 @@ impl Evaluator { Ok(()) } - /// Evaluate a garbled circuit, receiving the encrypted gates in batches from the provided stream. + /// Receives a garbled circuit from the generator, storing it for later evaluation. + /// + /// # Arguments + /// + /// * `circ` - The circuit to receive + /// * `inputs` - The inputs to the circuit + /// * `outputs` - The outputs from the circuit + /// * `stream` - The stream from the generator + pub async fn receive_garbled_circuit< + S: Stream> + Unpin, + >( + &self, + circ: Arc, + inputs: &[ValueRef], + outputs: &[ValueRef], + stream: &mut S, + ) -> Result<(), EvaluatorError> { + let refs = CircuitRefs { + inputs: inputs.to_vec(), + outputs: outputs.to_vec(), + }; + + if self.state().garbled_circuits.contains_key(&refs) { + return Err(EvaluatorError::DuplicateCircuit); + } + + let gate_count = circ.and_count(); + let mut gates = Vec::with_capacity(gate_count); + while gates.len() < gate_count { + let encrypted_gates = expect_msg_or_err!(stream, GarbleMessage::EncryptedGates)?; + gates.extend(encrypted_gates); + } + + // If configured, expect the output encoding commitments + let encoding_commitments = if self.config.encoding_commitments { + let commitments = expect_msg_or_err!(stream, GarbleMessage::EncodingCommitments)?; + + // Make sure the generator sent the expected number of commitments. + if commitments.len() != circ.outputs().len() { + return Err(EvaluatorError::IncorrectValueCount { + expected: circ.outputs().len(), + actual: commitments.len(), + }); + } + + Some(commitments) + } else { + None + }; + + self.state().garbled_circuits.insert( + refs, + GarbledCircuit { + gates, + commitments: encoding_commitments, + }, + ); + + Ok(()) + } + + /// Evaluate a circuit. /// /// Returns the encoded outputs of the evaluated circuit. /// @@ -275,6 +341,11 @@ impl Evaluator { outputs: &[ValueRef], stream: &mut S, ) -> Result>, EvaluatorError> { + let refs = CircuitRefs { + inputs: inputs.to_vec(), + outputs: outputs.to_vec(), + }; + let encoded_inputs = { let state = self.state(); inputs @@ -294,11 +365,19 @@ impl Evaluator { EvaluatorCore::new(circ.clone(), &encoded_inputs)? }; - while !ev.is_complete() { - let encrypted_gates = expect_msg_or_err!(stream, GarbleMessage::EncryptedGates)?; - - for batch in encrypted_gates.chunks(self.config.batch_size) { - let batch = batch.to_vec(); + let existing_garbled_circuit = self.state().garbled_circuits.remove(&refs); + + // If we've already received the garbled circuit, we evaluate it, otherwise we stream the encrypted gates + // from the generator. + let encoded_outputs = if let Some(GarbledCircuit { + mut gates, + commitments, + }) = existing_garbled_circuit + { + while !ev.is_complete() { + let batch = gates + .drain(..gates.len().min(self.config.batch_size)) + .collect::>(); // Move the evaluator to a new thread to process the batch then send it back ev = Backend::spawn(move || { ev.evaluate(batch.iter()); @@ -306,27 +385,53 @@ impl Evaluator { }) .await; } - } - - let encoded_outputs = ev.outputs()?; - // If configured, expect the output encoding commitments - // from the generator and verify them. - if self.config.encoding_commitments { - let commitments = expect_msg_or_err!(stream, GarbleMessage::EncodingCommitments)?; + let encoded_outputs = ev.outputs()?; + if self.config.encoding_commitments { + for (output, commitment) in encoded_outputs + .iter() + .zip(commitments.expect("commitments were checked to be present")) + { + commitment.verify(output)?; + } + } - // Make sure the generator sent the expected number of commitments. - if commitments.len() != encoded_outputs.len() { - return Err(EvaluatorError::IncorrectValueCount { - expected: encoded_outputs.len(), - actual: commitments.len(), - }); + encoded_outputs + } else { + while !ev.is_complete() { + let mut gates = expect_msg_or_err!(stream, GarbleMessage::EncryptedGates)?; + while !gates.is_empty() { + let batch = gates + .drain(..gates.len().min(self.config.batch_size)) + .collect::>(); + // Move the evaluator to a new thread to process the batch then send it back + ev = Backend::spawn(move || { + ev.evaluate(batch.iter()); + ev + }) + .await; + } } - for (output, commitment) in encoded_outputs.iter().zip(commitments) { - commitment.verify(output)?; + let encoded_outputs = ev.outputs()?; + if self.config.encoding_commitments { + let commitments = expect_msg_or_err!(stream, GarbleMessage::EncodingCommitments)?; + + // Make sure the generator sent the expected number of commitments. + if commitments.len() != encoded_outputs.len() { + return Err(EvaluatorError::IncorrectValueCount { + expected: encoded_outputs.len(), + actual: commitments.len(), + }); + } + + for (output, commitment) in encoded_outputs.iter().zip(commitments) { + commitment.verify(output)?; + } } - } + + encoded_outputs + }; // Add the output encodings to the memory. let mut state = self.state(); @@ -417,8 +522,7 @@ impl Evaluator { // Generate encodings for all received values let received_values: Vec<(ValueId, ValueType)> = self.state().received_values.drain().collect(); - gen.generate_encodings(&received_values) - .map_err(VerificationError::from)?; + gen.generate_input_encodings_by_id(&received_values); // Verify all OTs in the log let mut ot_futs: FuturesUnordered<_> = self diff --git a/garble/mpz-garble/src/generator/error.rs b/garble/mpz-garble/src/generator/error.rs index 92d58165..7a7de38b 100644 --- a/garble/mpz-garble/src/generator/error.rs +++ b/garble/mpz-garble/src/generator/error.rs @@ -15,7 +15,9 @@ pub enum GeneratorError { IOError(#[from] std::io::Error), #[error(transparent)] ValueError(#[from] ValueError), - #[error("missing encoding for value")] + #[error("duplicate encoding for value: {0:?}")] + DuplicateEncoding(ValueRef), + #[error("missing encoding for value: {0:?}")] MissingEncoding(ValueRef), #[error(transparent)] EncodingRegistryError(#[from] crate::memory::EncodingMemoryError), diff --git a/garble/mpz-garble/src/generator/mod.rs b/garble/mpz-garble/src/generator/mod.rs index d2eaaa7d..5fdf62f2 100644 --- a/garble/mpz-garble/src/generator/mod.rs +++ b/garble/mpz-garble/src/generator/mod.rs @@ -4,7 +4,7 @@ mod config; mod error; use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, ops::DerefMut, sync::{Arc, Mutex}, }; @@ -24,7 +24,7 @@ use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; use crate::{ memory::EncodingMemory, ot::OTSendEncoding, - value::{ValueId, ValueRef}, + value::{CircuitRefs, ValueId, ValueRef}, AssignedValues, }; @@ -44,6 +44,8 @@ struct State { encoder: ChaChaEncoder, /// Encodings of values memory: EncodingMemory, + /// Transferred garbled circuits + garbled: HashMap>, /// The set of values that are currently active. /// /// A value is considered active when it has been encoded and sent to the evaluator. @@ -88,18 +90,29 @@ impl Generator { .collect::>>() } - /// Generate encodings for a slice of values - pub(crate) fn generate_encodings( - &self, - values: &[(ValueId, ValueType)], - ) -> Result<(), GeneratorError> { - let mut state = self.state(); + /// Generates encoding for the provided input value. + /// + /// If an encoding for a value have already been generated, it is ignored. + /// + /// # Panics + /// + /// If the provided value type does not match the value reference. + pub fn generate_input_encoding(&self, value: &ValueRef, typ: &ValueType) { + self.state().encode(value, typ); + } - for (id, ty) in values { - _ = state.encode_by_id(id, ty)?; + /// Generates encodings for the provided input values. + /// + /// If encodings for a value have already been generated, it is ignored. + /// + /// # Panics + /// + /// If the provided value type is an array + pub(crate) fn generate_input_encodings_by_id(&self, values: &[(ValueId, ValueType)]) { + let mut state = self.state(); + for (value_id, value_typ) in values { + state.encode_by_id(value_id, value_typ); } - - Ok(()) } /// Transfer active encodings for the provided assigned values. @@ -139,7 +152,7 @@ impl Generator { /// - `id` - The ID of this operation /// - `values` - The values to send /// - `ot` - The OT sender - pub async fn ot_send_active_encodings( + pub(crate) async fn ot_send_active_encodings( &self, id: &str, values: &[(ValueId, ValueType)], @@ -151,16 +164,16 @@ impl Generator { let full_encodings = { let mut state = self.state(); - // Filter out any values that are already active, setting them active otherwise. + // Filter out any values that are already active let mut values = values .iter() - .filter(|(id, _)| state.active.insert(id.clone())) + .filter(|(id, _)| !state.active.contains(id)) .collect::>(); - values.sort_by_key(|(id, _)| id.clone()); + values.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b)); values .iter() - .map(|(id, ty)| state.encode_by_id(id, ty)) + .map(|(id, _)| state.activate_encoding(id)) .collect::, GeneratorError>>()? }; @@ -175,7 +188,7 @@ impl Generator { /// /// - `values` - The values to send /// - `sink` - The sink to send the encodings to the evaluator - pub async fn direct_send_active_encodings< + pub(crate) async fn direct_send_active_encodings< S: Sink + Unpin, >( &self, @@ -188,17 +201,17 @@ impl Generator { let active_encodings = { let mut state = self.state(); - // Filter out any values that are already active, setting them active otherwise. + // Filter out any values that are already active let mut values = values .iter() - .filter(|(id, _)| state.active.insert(id.clone())) + .filter(|(id, _)| !state.active.contains(id)) .collect::>(); - values.sort_by_key(|(id, _)| id.clone()); + values.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b)); values .iter() .map(|(id, value)| { - let full_encoding = state.encode_by_id(id, &value.value_type())?; + let full_encoding = state.activate_encoding(id)?; Ok(full_encoding.select(value.clone())?) }) .collect::, GeneratorError>>()? @@ -229,8 +242,29 @@ impl Generator { sink: &mut S, hash: bool, ) -> Result<(Vec>, Option), GeneratorError> { + let refs = CircuitRefs { + inputs: inputs.to_vec(), + outputs: outputs.to_vec(), + }; let (delta, inputs) = { let state = self.state(); + + // If the circuit has already been garbled, return early + if let Some(hash) = state.garbled.get(&refs) { + return Ok(( + outputs + .iter() + .map(|output| { + state + .memory + .get_encoding(output) + .expect("encoding exists if circuit is garbled already") + }) + .collect(), + *hash, + )); + } + let delta = state.encoder.delta(); let inputs = inputs .iter() @@ -289,6 +323,8 @@ impl Generator { }); } + state.garbled.insert(refs, hash); + Ok((encoded_outputs, hash)) } @@ -331,12 +367,14 @@ impl State { } } - #[allow(dead_code)] - fn encode( - &mut self, - value: &ValueRef, - ty: &ValueType, - ) -> Result, GeneratorError> { + /// Generates an encoding for a value + /// + /// If an encoding for the value already exists, it is returned instead. + /// + /// # Panics + /// + /// If the provided value type does not match the value reference. + fn encode(&mut self, value: &ValueRef, ty: &ValueType) -> EncodedValue { match (value, ty) { (ValueRef::Value { id }, ty) if !ty.is_array() => self.encode_by_id(id, ty), (ValueRef::Array(array), ValueType::Array(elem_ty, len)) if array.len() == *len => { @@ -344,23 +382,44 @@ impl State { .ids() .iter() .map(|id| self.encode_by_id(id, elem_ty)) - .collect::, _>>()?; + .collect(); - Ok(EncodedValue::Array(encodings)) + EncodedValue::Array(encodings) } _ => panic!("invalid value and type combination: {:?} {:?}", value, ty), } } - fn encode_by_id( + /// Generates an encoding for a value + /// + /// If an encoding for the value already exists, it is returned instead. + fn encode_by_id(&mut self, id: &ValueId, ty: &ValueType) -> EncodedValue { + if let Some(encoding) = self.memory.get_encoding_by_id(id) { + encoding + } else { + let encoding = self.encoder.encode_by_type(id.to_u64(), ty); + self.memory + .set_encoding_by_id(id, encoding.clone()) + .expect("encoding does not already exist"); + encoding + } + } + + fn activate_encoding( &mut self, id: &ValueId, - ty: &ValueType, ) -> Result, GeneratorError> { - let encoding = self.encoder.encode_by_type(id.to_u64(), ty); - - // Returns error if the encoding already exists - self.memory.set_encoding_by_id(id, encoding.clone())?; + let encoding = self + .memory + .get_encoding_by_id(id) + .ok_or_else(|| GeneratorError::MissingEncoding(ValueRef::Value { id: id.clone() }))?; + + // Returns error if the encoding is already active + if !self.active.insert(id.clone()) { + return Err(GeneratorError::DuplicateEncoding(ValueRef::Value { + id: id.clone(), + })); + } Ok(encoding) } diff --git a/garble/mpz-garble/src/lib.rs b/garble/mpz-garble/src/lib.rs index a5527b38..8884095a 100644 --- a/garble/mpz-garble/src/lib.rs +++ b/garble/mpz-garble/src/lib.rs @@ -98,6 +98,17 @@ pub enum AssignmentError { }, } +/// Errors that can occur when loading a circuit. +#[derive(Debug, thiserror::Error)] +pub enum LoadError { + /// IO error. + #[error(transparent)] + IOError(#[from] std::io::Error), + /// Protocol error. + #[error(transparent)] + ProtocolError(#[from] Box), +} + /// Errors that can occur when executing a circuit. #[derive(Debug, thiserror::Error)] #[allow(missing_docs)] @@ -298,6 +309,20 @@ pub trait Memory { } } +/// This trait provides methods for loading a circuit. +/// +/// Implementations may perform pre-processing prior to execution. +#[async_trait] +pub trait Load { + /// Loads a circuit with the provided inputs and output values. + async fn load( + &mut self, + circ: Arc, + inputs: &[ValueRef], + outputs: &[ValueRef], + ) -> Result<(), LoadError>; +} + /// This trait provides methods for executing a circuit. #[async_trait] pub trait Execute { diff --git a/garble/mpz-garble/src/protocol/deap/error.rs b/garble/mpz-garble/src/protocol/deap/error.rs index 99783f73..6f880860 100644 --- a/garble/mpz-garble/src/protocol/deap/error.rs +++ b/garble/mpz-garble/src/protocol/deap/error.rs @@ -1,6 +1,6 @@ use mpz_garble_core::{msg::GarbleMessage, ValueError}; -use crate::{value::ValueRef, DecodeError, ExecutionError, ProveError, VerifyError}; +use crate::{value::ValueRef, DecodeError, ExecutionError, LoadError, ProveError, VerifyError}; /// Errors that can occur during the DEAP protocol. #[derive(Debug, thiserror::Error)] @@ -52,6 +52,15 @@ pub enum PeerEncodingsError { EncodingNotAvailable(ValueRef), } +impl From for LoadError { + fn from(err: DEAPError) -> Self { + match err { + DEAPError::IOError(err) => LoadError::IOError(err), + err => LoadError::ProtocolError(Box::new(err)), + } + } +} + impl From for ExecutionError { fn from(err: DEAPError) -> Self { match err { diff --git a/garble/mpz-garble/src/protocol/deap/memory.rs b/garble/mpz-garble/src/protocol/deap/memory.rs index 737a0d7b..902a3205 100644 --- a/garble/mpz-garble/src/protocol/deap/memory.rs +++ b/garble/mpz-garble/src/protocol/deap/memory.rs @@ -11,7 +11,9 @@ impl Memory for DEAP { typ: ValueType, visibility: Visibility, ) -> Result { - self.state().memory.new_input(id, typ, visibility) + let value_ref = self.state().memory.new_input(id, typ.clone(), visibility)?; + self.gen.generate_input_encoding(&value_ref, &typ); + Ok(value_ref) } fn new_output_with_type(&self, id: &str, typ: ValueType) -> Result { diff --git a/garble/mpz-garble/src/protocol/deap/mod.rs b/garble/mpz-garble/src/protocol/deap/mod.rs index 99c2c806..7ec418d2 100644 --- a/garble/mpz-garble/src/protocol/deap/mod.rs +++ b/garble/mpz-garble/src/protocol/deap/mod.rs @@ -126,6 +126,41 @@ impl DEAP { self.state.lock().unwrap() } + /// Performs pre-processing for executing the provided circuit. + /// + /// # Arguments + /// + /// * `circ` - The circuit to load. + /// * `inputs` - The inputs to the circuit. + /// * `outputs` - The outputs of the circuit. + /// * `sink` - The sink to send messages to. + /// * `stream` - The stream to receive messages from. + pub async fn load( + &self, + circ: Arc, + inputs: &[ValueRef], + outputs: &[ValueRef], + sink: &mut T, + stream: &mut U, + ) -> Result<(), DEAPError> + where + T: Sink + Unpin, + U: Stream> + Unpin, + { + // Generate and receive concurrently. + // Drop the encoded outputs, we don't need them here + _ = futures::try_join!( + self.gen + .generate(circ.clone(), inputs, outputs, sink, false) + .map_err(DEAPError::from), + self.ev + .receive_garbled_circuit(circ.clone(), inputs, outputs, stream) + .map_err(DEAPError::from) + )?; + + Ok(()) + } + /// Executes a circuit. /// /// # Arguments @@ -464,7 +499,7 @@ impl DEAP { state.new_private_otp(&format!("{id}/{idx}/otp"), value); let otp_typ = otp_value.value_type(); let mask_ref = state.new_output_mask(&format!("{id}/{idx}/mask"), value); - + self.gen.generate_input_encoding(&otp_ref, &otp_typ); (((otp_ref, otp_typ), otp_value), mask_ref) }) .unzip() @@ -520,7 +555,7 @@ impl DEAP { .map(|(idx, value)| { let (otp_ref, otp_typ) = state.new_blind_otp(&format!("{id}/{idx}/otp"), value); let mask_ref = state.new_output_mask(&format!("{id}/{idx}/mask"), value); - + self.gen.generate_input_encoding(&otp_ref, &otp_typ); ((otp_ref, otp_typ), mask_ref) }) .unzip() @@ -590,6 +625,8 @@ impl DEAP { } }; let mask_ref = state.new_output_mask(&format!("{id}/{idx}/mask"), value); + self.gen.generate_input_encoding(&otp_0_ref, &otp_typ); + self.gen.generate_input_encoding(&otp_1_ref, &otp_typ); ((((otp_0_ref, otp_1_ref), otp_typ), otp_value), mask_ref) }) .unzip() @@ -942,6 +979,121 @@ mod tests { assert_eq!(leader_output, follower_output); } + #[tokio::test] + async fn test_deap_load() { + let (leader_channel, follower_channel) = MemoryDuplex::::new(); + let (leader_ot_send, follower_ot_recv) = mock_ot_shared_pair(); + let (follower_ot_send, leader_ot_recv) = mock_ot_shared_pair(); + + let mut leader = DEAP::new(Role::Leader, [42u8; 32]); + let mut follower = DEAP::new(Role::Follower, [69u8; 32]); + + let key = [42u8; 16]; + let msg = [69u8; 16]; + + let leader_fut = { + let (mut sink, mut stream) = leader_channel.split(); + + let key_ref = leader.new_private_input::<[u8; 16]>("key").unwrap(); + let msg_ref = leader.new_blind_input::<[u8; 16]>("msg").unwrap(); + let ciphertext_ref = leader.new_output::<[u8; 16]>("ciphertext").unwrap(); + + async move { + leader + .load( + AES128.clone(), + &[key_ref.clone(), msg_ref.clone()], + &[ciphertext_ref.clone()], + &mut sink, + &mut stream, + ) + .await + .unwrap(); + + leader.assign(&key_ref, key).unwrap(); + + leader + .execute( + "test", + AES128.clone(), + &[key_ref, msg_ref], + &[ciphertext_ref.clone()], + &mut sink, + &mut stream, + &leader_ot_send, + &leader_ot_recv, + ) + .await + .unwrap(); + + let outputs = leader + .decode("test", &[ciphertext_ref], &mut sink, &mut stream) + .await + .unwrap(); + + leader + .finalize(&mut sink, &mut stream, &leader_ot_recv) + .await + .unwrap(); + + outputs + } + }; + + let follower_fut = { + let (mut sink, mut stream) = follower_channel.split(); + + let key_ref = follower.new_blind_input::<[u8; 16]>("key").unwrap(); + let msg_ref = follower.new_private_input::<[u8; 16]>("msg").unwrap(); + let ciphertext_ref = follower.new_output::<[u8; 16]>("ciphertext").unwrap(); + + async move { + follower + .load( + AES128.clone(), + &[key_ref.clone(), msg_ref.clone()], + &[ciphertext_ref.clone()], + &mut sink, + &mut stream, + ) + .await + .unwrap(); + + follower.assign(&msg_ref, msg).unwrap(); + + follower + .execute( + "test", + AES128.clone(), + &[key_ref, msg_ref], + &[ciphertext_ref.clone()], + &mut sink, + &mut stream, + &follower_ot_send, + &follower_ot_recv, + ) + .await + .unwrap(); + + let outputs = follower + .decode("test", &[ciphertext_ref], &mut sink, &mut stream) + .await + .unwrap(); + + follower + .finalize(&mut sink, &mut stream, &follower_ot_recv) + .await + .unwrap(); + + outputs + } + }; + + let (leader_output, follower_output) = tokio::join!(leader_fut, follower_fut); + + assert_eq!(leader_output, follower_output); + } + #[tokio::test] async fn test_deap_decode_private() { let (leader_channel, follower_channel) = MemoryDuplex::::new(); diff --git a/garble/mpz-garble/src/protocol/deap/vm.rs b/garble/mpz-garble/src/protocol/deap/vm.rs index 93750592..0e2831ca 100644 --- a/garble/mpz-garble/src/protocol/deap/vm.rs +++ b/garble/mpz-garble/src/protocol/deap/vm.rs @@ -21,8 +21,8 @@ use crate::{ config::{Role, Visibility}, ot::{VerifiableOTReceiveEncoding, VerifiableOTSendEncoding}, value::ValueRef, - Decode, DecodeError, DecodePrivate, Execute, ExecutionError, Memory, MemoryError, Prove, - ProveError, Thread, Verify, VerifyError, Vm, VmError, + Decode, DecodeError, DecodePrivate, Execute, ExecutionError, Load, LoadError, Memory, + MemoryError, Prove, ProveError, Thread, Verify, VerifyError, Vm, VmError, }; use super::{ @@ -237,6 +237,25 @@ impl Memory for DEAPThread { } } +#[async_trait] +impl Load for DEAPThread +where + OTS: VerifiableOTSendEncoding + Send + Sync, + OTR: VerifiableOTReceiveEncoding + Send + Sync, +{ + async fn load( + &mut self, + circ: Arc, + inputs: &[ValueRef], + outputs: &[ValueRef], + ) -> Result<(), LoadError> { + self.deap() + .load(circ, inputs, outputs, &mut self.sink, &mut self.stream) + .map_err(LoadError::from) + .await + } +} + #[async_trait] impl Execute for DEAPThread where diff --git a/garble/mpz-garble/src/value.rs b/garble/mpz-garble/src/value.rs index c6a14e8b..db6c60a4 100644 --- a/garble/mpz-garble/src/value.rs +++ b/garble/mpz-garble/src/value.rs @@ -150,3 +150,10 @@ impl<'a> Iterator for ValueRefIter<'a> { } } } + +/// References to the inputs and outputs of a circuit. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct CircuitRefs { + pub(crate) inputs: Vec, + pub(crate) outputs: Vec, +} diff --git a/garble/mpz-garble/tests/offline-garble.rs b/garble/mpz-garble/tests/offline-garble.rs new file mode 100644 index 00000000..c3e049d2 --- /dev/null +++ b/garble/mpz-garble/tests/offline-garble.rs @@ -0,0 +1,136 @@ +use mpz_circuits::{circuits::AES128, types::StaticValueType}; +use mpz_garble_core::msg::GarbleMessage; +use mpz_ot::mock::mock_ot_shared_pair; +use utils_aio::duplex::MemoryDuplex; + +use mpz_garble::{config::Visibility, Evaluator, Generator, GeneratorConfigBuilder, ValueMemory}; + +#[tokio::test] +async fn test_offline_garble() { + let (mut gen_channel, mut ev_channel) = MemoryDuplex::::new(); + let (ot_send, ot_recv) = mock_ot_shared_pair(); + + let gen = Generator::new( + GeneratorConfigBuilder::default().build().unwrap(), + [0u8; 32], + ); + let ev = Evaluator::default(); + + let key = [69u8; 16]; + let msg = [42u8; 16]; + + let key_typ = <[u8; 16]>::value_type(); + let msg_typ = <[u8; 16]>::value_type(); + let ciphertext_typ = <[u8; 16]>::value_type(); + + let gen_fut = async { + let mut memory = ValueMemory::default(); + + let key_ref = memory + .new_input("key", key_typ.clone(), Visibility::Private) + .unwrap(); + let msg_ref = memory + .new_input("msg", msg_typ.clone(), Visibility::Blind) + .unwrap(); + let ciphertext_ref = memory + .new_output("ciphertext", ciphertext_typ.clone()) + .unwrap(); + + gen.generate_input_encoding(&key_ref, &key_typ); + gen.generate_input_encoding(&msg_ref, &msg_typ); + + gen.generate( + AES128.clone(), + &[key_ref.clone(), msg_ref.clone()], + &[ciphertext_ref.clone()], + &mut gen_channel, + false, + ) + .await + .unwrap(); + + memory.assign(&key_ref, key.into()).unwrap(); + + gen.setup_assigned_values( + "test", + &memory.drain_assigned(&[key_ref.clone(), msg_ref.clone()]), + &mut gen_channel, + &ot_send, + ) + .await + .unwrap(); + + gen.get_encoding(&ciphertext_ref).unwrap() + }; + + let ev_fut = async { + let mut memory = ValueMemory::default(); + + let key_ref = memory + .new_input("key", key_typ.clone(), Visibility::Blind) + .unwrap(); + let msg_ref = memory + .new_input("msg", msg_typ.clone(), Visibility::Private) + .unwrap(); + let ciphertext_ref = memory + .new_output("ciphertext", ciphertext_typ.clone()) + .unwrap(); + + ev.receive_garbled_circuit( + AES128.clone(), + &[key_ref.clone(), msg_ref.clone()], + &[ciphertext_ref.clone()], + &mut ev_channel, + ) + .await + .unwrap(); + + memory.assign(&msg_ref, msg.into()).unwrap(); + + ev.setup_assigned_values( + "test", + &memory.drain_assigned(&[key_ref.clone(), msg_ref.clone()]), + &mut ev_channel, + &ot_recv, + ) + .await + .unwrap(); + + _ = ev + .evaluate( + AES128.clone(), + &[key_ref.clone(), msg_ref.clone()], + &[ciphertext_ref.clone()], + &mut ev_channel, + ) + .await + .unwrap(); + + ev.get_encoding(&ciphertext_ref).unwrap() + }; + + let (ciphertext_full_encoding, ciphertext_active_encoding) = tokio::join!(gen_fut, ev_fut); + + let decoding = ciphertext_full_encoding.decoding(); + let ciphertext: [u8; 16] = ciphertext_active_encoding + .decode(&decoding) + .unwrap() + .try_into() + .unwrap(); + + let expected: [u8; 16] = { + use aes::{ + cipher::{BlockEncrypt, KeyInit}, + Aes128, + }; + + let mut msg = msg.into(); + + let cipher = Aes128::new_from_slice(&key).unwrap(); + cipher.encrypt_block(&mut msg); + + msg.into() + }; + + assert_eq!(ciphertext, expected) +} diff --git a/garble/mpz-garble/tests/semihonest.rs b/garble/mpz-garble/tests/semihonest.rs index c0b449c8..a3f9c71d 100644 --- a/garble/mpz-garble/tests/semihonest.rs +++ b/garble/mpz-garble/tests/semihonest.rs @@ -19,21 +19,28 @@ async fn test_semi_honest() { let key = [69u8; 16]; let msg = [42u8; 16]; + let key_typ = <[u8; 16]>::value_type(); + let msg_typ = <[u8; 16]>::value_type(); + let ciphertext_typ = <[u8; 16]>::value_type(); + let gen_fut = async { let mut memory = ValueMemory::default(); let key_ref = memory - .new_input("key", <[u8; 16]>::value_type(), Visibility::Private) + .new_input("key", key_typ.clone(), Visibility::Private) .unwrap(); let msg_ref = memory - .new_input("msg", <[u8; 16]>::value_type(), Visibility::Blind) + .new_input("msg", msg_typ.clone(), Visibility::Blind) .unwrap(); let ciphertext_ref = memory - .new_output("ciphertext", <[u8; 16]>::value_type()) + .new_output("ciphertext", ciphertext_typ.clone()) .unwrap(); memory.assign(&key_ref, key.into()).unwrap(); + gen.generate_input_encoding(&key_ref, &key_typ); + gen.generate_input_encoding(&msg_ref, &msg_typ); + gen.setup_assigned_values( "test", &memory.drain_assigned(&[key_ref.clone(), msg_ref.clone()]), @@ -60,13 +67,13 @@ async fn test_semi_honest() { let mut memory = ValueMemory::default(); let key_ref = memory - .new_input("key", <[u8; 16]>::value_type(), Visibility::Blind) + .new_input("key", key_typ.clone(), Visibility::Blind) .unwrap(); let msg_ref = memory - .new_input("msg", <[u8; 16]>::value_type(), Visibility::Private) + .new_input("msg", msg_typ.clone(), Visibility::Private) .unwrap(); let ciphertext_ref = memory - .new_output("ciphertext", <[u8; 16]>::value_type()) + .new_output("ciphertext", ciphertext_typ.clone()) .unwrap(); memory.assign(&msg_ref, msg.into()).unwrap(); From bc81d8ca279a406c162d1d540ceb9e79d45f37b1 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:37:02 -0700 Subject: [PATCH 2/5] fix massive regression in evaluator --- garble/mpz-garble/src/evaluator/mod.rs | 32 +++++++++++--------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/garble/mpz-garble/src/evaluator/mod.rs b/garble/mpz-garble/src/evaluator/mod.rs index c4871fb2..65ca17d2 100644 --- a/garble/mpz-garble/src/evaluator/mod.rs +++ b/garble/mpz-garble/src/evaluator/mod.rs @@ -369,21 +369,19 @@ impl Evaluator { // If we've already received the garbled circuit, we evaluate it, otherwise we stream the encrypted gates // from the generator. - let encoded_outputs = if let Some(GarbledCircuit { - mut gates, - commitments, - }) = existing_garbled_circuit + let encoded_outputs = if let Some(GarbledCircuit { gates, commitments }) = + existing_garbled_circuit { while !ev.is_complete() { - let batch = gates - .drain(..gates.len().min(self.config.batch_size)) - .collect::>(); - // Move the evaluator to a new thread to process the batch then send it back - ev = Backend::spawn(move || { - ev.evaluate(batch.iter()); - ev - }) - .await; + for batch in gates.chunks(self.config.batch_size) { + let batch = batch.to_vec(); + // Move the evaluator to a new thread to process the batch then send it back + ev = Backend::spawn(move || { + ev.evaluate(batch.iter()); + ev + }) + .await; + } } let encoded_outputs = ev.outputs()?; @@ -399,11 +397,9 @@ impl Evaluator { encoded_outputs } else { while !ev.is_complete() { - let mut gates = expect_msg_or_err!(stream, GarbleMessage::EncryptedGates)?; - while !gates.is_empty() { - let batch = gates - .drain(..gates.len().min(self.config.batch_size)) - .collect::>(); + let gates = expect_msg_or_err!(stream, GarbleMessage::EncryptedGates)?; + for batch in gates.chunks(self.config.batch_size) { + let batch = batch.to_vec(); // Move the evaluator to a new thread to process the batch then send it back ev = Backend::spawn(move || { ev.evaluate(batch.iter()); From 9d324a76350a1debe97296ff7d5515dd4d0897a7 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:50:16 -0700 Subject: [PATCH 3/5] remove unnecessary while loop --- garble/mpz-garble/src/evaluator/mod.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/garble/mpz-garble/src/evaluator/mod.rs b/garble/mpz-garble/src/evaluator/mod.rs index 65ca17d2..30aa67ce 100644 --- a/garble/mpz-garble/src/evaluator/mod.rs +++ b/garble/mpz-garble/src/evaluator/mod.rs @@ -372,16 +372,14 @@ impl Evaluator { let encoded_outputs = if let Some(GarbledCircuit { gates, commitments }) = existing_garbled_circuit { - while !ev.is_complete() { - for batch in gates.chunks(self.config.batch_size) { - let batch = batch.to_vec(); - // Move the evaluator to a new thread to process the batch then send it back - ev = Backend::spawn(move || { - ev.evaluate(batch.iter()); - ev - }) - .await; - } + for batch in gates.chunks(self.config.batch_size) { + let batch = batch.to_vec(); + // Move the evaluator to a new thread to process the batch then send it back + ev = Backend::spawn(move || { + ev.evaluate(batch.iter()); + ev + }) + .await; } let encoded_outputs = ev.outputs()?; From c87d88ea014ab40bc853db0b037e4f13a3fb40cf Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Thu, 19 Oct 2023 08:48:36 -0700 Subject: [PATCH 4/5] remove batching from evaluator --- garble/mpz-garble/src/evaluator/config.rs | 3 --- garble/mpz-garble/src/evaluator/mod.rs | 28 ++++++++--------------- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/garble/mpz-garble/src/evaluator/config.rs b/garble/mpz-garble/src/evaluator/config.rs index f84161ef..b0fdfd43 100644 --- a/garble/mpz-garble/src/evaluator/config.rs +++ b/garble/mpz-garble/src/evaluator/config.rs @@ -12,9 +12,6 @@ pub struct EvaluatorConfig { /// Whether to log decodings. #[builder(default = "false", setter(custom))] pub(crate) log_decodings: bool, - /// The number of encrypted gates to evaluate per batch. - #[builder(default = "1024")] - pub(crate) batch_size: usize, } impl EvaluatorConfig { diff --git a/garble/mpz-garble/src/evaluator/mod.rs b/garble/mpz-garble/src/evaluator/mod.rs index 30aa67ce..0d5fc102 100644 --- a/garble/mpz-garble/src/evaluator/mod.rs +++ b/garble/mpz-garble/src/evaluator/mod.rs @@ -372,15 +372,11 @@ impl Evaluator { let encoded_outputs = if let Some(GarbledCircuit { gates, commitments }) = existing_garbled_circuit { - for batch in gates.chunks(self.config.batch_size) { - let batch = batch.to_vec(); - // Move the evaluator to a new thread to process the batch then send it back - ev = Backend::spawn(move || { - ev.evaluate(batch.iter()); - ev - }) - .await; - } + ev = Backend::spawn(move || { + ev.evaluate(gates.iter()); + ev + }) + .await; let encoded_outputs = ev.outputs()?; if self.config.encoding_commitments { @@ -396,15 +392,11 @@ impl Evaluator { } else { while !ev.is_complete() { let gates = expect_msg_or_err!(stream, GarbleMessage::EncryptedGates)?; - for batch in gates.chunks(self.config.batch_size) { - let batch = batch.to_vec(); - // Move the evaluator to a new thread to process the batch then send it back - ev = Backend::spawn(move || { - ev.evaluate(batch.iter()); - ev - }) - .await; - } + ev = Backend::spawn(move || { + ev.evaluate(gates.iter()); + ev + }) + .await; } let encoded_outputs = ev.outputs()?; From c1e39f9e15cf5341637433643690ee659bf09a72 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 24 Oct 2023 12:55:17 -0700 Subject: [PATCH 5/5] small fixes from review --- garble/mpz-garble/src/evaluator/mod.rs | 2 +- garble/mpz-garble/src/generator/mod.rs | 6 ++---- garble/mpz-garble/src/lib.rs | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/garble/mpz-garble/src/evaluator/mod.rs b/garble/mpz-garble/src/evaluator/mod.rs index 0d5fc102..efdd2021 100644 --- a/garble/mpz-garble/src/evaluator/mod.rs +++ b/garble/mpz-garble/src/evaluator/mod.rs @@ -63,7 +63,7 @@ struct State { decoded_values: HashSet, /// Pre-transferred garbled circuits /// - /// (inputs, outputs) => garbled circuit + /// A map used to look up a garbled circuit by its unique (inputs, outputs) reference. garbled_circuits: HashMap, /// OT logs ot_log: HashMap>, diff --git a/garble/mpz-garble/src/generator/mod.rs b/garble/mpz-garble/src/generator/mod.rs index 5fdf62f2..22584a4b 100644 --- a/garble/mpz-garble/src/generator/mod.rs +++ b/garble/mpz-garble/src/generator/mod.rs @@ -45,6 +45,8 @@ struct State { /// Encodings of values memory: EncodingMemory, /// Transferred garbled circuits + /// + /// Each circuit is uniquely identified by its (input, output) references. Optionally, the garbled circuit may have been hashed. garbled: HashMap>, /// The set of values that are currently active. /// @@ -370,10 +372,6 @@ impl State { /// Generates an encoding for a value /// /// If an encoding for the value already exists, it is returned instead. - /// - /// # Panics - /// - /// If the provided value type does not match the value reference. fn encode(&mut self, value: &ValueRef, ty: &ValueType) -> EncodedValue { match (value, ty) { (ValueRef::Value { id }, ty) if !ty.is_array() => self.encode_by_id(id, ty), diff --git a/garble/mpz-garble/src/lib.rs b/garble/mpz-garble/src/lib.rs index 8884095a..042235a2 100644 --- a/garble/mpz-garble/src/lib.rs +++ b/garble/mpz-garble/src/lib.rs @@ -314,7 +314,7 @@ pub trait Memory { /// Implementations may perform pre-processing prior to execution. #[async_trait] pub trait Load { - /// Loads a circuit with the provided inputs and output values. + /// Loads a circuit with the provided input and output values. async fn load( &mut self, circ: Arc,