diff --git a/garble/mpz-garble/Cargo.toml b/garble/mpz-garble/Cargo.toml index f746c6a6..89d6b699 100644 --- a/garble/mpz-garble/Cargo.toml +++ b/garble/mpz-garble/Cargo.toml @@ -31,6 +31,7 @@ aes = { workspace = true } rayon = { workspace = true } derive_builder.workspace = true itybity.workspace = true +opaque-debug.workspace = true [dev-dependencies] mpz-ot = { workspace = true, features = ["mock"] } diff --git a/garble/mpz-garble/benches/deap.rs b/garble/mpz-garble/benches/deap.rs index d4eb2a79..febcb078 100644 --- a/garble/mpz-garble/benches/deap.rs +++ b/garble/mpz-garble/benches/deap.rs @@ -16,14 +16,12 @@ async fn bench_deap() { let msg = [0u8; 16]; let leader_fut = { - let key_ref = leader_thread - .new_private_input::<[u8; 16]>("key", Some(key)) - .unwrap(); - let msg_ref = leader_thread - .new_private_input::<[u8; 16]>("msg", None) - .unwrap(); + let key_ref = leader_thread.new_private_input::<[u8; 16]>("key").unwrap(); + let msg_ref = leader_thread.new_blind_input::<[u8; 16]>("msg").unwrap(); let ciphertext_ref = leader_thread.new_output::<[u8; 16]>("ciphertext").unwrap(); + leader_thread.assign(&key_ref, key).unwrap(); + async { leader_thread .execute( @@ -41,16 +39,16 @@ async fn bench_deap() { }; let follower_fut = { - let key_ref = follower_thread - .new_private_input::<[u8; 16]>("key", None) - .unwrap(); + let key_ref = follower_thread.new_blind_input::<[u8; 16]>("key").unwrap(); let msg_ref = follower_thread - .new_private_input::<[u8; 16]>("msg", Some(msg)) + .new_private_input::<[u8; 16]>("msg") .unwrap(); let ciphertext_ref = follower_thread .new_output::<[u8; 16]>("ciphertext") .unwrap(); + follower_thread.assign(&msg_ref, msg).unwrap(); + async { follower_thread .execute( @@ -75,10 +73,13 @@ fn bench_aes_leader( block: usize, ) -> Pin> + Send + '_>> { Box::pin(async move { - let key = thread.new_private_input(&format!("key/{block}"), Some([0u8; 16]))?; - let msg = thread.new_private_input(&format!("msg/{block}"), Some([0u8; 16]))?; + let key = thread.new_private_input::<[u8; 16]>(&format!("key/{block}"))?; + let msg = thread.new_private_input::<[u8; 16]>(&format!("msg/{block}"))?; let ciphertext = thread.new_output::<[u8; 16]>(&format!("ciphertext/{block}"))?; + thread.assign(&key, [0u8; 16])?; + thread.assign(&msg, [0u8; 16])?; + thread .execute(AES128.clone(), &[key, msg], &[ciphertext.clone()]) .await?; @@ -94,8 +95,8 @@ fn bench_aes_follower( block: usize, ) -> Pin> + Send + '_>> { Box::pin(async move { - let key = thread.new_private_input::<[u8; 16]>(&format!("key/{block}"), None)?; - let msg = thread.new_private_input::<[u8; 16]>(&format!("msg/{block}"), None)?; + let key = thread.new_blind_input::<[u8; 16]>(&format!("key/{block}"))?; + let msg = thread.new_blind_input::<[u8; 16]>(&format!("msg/{block}"))?; let ciphertext = thread.new_output::<[u8; 16]>(&format!("ciphertext/{block}"))?; thread diff --git a/garble/mpz-garble/src/config.rs b/garble/mpz-garble/src/config.rs index 3a005d6e..179a68a8 100644 --- a/garble/mpz-garble/src/config.rs +++ b/garble/mpz-garble/src/config.rs @@ -1,8 +1,5 @@ //! Various configuration used in the protocol -use mpz_circuits::types::{StaticValueType, Value, ValueType}; -use mpz_core::value::{ValueId, ValueRef}; - /// Role in 2PC. #[derive(Debug, Clone, Copy, PartialEq)] #[allow(missing_docs)] @@ -11,224 +8,13 @@ pub enum Role { Follower, } -#[derive(Debug)] -#[allow(missing_docs)] -#[allow(dead_code)] -pub struct ValueConfigError { - value_ref: ValueRef, - ty: ValueType, - value: Option, - visibility: Visibility, -} - +/// Visibility of a value #[derive(Debug, Clone, Copy)] -pub(crate) enum Visibility { +pub enum Visibility { + /// A value known to all parties Public, + /// A private value known to this party. Private, -} - -/// Configuration of a value -#[derive(Debug, Clone)] -#[allow(missing_docs)] -pub enum ValueConfig { - /// A value known to all parties - Public { - value_ref: ValueRef, - ty: ValueType, - value: Value, - }, - /// A private value - Private { - value_ref: ValueRef, - ty: ValueType, - value: Option, - }, -} - -/// Configuration of a value -#[derive(Debug, Clone)] -#[allow(missing_docs)] -pub enum ValueIdConfig { - /// A value known to all parties - Public { - id: ValueId, - ty: ValueType, - value: Value, - }, - /// A private value - Private { - id: ValueId, - ty: ValueType, - value: Option, - }, -} - -impl ValueConfig { - /// Creates a new public value config - pub fn new_public( - value_ref: ValueRef, - value: impl Into, - ) -> Result { - let value = value.into(); - let ty = value.value_type(); - Self::new(value_ref, ty, Some(value), Visibility::Public) - } - - /// Creates a new public array value config - pub fn new_public_array( - value_ref: ValueRef, - value: Vec, - ) -> Result - where - Vec: Into, - { - let value = value.into(); - let ty = value.value_type(); - Self::new(value_ref, ty, Some(value), Visibility::Public) - } - - /// Creates a new private value config - pub fn new_private( - value_ref: ValueRef, - value: Option, - ) -> Result { - let ty = T::value_type(); - let value = value.map(|value| value.into()); - Self::new(value_ref, ty, value, Visibility::Private) - } - - /// Creates a new private array value config - pub fn new_private_array( - value_ref: ValueRef, - value: Option>, - len: usize, - ) -> Result - where - Vec: Into, - { - let ty = ValueType::new_array::(len); - let value = value.map(|value| value.into()); - Self::new(value_ref, ty, value, Visibility::Private) - } - - /// Creates a new value config - pub(crate) fn new( - value_ref: ValueRef, - ty: ValueType, - value: Option, - visibility: Visibility, - ) -> Result { - // invariants: - // - public values are always set - // - types and lengths are consistent across `value_ref`, `ty`, and `value` - // - // the outer context must ensure that the provided `ty` is correct for the - // provided `value_ref`. - let is_ok = if !value_ref.is_array() && !ty.is_array() { - true - } else if let (ValueRef::Array(ids), ValueType::Array(_, len)) = (&value_ref, &ty) { - ids.len() == *len - } else { - false - }; - - match visibility { - Visibility::Public if is_ok && value.is_some() => Ok(Self::Public { - value_ref, - ty, - value: value.unwrap(), - }), - Visibility::Private if is_ok => Ok(Self::Private { - value_ref, - ty, - value, - }), - _ => Err(ValueConfigError { - value_ref, - ty, - value, - visibility, - }), - } - } - - /// Flattens to a vector of `ValueIdConfig` - pub fn flatten(self) -> Vec { - match self { - ValueConfig::Public { - value_ref, - ty, - value, - } => match value_ref { - ValueRef::Value { id } => { - vec![ValueIdConfig::Public { id, ty, value }] - } - ValueRef::Array(ids) => { - let ValueType::Array(elem_ty, _) = ty else { - panic!("expected array type"); - }; - - let elem_ty = *elem_ty; - - let Value::Array(value) = value else { - panic!("expected array value"); - }; - - ids.into_iter() - .zip(value) - .map(|(id, value)| ValueIdConfig::Public { - id, - ty: elem_ty.clone(), - value, - }) - .collect() - } - }, - ValueConfig::Private { - value_ref, - ty, - value, - } => match value_ref { - ValueRef::Value { id } => { - vec![ValueIdConfig::Private { id, ty, value }] - } - ValueRef::Array(ids) => { - let ValueType::Array(elem_ty, _) = ty else { - panic!("expected array type"); - }; - - let elem_ty = *elem_ty; - - let values = if let Some(value) = value { - let Value::Array(value) = value else { - panic!("expected array value"); - }; - - value.into_iter().map(Option::Some).collect() - } else { - vec![None; ids.len()] - }; - - ids.into_iter() - .zip(values) - .map(|(id, value)| ValueIdConfig::Private { - id, - ty: elem_ty.clone(), - value, - }) - .collect() - } - }, - } - } -} - -impl ValueIdConfig { - /// Returns the ID of the value - pub(crate) fn id(&self) -> &ValueId { - match self { - ValueIdConfig::Public { id, .. } => id, - ValueIdConfig::Private { id, .. } => id, - } - } + /// A private value not known to this party. + Blind, } diff --git a/garble/mpz-garble/src/evaluator/error.rs b/garble/mpz-garble/src/evaluator/error.rs index 2c95bbba..3cff8974 100644 --- a/garble/mpz-garble/src/evaluator/error.rs +++ b/garble/mpz-garble/src/evaluator/error.rs @@ -1,4 +1,4 @@ -use mpz_core::value::{ValueId, ValueRef}; +use crate::value::{ValueId, ValueRef}; /// Errors that can occur while performing the role of an evaluator #[derive(Debug, thiserror::Error)] @@ -18,7 +18,7 @@ pub enum EvaluatorError { #[error(transparent)] ValueError(#[from] mpz_garble_core::ValueError), #[error(transparent)] - EncodingRegistryError(#[from] crate::registry::EncodingRegistryError), + EncodingRegistryError(#[from] crate::memory::EncodingMemoryError), #[error("missing active encoding for value")] MissingEncoding(ValueRef), #[error("duplicate decoding for value: {0:?}")] diff --git a/garble/mpz-garble/src/evaluator/mod.rs b/garble/mpz-garble/src/evaluator/mod.rs index 4e752140..fa933888 100644 --- a/garble/mpz-garble/src/evaluator/mod.rs +++ b/garble/mpz-garble/src/evaluator/mod.rs @@ -14,10 +14,7 @@ use mpz_circuits::{ types::{TypeError, Value, ValueType}, Circuit, }; -use mpz_core::{ - hash::Hash, - value::{ValueId, ValueRef}, -}; +use mpz_core::hash::Hash; use mpz_garble_core::{ encoding_state, msg::GarbleMessage, Decoding, EncodedValue, Evaluator as EvaluatorCore, }; @@ -28,10 +25,10 @@ use utils_aio::{ }; use crate::{ - config::ValueIdConfig, + memory::EncodingMemory, ot::{OTReceiveEncoding, OTVerifyEncoding}, - registry::EncodingRegistry, - Generator, GeneratorConfigBuilder, + value::{ValueId, ValueRef}, + AssignedValues, Generator, GeneratorConfigBuilder, }; pub use config::{EvaluatorConfig, EvaluatorConfigBuilder}; @@ -58,7 +55,7 @@ impl Default for Evaluator { #[derive(Debug, Default)] struct State { /// Encodings of values - encoding_registry: EncodingRegistry, + memory: EncodingMemory, /// Encoded values which were received either directly or via OT received_values: HashMap, /// Values which have been decoded @@ -106,7 +103,7 @@ impl Evaluator { /// Returns the encoding for a value. pub fn get_encoding(&self, value: &ValueRef) -> Option> { - self.state().encoding_registry.get_encoding(value) + self.state().memory.get_encoding(value) } /// Adds a decoding log entry. @@ -114,57 +111,47 @@ impl Evaluator { self.state().decoding_logs.insert(value.clone(), decoding); } - /// Setup input values by receiving the encodings from the generator - /// either directly or via oblivious transfer. + /// Transfer encodings for the provided assigned values. /// /// # Arguments /// - /// * `id` - The id of this operation - /// * `input_configs` - The inputs to setup - /// * `stream` - The stream of messages from the generator - /// * `ot` - The oblivious transfer receiver - pub async fn setup_inputs< + /// - `id` - The id of this operation + /// - `values` - The assigned values + /// - `stream` - The stream to receive the encodings from the generator + /// - `ot` - The OT receiver + pub async fn setup_assigned_values< S: Stream> + Unpin, OT: OTReceiveEncoding, >( &self, id: &str, - input_configs: &[ValueIdConfig], + values: &AssignedValues, stream: &mut S, ot: &OT, ) -> Result<(), EvaluatorError> { - let (ot_recv_values, direct_recv_values) = { + // Filter out any values that are already active. + let (mut ot_recv_values, mut direct_recv_values) = { let state = self.state(); - - // Filter out any values that are already active. - let mut input_configs: Vec = input_configs + let ot_recv_values = values + .private .iter() - .filter(|config| !state.encoding_registry.contains(config.id())) + .filter(|(id, _)| !state.memory.contains(id)) .cloned() - .collect(); - - input_configs.sort_by_key(|config| config.id().clone()); - - let mut ot_recv_values = Vec::new(); - let mut direct_recv_values = Vec::new(); - for config in input_configs.into_iter() { - match config { - ValueIdConfig::Public { id, ty, .. } => { - direct_recv_values.push((id, ty)); - } - ValueIdConfig::Private { id, ty, value } => { - if let Some(value) = value { - ot_recv_values.push((id, value)); - } else { - direct_recv_values.push((id, ty)); - } - } - } - } + .collect::>(); + let direct_recv_values = values + .public + .iter() + .map(|(id, value)| (id.clone(), value.value_type())) + .chain(values.blind.clone()) + .filter(|(id, _)| !state.memory.contains(id)) + .collect::>(); (ot_recv_values, direct_recv_values) }; + ot_recv_values.sort_by(|(id1, _), (id2, _)| id1.cmp(id2)); + direct_recv_values.sort_by(|(id1, _), (id2, _)| id1.cmp(id2)); + futures::try_join!( self.ot_receive_active_encodings(id, &ot_recv_values, ot), self.direct_receive_active_encodings(&direct_recv_values, stream) @@ -179,7 +166,7 @@ impl Evaluator { /// - `id` - The id of this operation /// - `values` - The values to receive via oblivious transfer. /// - `ot` - The oblivious transfer receiver - async fn ot_receive_active_encodings( + pub async fn ot_receive_active_encodings( &self, id: &str, values: &[(ValueId, Value)], @@ -218,10 +205,8 @@ impl Evaluator { actual: active_encoding.value_type(), })?; } - // Add the received values to the encoding registry. - state - .encoding_registry - .set_encoding_by_id(id, active_encoding)?; + // Add the received values to the memory. + state.memory.set_encoding_by_id(id, active_encoding)?; state.received_values.insert(id.clone(), expected_ty); } @@ -233,7 +218,7 @@ impl Evaluator { /// # Arguments /// - `values` - The values and types expected to be received /// - `stream` - The stream of messages from the generator - async fn direct_receive_active_encodings< + pub async fn direct_receive_active_encodings< S: Stream> + Unpin, >( &self, @@ -263,10 +248,8 @@ impl Evaluator { actual: active_encoding.value_type(), })?; } - // Add the received values to the encoding registry. - state - .encoding_registry - .set_encoding_by_id(id, active_encoding)?; + // Add the received values to the memory. + state.memory.set_encoding_by_id(id, active_encoding)?; state .received_values .insert(id.clone(), expected_ty.clone()); @@ -298,7 +281,7 @@ impl Evaluator { .iter() .map(|value_ref| { state - .encoding_registry + .memory .get_encoding(value_ref) .ok_or_else(|| EvaluatorError::MissingEncoding(value_ref.clone())) }) @@ -345,12 +328,10 @@ impl Evaluator { } } - // Add the output encodings to the encoding registry. + // Add the output encodings to the memory. let mut state = self.state(); for (output, encoding) in outputs.iter().zip(encoded_outputs.iter()) { - state - .encoding_registry - .set_encoding(output, encoding.clone())?; + state.memory.set_encoding(output, encoding.clone())?; } // If configured, log the circuit evaluation diff --git a/garble/mpz-garble/src/generator/error.rs b/garble/mpz-garble/src/generator/error.rs index 9a58eb73..92d58165 100644 --- a/garble/mpz-garble/src/generator/error.rs +++ b/garble/mpz-garble/src/generator/error.rs @@ -1,6 +1,7 @@ -use mpz_core::value::ValueRef; use mpz_garble_core::ValueError; +use crate::value::ValueRef; + /// Errors that can occur while performing the role of a generator #[derive(Debug, thiserror::Error)] #[allow(missing_docs)] @@ -17,7 +18,7 @@ pub enum GeneratorError { #[error("missing encoding for value")] MissingEncoding(ValueRef), #[error(transparent)] - EncodingRegistryError(#[from] crate::registry::EncodingRegistryError), + EncodingRegistryError(#[from] crate::memory::EncodingMemoryError), } impl From for GeneratorError { diff --git a/garble/mpz-garble/src/generator/mod.rs b/garble/mpz-garble/src/generator/mod.rs index 8cb323c2..d2eaaa7d 100644 --- a/garble/mpz-garble/src/generator/mod.rs +++ b/garble/mpz-garble/src/generator/mod.rs @@ -14,17 +14,19 @@ use mpz_circuits::{ types::{Value, ValueType}, Circuit, }; -use mpz_core::{ - hash::Hash, - value::{ValueId, ValueRef}, -}; +use mpz_core::hash::Hash; use mpz_garble_core::{ encoding_state, msg::GarbleMessage, ChaChaEncoder, EncodedValue, Encoder, Generator as GeneratorCore, }; use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; -use crate::{config::ValueIdConfig, ot::OTSendEncoding, registry::EncodingRegistry}; +use crate::{ + memory::EncodingMemory, + ot::OTSendEncoding, + value::{ValueId, ValueRef}, + AssignedValues, +}; pub use config::{GeneratorConfig, GeneratorConfigBuilder}; pub use error::GeneratorError; @@ -41,7 +43,7 @@ struct State { /// The encoder used to encode values encoder: ChaChaEncoder, /// Encodings of values - encoding_registry: EncodingRegistry, + memory: EncodingMemory, /// The set of values that are currently active. /// /// A value is considered active when it has been encoded and sent to the evaluator. @@ -72,7 +74,7 @@ impl Generator { /// Returns the encoding for a value. pub fn get_encoding(&self, value: &ValueRef) -> Option> { - self.state().encoding_registry.get_encoding(value) + self.state().memory.get_encoding(value) } pub(crate) fn get_encodings_by_id( @@ -82,7 +84,7 @@ impl Generator { let state = self.state(); ids.iter() - .map(|id| state.encoding_registry.get_encoding_by_id(id)) + .map(|id| state.memory.get_encoding_by_id(id)) .collect::>>() } @@ -100,41 +102,27 @@ impl Generator { Ok(()) } - /// Setup input values by transferring the encodings to the evaluator - /// either directly or via oblivious transfer. + /// Transfer active encodings for the provided assigned values. /// /// # Arguments /// - /// * `id` - The ID of this operation - /// * `input_configs` - The inputs to set up - /// * `sink` - The sink to send the encodings to the evaluator - /// * `ot` - The OT sender. - pub async fn setup_inputs< + /// - `id` - The ID of this operation + /// - `values` - The assigned values + /// - `sink` - The sink to send the encodings to the evaluator + /// - `ot` - The OT sender + pub async fn setup_assigned_values< S: Sink + Unpin, OT: OTSendEncoding, >( &self, id: &str, - input_configs: &[ValueIdConfig], + values: &AssignedValues, sink: &mut S, ot: &OT, ) -> Result<(), GeneratorError> { - let mut ot_send_values = Vec::new(); - let mut direct_send_values = Vec::new(); - for config in input_configs.iter().cloned() { - match config { - ValueIdConfig::Public { id, value, .. } => { - direct_send_values.push((id, value)); - } - ValueIdConfig::Private { id, value, ty } => { - if let Some(value) = value { - direct_send_values.push((id, value)); - } else { - ot_send_values.push((id, ty)); - } - } - } - } + let ot_send_values = values.blind.clone(); + let mut direct_send_values = values.public.clone(); + direct_send_values.extend(values.private.iter().cloned()); futures::try_join!( self.ot_send_active_encodings(id, &ot_send_values, ot), @@ -151,7 +139,7 @@ impl Generator { /// - `id` - The ID of this operation /// - `values` - The values to send /// - `ot` - The OT sender - async fn ot_send_active_encodings( + pub async fn ot_send_active_encodings( &self, id: &str, values: &[(ValueId, ValueType)], @@ -187,7 +175,7 @@ impl Generator { /// /// - `values` - The values to send /// - `sink` - The sink to send the encodings to the evaluator - async fn direct_send_active_encodings< + pub async fn direct_send_active_encodings< S: Sink + Unpin, >( &self, @@ -248,7 +236,7 @@ impl Generator { .iter() .map(|value| { state - .encoding_registry + .memory .get_encoding(value) .ok_or(GeneratorError::MissingEncoding(value.clone())) }) @@ -292,12 +280,10 @@ impl Generator { .await?; } - // Add the outputs to the encoding registry and set as active. + // Add the outputs to the memory and set as active. let mut state = self.state(); for (output, encoding) in outputs.iter().zip(encoded_outputs.iter()) { - state - .encoding_registry - .set_encoding(output, encoding.clone())?; + state.memory.set_encoding(output, encoding.clone())?; output.iter().for_each(|id| { state.active.insert(id.clone()); }); @@ -323,7 +309,7 @@ impl Generator { .iter() .map(|value| { state - .encoding_registry + .memory .get_encoding(value) .ok_or(GeneratorError::MissingEncoding(value.clone())) .map(|encoding| encoding.decoding()) @@ -353,8 +339,9 @@ impl State { ) -> Result, GeneratorError> { match (value, ty) { (ValueRef::Value { id }, ty) if !ty.is_array() => self.encode_by_id(id, ty), - (ValueRef::Array(ids), ValueType::Array(elem_ty, len)) if ids.len() == *len => { - let encodings = ids + (ValueRef::Array(array), ValueType::Array(elem_ty, len)) if array.len() == *len => { + let encodings = array + .ids() .iter() .map(|id| self.encode_by_id(id, elem_ty)) .collect::, _>>()?; @@ -373,8 +360,7 @@ impl State { let encoding = self.encoder.encode_by_type(id.to_u64(), ty); // Returns error if the encoding already exists - self.encoding_registry - .set_encoding_by_id(id, encoding.clone())?; + self.memory.set_encoding_by_id(id, encoding.clone())?; Ok(encoding) } diff --git a/garble/mpz-garble/src/lib.rs b/garble/mpz-garble/src/lib.rs index 5d18a5ee..a5527b38 100644 --- a/garble/mpz-garble/src/lib.rs +++ b/garble/mpz-garble/src/lib.rs @@ -8,27 +8,29 @@ use std::sync::Arc; use async_trait::async_trait; +use config::Visibility; use mpz_circuits::{ - types::{StaticValueType, Value, ValueType}, + types::{PrimitiveType, StaticValueType, Value, ValueType}, Circuit, }; -pub use mpz_core::value::{ValueId, ValueRef}; pub mod config; pub(crate) mod evaluator; pub(crate) mod generator; pub(crate) mod internal_circuits; +pub(crate) mod memory; pub mod ot; pub mod protocol; -pub(crate) mod registry; mod threadpool; +pub mod value; pub use evaluator::{Evaluator, EvaluatorConfig, EvaluatorConfigBuilder, EvaluatorError}; pub use generator::{Generator, GeneratorConfig, GeneratorConfigBuilder, GeneratorError}; -pub use registry::ValueRegistry; +pub use memory::{AssignedValues, ValueMemory}; pub use threadpool::ThreadPool; use utils::id::NestedId; +use value::{ArrayRef, ValueId, ValueRef}; /// Errors that can occur when using an implementation of [`Vm`]. #[derive(Debug, thiserror::Error)] @@ -64,10 +66,36 @@ pub enum MemoryError { DuplicateValueId(ValueId), #[error("duplicate value: {0:?}")] DuplicateValue(ValueRef), + #[error("value with id {0} has not been defined")] + Undefined(String), + #[error("attempted to create an invalid array: {0}")] + InvalidArray(String), #[error(transparent)] - TypeError(#[from] mpz_circuits::types::TypeError), - #[error("invalid value type {1:?} for {0:?}")] - InvalidType(ValueId, mpz_circuits::types::ValueType), + Assignment(#[from] AssignmentError), +} + +/// Errors that can occur when assigning values. +#[derive(Debug, thiserror::Error)] +pub enum AssignmentError { + /// The value is already assigned. + #[error("value already assigned: {0:?}")] + Duplicate(ValueId), + /// Can not assign to a blind input value. + #[error("can not assign to a blind input value: {0:?}")] + BlindInput(ValueId), + /// Can not assign to an output value. + #[error("can not assign to an output value: {0:?}")] + Output(ValueId), + /// Attempted to assign a value with an invalid type. + #[error("invalid value type {actual:?} for {value:?}, expected {expected:?}")] + Type { + /// The value reference. + value: ValueRef, + /// The expected type. + expected: ValueType, + /// The actual type. + actual: ValueType, + }, } /// Errors that can occur when executing a circuit. @@ -144,70 +172,130 @@ pub trait Thread: Memory {} /// This trait provides methods for interacting with values in memory. pub trait Memory { - /// Adds a new public input value, returning a reference to it. - fn new_public_input( + /// Adds a new input value, returning a reference to it. + fn new_input_with_type( &self, id: &str, - value: T, + typ: ValueType, + visibility: Visibility, ) -> Result; - /// Adds a new public array input value, returning a reference to it. - fn new_public_array_input( + /// Adds a new input value, returning a reference to it. + fn new_input( &self, id: &str, - value: Vec, - ) -> Result - where - Vec: Into; + visibility: Visibility, + ) -> Result { + self.new_input_with_type(id, T::value_type(), visibility) + } /// Adds a new public input value, returning a reference to it. - fn new_public_input_by_type(&self, id: &str, value: Value) -> Result; + fn new_public_input(&self, id: &str) -> Result { + self.new_input::(id, Visibility::Public) + } - /// Adds a new private input value, returning a reference to it. - fn new_private_input( + /// Adds a new public array input value, returning a reference to it. + fn new_public_array_input( &self, id: &str, - value: Option, - ) -> Result; + len: usize, + ) -> Result { + self.new_input_with_type(id, ValueType::new_array::(len), Visibility::Public) + } + + /// Adds a new private input value, returning a reference to it. + fn new_private_input(&self, id: &str) -> Result { + self.new_input::(id, Visibility::Private) + } /// Adds a new private array input value, returning a reference to it. - fn new_private_array_input( + fn new_private_array_input( &self, id: &str, - value: Option>, len: usize, - ) -> Result - where - Vec: Into; + ) -> Result { + self.new_input_with_type(id, ValueType::new_array::(len), Visibility::Private) + } - /// Adds a new private input value, returning a reference to it. - fn new_private_input_by_type( + /// Adds a new blind input value, returning a reference to it. + fn new_blind_input(&self, id: &str) -> Result { + self.new_input::(id, Visibility::Blind) + } + + /// Adds a new blind array input value, returning a reference to it. + fn new_blind_array_input( &self, id: &str, - ty: &ValueType, - value: Option, - ) -> Result; + len: usize, + ) -> Result { + self.new_input_with_type(id, ValueType::new_array::(len), Visibility::Blind) + } - /// Creates a new output value, returning a reference to it. - fn new_output(&self, id: &str) -> Result; + /// Adds a new output value, returning a reference to it. + fn new_output_with_type(&self, id: &str, typ: ValueType) -> Result; + + /// Adds a new output value, returning a reference to it. + fn new_output(&self, id: &str) -> Result { + self.new_output_with_type(id, T::value_type()) + } /// Creates a new array output value, returning a reference to it. - fn new_array_output( + fn new_array_output( &self, id: &str, len: usize, - ) -> Result - where - Vec: Into; + ) -> Result { + self.new_output_with_type(id, ValueType::new_array::(len)) + } + + /// Assigns a value. + fn assign(&self, value_ref: &ValueRef, value: impl Into) -> Result<(), MemoryError>; - /// Creates a new output value, returning a reference to it. - fn new_output_by_type(&self, id: &str, ty: &ValueType) -> Result; + /// Assigns a value. + fn assign_by_id(&self, id: &str, value: impl Into) -> Result<(), MemoryError>; /// Returns a value if it exists. fn get_value(&self, id: &str) -> Option; + /// Returns the type of a value. + fn get_value_type(&self, value_ref: &ValueRef) -> ValueType; + /// Returns the type of a value if it exists. - fn get_value_type(&self, id: &str) -> Option; + fn get_value_type_by_id(&self, id: &str) -> Option; + + /// Creates an array from the provided values. + /// + /// All values must be of the same primitive type. + fn array_from_values(&self, values: &[ValueRef]) -> Result { + if values.is_empty() { + return Err(MemoryError::InvalidArray( + "cannot create an array with no values".to_string(), + )); + } + + let mut ids = Vec::with_capacity(values.len()); + let elem_typ = self.get_value_type(&values[0]); + for value in values { + let ValueRef::Value { id } = value else { + return Err(MemoryError::InvalidArray( + "an array can only contain primitive types".to_string(), + )); + }; + + let value_typ = self.get_value_type(value); + + if value_typ != elem_typ { + return Err(MemoryError::InvalidArray(format!( + "all values in an array must have the same type, expected {:?}, got {:?}", + elem_typ, value_typ + ))); + }; + + ids.push(id.clone()); + } + + Ok(ValueRef::Array(ArrayRef::new(ids))) + } } /// This trait provides methods for executing a circuit. diff --git a/garble/mpz-garble/src/memory.rs b/garble/mpz-garble/src/memory.rs new file mode 100644 index 00000000..5bf154d3 --- /dev/null +++ b/garble/mpz-garble/src/memory.rs @@ -0,0 +1,560 @@ +use std::collections::{HashMap, HashSet}; + +use mpz_circuits::types::{Value, ValueType}; +use mpz_garble_core::{encoding_state::LabelState, EncodedValue}; + +use crate::{ + config::Visibility, + value::{ArrayRef, ValueId, ValueRef}, + AssignmentError, MemoryError, +}; + +/// Collection of assigned values. +#[derive(Debug)] +pub struct AssignedValues { + /// Public values. + pub public: Vec<(ValueId, Value)>, + /// Private values. + pub private: Vec<(ValueId, Value)>, + /// Blind values. + pub blind: Vec<(ValueId, ValueType)>, +} + +enum AssignedValue { + Public(Value), + Private(Value), + Blind(ValueType), +} + +enum ValueDetails { + Input { + typ: ValueType, + visibility: Visibility, + }, + Output { + typ: ValueType, + }, +} + +impl ValueDetails { + fn typ(&self) -> &ValueType { + match self { + ValueDetails::Input { typ, .. } => typ, + ValueDetails::Output { typ } => typ, + } + } +} + +/// A memory for storing values. +#[derive(Default)] +pub struct ValueMemory { + /// IDs for each reference + id_to_ref: HashMap, + /// References for each ID + ref_to_id: HashMap, + /// Details for each value + details: HashMap, + /// Values that have been assigned and blind values + assigned: HashSet, + /// Buffer containing assigned values + assigned_buffer: HashMap, +} + +opaque_debug::implement!(ValueMemory); + +impl ValueMemory { + /// Adds a new input value to the memory. + /// + /// # Arguments + /// + /// * `id` - The ID of the value. + /// * `typ` - The type of the value. + /// * `visibility` - The visibility of the value. + pub fn new_input( + &mut self, + id: &str, + typ: ValueType, + visibility: Visibility, + ) -> Result { + let value_id = ValueId::new(id); + let value_ref = if let ValueType::Array(typ, len) = typ { + let typ = *typ; + let mut ids = Vec::with_capacity(len); + for i in 0..len { + let elem_id = value_id.append_counter(i); + + if self.details.contains_key(&elem_id) { + return Err(MemoryError::DuplicateValueId(elem_id)); + } + + self.details.insert( + elem_id.clone(), + ValueDetails::Input { + typ: typ.clone(), + visibility, + }, + ); + ids.push(elem_id); + } + + if let Visibility::Blind = visibility { + for id in &ids { + self.assigned.insert(id.clone()); + self.assigned_buffer + .insert(id.clone(), AssignedValue::Blind(typ.clone())); + } + } + + ValueRef::Array(ArrayRef::new(ids)) + } else { + if self.details.contains_key(&value_id) { + return Err(MemoryError::DuplicateValueId(value_id)); + } + + self.details.insert( + value_id.clone(), + ValueDetails::Input { + typ: typ.clone(), + visibility, + }, + ); + + if let Visibility::Blind = visibility { + self.assigned.insert(value_id.clone()); + self.assigned_buffer + .insert(value_id.clone(), AssignedValue::Blind(typ.clone())); + } + + ValueRef::Value { id: value_id } + }; + + self.id_to_ref.insert(id.to_string(), value_ref.clone()); + self.ref_to_id.insert(value_ref.clone(), id.to_string()); + + Ok(value_ref) + } + + /// Adds a new output value to the memory. + /// + /// # Arguments + /// + /// * `id` - The ID of the value. + /// * `typ` - The type of the value. + pub fn new_output(&mut self, id: &str, typ: ValueType) -> Result { + let value_id = ValueId::new(id); + let value_ref = if let ValueType::Array(typ, len) = typ { + let typ = *typ; + let mut ids = Vec::with_capacity(len); + for i in 0..len { + let elem_id = value_id.append_counter(i); + + if self.details.contains_key(&elem_id) { + return Err(MemoryError::DuplicateValueId(elem_id)); + } + + self.details + .insert(elem_id.clone(), ValueDetails::Output { typ: typ.clone() }); + + ids.push(elem_id); + } + + ValueRef::Array(ArrayRef::new(ids)) + } else { + if self.details.contains_key(&value_id) { + return Err(MemoryError::DuplicateValueId(value_id)); + } + + self.details + .insert(value_id.clone(), ValueDetails::Output { typ }); + + ValueRef::Value { id: value_id } + }; + + self.id_to_ref.insert(id.to_string(), value_ref.clone()); + self.ref_to_id.insert(value_ref.clone(), id.to_string()); + + Ok(value_ref) + } + + /// Assigns a value to a value reference. + /// + /// # Arguments + /// + /// * `value_ref` - The value reference. + /// * `value` - The value to assign. + pub fn assign(&mut self, value_ref: &ValueRef, value: Value) -> Result<(), MemoryError> { + match value_ref { + ValueRef::Array(array) => { + let elem_details = self + .details + .get(&array.ids()[0]) + .expect("value is defined if reference exists"); + + let expected_typ = + ValueType::Array(Box::new(elem_details.typ().clone()), array.len()); + let actual_typ = value.value_type(); + if expected_typ != actual_typ { + Err(AssignmentError::Type { + value: value_ref.clone(), + expected: expected_typ, + actual: actual_typ, + })? + } + + let Value::Array(elems) = value else { + unreachable!("value type is checked above"); + }; + + for (id, elem) in array.ids().iter().zip(elems) { + self.assign(&ValueRef::Value { id: id.clone() }, elem)?; + } + } + ValueRef::Value { id } => { + let details = self + .details + .get(id) + .expect("value is defined if reference exists"); + + let ValueDetails::Input { typ, visibility } = details else { + Err(AssignmentError::Output(id.clone()))? + }; + + if typ != &value.value_type() { + Err(AssignmentError::Type { + value: value_ref.clone(), + expected: typ.clone(), + actual: value.value_type(), + })? + } + + let value = match visibility { + Visibility::Public => AssignedValue::Public(value), + Visibility::Private => AssignedValue::Private(value), + Visibility::Blind => Err(AssignmentError::BlindInput(id.clone()))?, + }; + + if self.assigned.contains(id) { + Err(AssignmentError::Duplicate(id.clone()))? + } + + self.assigned_buffer.insert(id.clone(), value); + self.assigned.insert(id.clone()); + } + } + + Ok(()) + } + + /// Returns a value reference by ID if it exists. + pub fn get_ref_by_id(&self, id: &str) -> Option<&ValueRef> { + self.id_to_ref.get(id) + } + + /// Returns a value ID by reference if it exists. + pub fn get_id_by_ref(&self, value_ref: &ValueRef) -> Option<&str> { + self.ref_to_id.get(value_ref).map(|id| id.as_str()) + } + + /// Returns the type of value of a value reference. + pub fn get_value_type(&self, value_ref: &ValueRef) -> ValueType { + match value_ref { + ValueRef::Array(array) => { + let details = self + .details + .get(&array.ids()[0]) + .expect("value is defined if reference exists"); + + ValueType::Array(Box::new(details.typ().clone()), array.len()) + } + ValueRef::Value { id } => self + .details + .get(id) + .expect("value is defined if reference exists") + .typ() + .clone(), + } + } + + /// Drains assigned values from buffer if they are present. + /// + /// Returns a tuple of public, private, and blind values. + pub fn drain_assigned(&mut self, values: &[ValueRef]) -> AssignedValues { + let mut public = Vec::new(); + let mut private = Vec::new(); + let mut blind = Vec::new(); + for id in values.iter().flat_map(|value| value.iter()) { + if let Some(value) = self.assigned_buffer.remove(id) { + match value { + AssignedValue::Public(v) => public.push((id.clone(), v)), + AssignedValue::Private(v) => private.push((id.clone(), v)), + AssignedValue::Blind(v) => blind.push((id.clone(), v)), + } + } + } + + AssignedValues { + public, + private, + blind, + } + } +} + +/// A unique ID for an encoding. +/// +/// # Warning +/// +/// The internal representation for this type is a `u64` and is computed using a hash function. +/// As such, it is not guaranteed to be unique and collisions may occur. Contexts using these +/// IDs should be aware of this and handle collisions appropriately. +/// +/// For example, an encoding should never be used for more than one value as this will compromise +/// the security of the MPC protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub(crate) struct EncodingId(u64); + +impl EncodingId { + /// Create a new encoding ID. + pub(crate) fn new(id: u64) -> Self { + Self(id) + } +} + +impl From for EncodingId { + fn from(id: u64) -> Self { + Self::new(id) + } +} + +/// Errors which can occur when registering an encoding. +#[derive(Debug, thiserror::Error)] +pub enum EncodingMemoryError { + #[error("encoding for value {0:?} is already defined")] + DuplicateId(ValueId), +} + +/// Memory for encodings. +/// +/// This is used to store encodings for values. +/// +/// It enforces that an encoding for a value is only set once. +#[derive(Debug)] +pub(crate) struct EncodingMemory +where + T: LabelState, +{ + encodings: HashMap>, +} + +impl Default for EncodingMemory +where + T: LabelState, +{ + fn default() -> Self { + Self { + encodings: HashMap::new(), + } + } +} + +impl EncodingMemory +where + T: LabelState, +{ + /// Set the encoding for a value id. + pub(crate) fn set_encoding_by_id( + &mut self, + id: &ValueId, + encoding: EncodedValue, + ) -> Result<(), EncodingMemoryError> { + let encoding_id = EncodingId::new(id.to_u64()); + if self.encodings.contains_key(&encoding_id) { + return Err(EncodingMemoryError::DuplicateId(id.clone())); + } + + self.encodings.insert(encoding_id, encoding); + + Ok(()) + } + + /// Set the encoding for a value. + /// + /// # Panics + /// + /// Panics if the encoding for the value has already been set, or if the value + /// type does not match the encoding type. + pub(crate) fn set_encoding( + &mut self, + value: &ValueRef, + encoding: EncodedValue, + ) -> Result<(), EncodingMemoryError> { + match (value, encoding) { + (ValueRef::Value { id }, encoding) => self.set_encoding_by_id(id, encoding)?, + (ValueRef::Array(array), EncodedValue::Array(encodings)) + if array.len() == encodings.len() => + { + for (id, encoding) in array.ids().iter().zip(encodings) { + self.set_encoding_by_id(id, encoding)? + } + } + _ => panic!("value type {:?} does not match encoding type", value), + } + + Ok(()) + } + + /// Get the encoding for a value id if it exists. + pub(crate) fn get_encoding_by_id(&self, id: &ValueId) -> Option> { + self.encodings.get(&id.to_u64().into()).cloned() + } + + /// Get the encoding for a value if it exists. + /// + /// # Panics + /// + /// Panics if the value is an array and if the type of its elements is not consistent. + pub(crate) fn get_encoding(&self, value: &ValueRef) -> Option> { + match value { + ValueRef::Value { id, .. } => self.encodings.get(&id.to_u64().into()).cloned(), + ValueRef::Array(array) => { + let encodings = array + .ids() + .iter() + .map(|id| self.encodings.get(&id.to_u64().into()).cloned()) + .collect::>>()?; + + assert!( + encodings + .windows(2) + .all(|window| window[0].value_type() == window[1].value_type()), + "inconsistent element types in array {:?}", + value + ); + + Some(EncodedValue::Array(encodings)) + } + } + } + + /// Returns whether an encoding is present for a value id. + pub(crate) fn contains(&self, id: &ValueId) -> bool { + self.encodings.contains_key(&id.to_u64().into()) + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use super::*; + + use mpz_circuits::types::{StaticValueType, ValueType}; + use mpz_garble_core::{encoding_state, ChaChaEncoder, Encoder}; + use rstest::*; + + #[fixture] + fn encoder() -> ChaChaEncoder { + ChaChaEncoder::new([0; 32]) + } + + fn generate_encoding( + encoder: ChaChaEncoder, + value: &ValueRef, + ty: &ValueType, + ) -> EncodedValue { + match (value, ty) { + (ValueRef::Value { id }, ty) => encoder.encode_by_type(id.to_u64(), ty), + (ValueRef::Array(array), ValueType::Array(elem_ty, _)) => EncodedValue::Array( + array + .ids() + .iter() + .map(|id| encoder.encode_by_type(id.to_u64(), elem_ty)) + .collect(), + ), + _ => panic!(), + } + } + + #[rstest] + #[case::bit(PhantomData::)] + #[case::u8(PhantomData::)] + #[case::u16(PhantomData::)] + #[case::u64(PhantomData::)] + #[case::u64(PhantomData::)] + #[case::u128(PhantomData::)] + #[case::bit_array(PhantomData::<[bool; 16]>)] + #[case::u8_array(PhantomData::<[u8; 16]>)] + #[case::u16_array(PhantomData::<[u16; 16]>)] + #[case::u32_array(PhantomData::<[u32; 16]>)] + #[case::u64_array(PhantomData::<[u64; 16]>)] + #[case::u128_array(PhantomData::<[u128; 16]>)] + fn test_value_memory_duplicate_fails(#[case] _ty: PhantomData) + where + T: StaticValueType + Default + std::fmt::Debug, + { + let mut memory = ValueMemory::default(); + + let _ = memory + .new_input("test", T::value_type(), Visibility::Private) + .unwrap(); + + let err = memory + .new_input("test", T::value_type(), Visibility::Private) + .unwrap_err(); + + assert!(matches!(err, MemoryError::DuplicateValueId(_))); + } + + #[rstest] + #[case::bit(PhantomData::)] + #[case::u8(PhantomData::)] + #[case::u16(PhantomData::)] + #[case::u64(PhantomData::)] + #[case::u64(PhantomData::)] + #[case::u128(PhantomData::)] + #[case::bit_array(PhantomData::<[bool; 16]>)] + #[case::u8_array(PhantomData::<[u8; 16]>)] + #[case::u16_array(PhantomData::<[u16; 16]>)] + #[case::u32_array(PhantomData::<[u32; 16]>)] + #[case::u64_array(PhantomData::<[u64; 16]>)] + #[case::u128_array(PhantomData::<[u128; 16]>)] + fn test_encoding_memory_set_duplicate_fails( + encoder: ChaChaEncoder, + #[case] _ty: PhantomData, + ) where + T: StaticValueType + Default + std::fmt::Debug, + { + let mut memory = ValueMemory::default(); + let mut full_encoding_memory = EncodingMemory::::default(); + let mut active_encoding_memory = EncodingMemory::::default(); + + let typ = T::value_type(); + let value = memory + .new_input("test", typ.clone(), Visibility::Private) + .unwrap(); + + let encoding = generate_encoding(encoder, &value, &typ); + + full_encoding_memory + .set_encoding(&value, encoding.clone()) + .unwrap(); + + let err = full_encoding_memory + .set_encoding(&value, encoding.clone()) + .unwrap_err(); + + assert!(matches!(err, EncodingMemoryError::DuplicateId(_))); + + let encoding = encoding.select(T::default()).unwrap(); + + active_encoding_memory + .set_encoding(&value, encoding.clone()) + .unwrap(); + + let err = active_encoding_memory + .set_encoding(&value, encoding) + .unwrap_err(); + + assert!(matches!(err, EncodingMemoryError::DuplicateId(_))); + } +} diff --git a/garble/mpz-garble/src/protocol/deap/error.rs b/garble/mpz-garble/src/protocol/deap/error.rs index 31d200b8..99783f73 100644 --- a/garble/mpz-garble/src/protocol/deap/error.rs +++ b/garble/mpz-garble/src/protocol/deap/error.rs @@ -1,7 +1,6 @@ -use mpz_core::value::ValueRef; use mpz_garble_core::{msg::GarbleMessage, ValueError}; -use crate::{DecodeError, ExecutionError, ProveError, VerifyError}; +use crate::{value::ValueRef, DecodeError, ExecutionError, ProveError, VerifyError}; /// Errors that can occur during the DEAP protocol. #[derive(Debug, thiserror::Error)] @@ -47,7 +46,7 @@ pub enum FinalizationError { pub enum PeerEncodingsError { #[error("Encodings not available since DEAP instance already finalized")] AlreadyFinalized, - #[error("Value id was not found in registry: {0:?}")] + #[error("Value id was not found in memory: {0:?}")] ValueIdNotFound(String), #[error("Encoding is not available for value: {0:?}")] EncodingNotAvailable(ValueRef), diff --git a/garble/mpz-garble/src/protocol/deap/memory.rs b/garble/mpz-garble/src/protocol/deap/memory.rs index fa6d5198..737a0d7b 100644 --- a/garble/mpz-garble/src/protocol/deap/memory.rs +++ b/garble/mpz-garble/src/protocol/deap/memory.rs @@ -1,176 +1,48 @@ -use mpz_circuits::types::{StaticValueType, TypeError, Value, ValueType}; -use mpz_core::value::ValueRef; +use mpz_circuits::types::{Value, ValueType}; -use crate::{ - config::{ValueConfig, Visibility}, - Memory, MemoryError, -}; +use crate::{config::Visibility, value::ValueRef, Memory, MemoryError}; use super::DEAP; impl Memory for DEAP { - fn new_public_input( + fn new_input_with_type( &self, id: &str, - value: T, - ) -> Result { - let mut state = self.state(); - - let ty = T::value_type(); - let value_ref = state.value_registry.add_value(id, ty)?; - - state.add_input_config( - &value_ref, - ValueConfig::new_public::(value_ref.clone(), value).expect("config is valid"), - ); - - Ok(value_ref) - } - - fn new_public_array_input( - &self, - id: &str, - value: Vec, - ) -> Result - where - Vec: Into, - { - let mut state = self.state(); - - let value: Value = value.into(); - let ty = value.value_type(); - let value_ref = state.value_registry.add_value(id, ty)?; - - state.add_input_config( - &value_ref, - ValueConfig::new_public::(value_ref.clone(), value).expect("config is valid"), - ); - - Ok(value_ref) - } - - fn new_public_input_by_type(&self, id: &str, value: Value) -> Result { - let mut state = self.state(); - - let ty = value.value_type(); - let value_ref = state.value_registry.add_value(id, ty.clone())?; - - state.add_input_config( - &value_ref, - ValueConfig::new(value_ref.clone(), ty, Some(value), Visibility::Public) - .expect("config is valid"), - ); - - Ok(value_ref) - } - - fn new_private_input( - &self, - id: &str, - value: Option, - ) -> Result { - let mut state = self.state(); - - let ty = T::value_type(); - let value_ref = state.value_registry.add_value(id, ty)?; - - state.add_input_config( - &value_ref, - ValueConfig::new_private::(value_ref.clone(), value).expect("config is valid"), - ); - - Ok(value_ref) - } - - fn new_private_array_input( - &self, - id: &str, - value: Option>, - len: usize, - ) -> Result - where - Vec: Into, - { - let mut state = self.state(); - - let ty = ValueType::new_array::(len); - let value_ref = state.value_registry.add_value(id, ty)?; - - state.add_input_config( - &value_ref, - ValueConfig::new_private_array::(value_ref.clone(), value, len) - .expect("config is valid"), - ); - - Ok(value_ref) - } - - fn new_private_input_by_type( - &self, - id: &str, - ty: &ValueType, - value: Option, + typ: ValueType, + visibility: Visibility, ) -> Result { - if let Some(value) = &value { - if &value.value_type() != ty { - return Err(TypeError::UnexpectedType { - expected: ty.clone(), - actual: value.value_type(), - })?; - } - } - - let mut state = self.state(); - - let value_ref = state.value_registry.add_value(id, ty.clone())?; - - state.add_input_config( - &value_ref, - ValueConfig::new(value_ref.clone(), ty.clone(), value, Visibility::Private) - .expect("config is valid"), - ); - - Ok(value_ref) + self.state().memory.new_input(id, typ, visibility) } - fn new_output(&self, id: &str) -> Result { - let mut state = self.state(); - - let ty = T::value_type(); - let value_ref = state.value_registry.add_value(id, ty)?; - - Ok(value_ref) + fn new_output_with_type(&self, id: &str, typ: ValueType) -> Result { + self.state().memory.new_output(id, typ) } - fn new_array_output( - &self, - id: &str, - len: usize, - ) -> Result - where - Vec: Into, - { - let mut state = self.state(); - - let ty = ValueType::new_array::(len); - let value_ref = state.value_registry.add_value(id, ty)?; - - Ok(value_ref) + fn assign(&self, value_ref: &ValueRef, value: impl Into) -> Result<(), MemoryError> { + self.state().memory.assign(value_ref, value.into()) } - fn new_output_by_type(&self, id: &str, ty: &ValueType) -> Result { + fn assign_by_id(&self, id: &str, value: impl Into) -> Result<(), MemoryError> { let mut state = self.state(); - - let value_ref = state.value_registry.add_value(id, ty.clone())?; - - Ok(value_ref) + let value_ref = state + .memory + .get_ref_by_id(id) + .ok_or_else(|| MemoryError::Undefined(id.to_string()))? + .clone(); + state.memory.assign(&value_ref, value.into()) } fn get_value(&self, id: &str) -> Option { - self.state().value_registry.get_value(id) + self.state().memory.get_ref_by_id(id).cloned() + } + + fn get_value_type(&self, value_ref: &ValueRef) -> ValueType { + self.state().memory.get_value_type(value_ref) } - fn get_value_type(&self, id: &str) -> Option { - self.state().value_registry.get_value_type(id) + fn get_value_type_by_id(&self, id: &str) -> Option { + let state = self.state(); + let value_ref = state.memory.get_ref_by_id(id)?; + Some(state.memory.get_value_type(value_ref)) } } diff --git a/garble/mpz-garble/src/protocol/deap/mod.rs b/garble/mpz-garble/src/protocol/deap/mod.rs index ab41a81c..99c2c806 100644 --- a/garble/mpz-garble/src/protocol/deap/mod.rs +++ b/garble/mpz-garble/src/protocol/deap/mod.rs @@ -14,23 +14,26 @@ use std::{ }; use futures::{Sink, SinkExt, Stream, StreamExt, TryFutureExt}; -use mpz_circuits::{types::Value, Circuit}; +use mpz_circuits::{ + types::{Value, ValueType}, + Circuit, +}; use mpz_core::{ commit::{Decommitment, HashCommit}, hash::{Hash, SecureHash}, - value::{ValueId, ValueRef}, }; use mpz_garble_core::{msg::GarbleMessage, EqualityCheck}; use rand::thread_rng; use utils_aio::expect_msg_or_err; use crate::{ - config::{Role, ValueConfig, ValueIdConfig, Visibility}, + config::{Role, Visibility}, evaluator::{Evaluator, EvaluatorConfigBuilder}, generator::{Generator, GeneratorConfigBuilder}, internal_circuits::{build_otp_circuit, build_otp_shared_circuit}, + memory::ValueMemory, ot::{OTReceiveEncoding, OTSendEncoding, OTVerifyEncoding}, - registry::ValueRegistry, + value::ValueRef, }; pub use error::{DEAPError, PeerEncodingsError}; @@ -50,11 +53,7 @@ pub struct DEAP { #[derive(Debug, Default)] struct State { - /// A registry of all values - value_registry: ValueRegistry, - /// An internal buffer for value configurations which get - /// drained and set up prior to execution. - input_buffer: HashMap, + memory: ValueMemory, /// Equality check decommitments withheld by the leader /// prior to finalization @@ -157,7 +156,7 @@ impl DEAP { OTS: OTSendEncoding, OTR: OTReceiveEncoding, { - let input_configs = self.state().remove_input_configs(inputs); + let assigned_values = self.state().memory.drain_assigned(inputs); let id_0 = format!("{}/0", id); let id_1 = format!("{}/1", id); @@ -170,10 +169,10 @@ impl DEAP { // Setup inputs concurrently. futures::try_join!( self.gen - .setup_inputs(&gen_id, &input_configs, sink, ot_send) + .setup_assigned_values(&gen_id, &assigned_values, sink, ot_send) .map_err(DEAPError::from), self.ev - .setup_inputs(&ev_id, &input_configs, stream, ot_recv) + .setup_assigned_values(&ev_id, &assigned_values, stream, ot_recv) .map_err(DEAPError::from) )?; @@ -232,12 +231,12 @@ impl DEAP { ))?; } - let input_configs = self.state().remove_input_configs(inputs); + let assigned_values = self.state().memory.drain_assigned(inputs); // The prover only acts as the evaluator for ZKPs instead of // dual-execution. self.ev - .setup_inputs(id, &input_configs, stream, ot_recv) + .setup_assigned_values(id, &assigned_values, stream, ot_recv) .map_err(DEAPError::from) .await?; @@ -303,12 +302,12 @@ impl DEAP { ))?; } - let input_configs = self.state().remove_input_configs(inputs); + let assigned_values = self.state().memory.drain_assigned(inputs); // The verifier only acts as the generator for ZKPs instead of // dual-execution. self.gen - .setup_inputs(id, &input_configs, sink, ot_send) + .setup_assigned_values(id, &assigned_values, sink, ot_send) .map_err(DEAPError::from) .await?; @@ -454,49 +453,25 @@ impl DEAP { OTS: OTSendEncoding, OTR: OTReceiveEncoding, { - let (otp_refs, masked_refs): (Vec<_>, Vec<_>) = values - .iter() - .map(|value| (value.append_id("otp"), value.append_id("masked"))) - .unzip(); - - let (otp_tys, otp_values) = { + let (((otp_refs, otp_typs), otp_values), mask_refs): (((Vec<_>, Vec<_>), Vec<_>), Vec<_>) = { let mut state = self.state(); - let otp_tys = values + values .iter() - .map(|value| { - state - .value_registry - .get_value_type_with_ref(value) - .ok_or_else(|| DEAPError::ValueDoesNotExist(value.clone())) + .enumerate() + .map(|(idx, value)| { + let (otp_ref, otp_value) = + state.new_private_otp(&format!("{id}/{idx}/otp"), value); + let otp_typ = otp_value.value_type(); + let mask_ref = state.new_output_mask(&format!("{id}/{idx}/mask"), value); + + (((otp_ref, otp_typ), otp_value), mask_ref) }) - .collect::, _>>()?; - - let otp_values = otp_tys - .iter() - .map(|ty| Value::random(&mut thread_rng(), ty)) - .collect::>(); - - for ((otp_ref, otp_ty), otp_value) in - otp_refs.iter().zip(otp_tys.iter()).zip(otp_values.iter()) - { - state.add_input_config( - otp_ref, - ValueConfig::new( - otp_ref.clone(), - otp_ty.clone(), - Some(otp_value.clone()), - Visibility::Private, - ) - .expect("config is valid"), - ); - } - - (otp_tys, otp_values) + .unzip() }; // Apply OTPs to values - let circ = build_otp_circuit(&otp_tys); + let circ = build_otp_circuit(&otp_typs); let inputs = values .iter() @@ -506,19 +481,12 @@ impl DEAP { .collect::>(); self.execute( - id, - circ, - &inputs, - &masked_refs, - sink, - stream, - ot_send, - ot_recv, + id, circ, &inputs, &mask_refs, sink, stream, ot_send, ot_recv, ) .await?; // Decode masked values - let masked_values = self.decode(id, &masked_refs, sink, stream).await?; + let masked_values = self.decode(id, &mask_refs, sink, stream).await?; // Remove OTPs, returning plaintext values Ok(masked_values @@ -543,37 +511,23 @@ impl DEAP { OTS: OTSendEncoding, OTR: OTReceiveEncoding, { - let (otp_refs, masked_refs): (Vec<_>, Vec<_>) = values - .iter() - .map(|value| (value.append_id("otp"), value.append_id("masked"))) - .unzip(); - - let otp_tys = { + let ((otp_refs, otp_typs), mask_refs): ((Vec<_>, Vec<_>), Vec<_>) = { let mut state = self.state(); - let otp_tys = values + values .iter() - .map(|value| { - state - .value_registry - .get_value_type_with_ref(value) - .ok_or_else(|| DEAPError::ValueDoesNotExist(value.clone())) - }) - .collect::, _>>()?; - - for (otp_ref, otp_ty) in otp_refs.iter().zip(otp_tys.iter()) { - state.add_input_config( - otp_ref, - ValueConfig::new(otp_ref.clone(), otp_ty.clone(), None, Visibility::Private) - .expect("config is valid"), - ); - } + .enumerate() + .map(|(idx, value)| { + let (otp_ref, otp_typ) = state.new_blind_otp(&format!("{id}/{idx}/otp"), value); + let mask_ref = state.new_output_mask(&format!("{id}/{idx}/mask"), value); - otp_tys + ((otp_ref, otp_typ), mask_ref) + }) + .unzip() }; // Apply OTPs to values - let circ = build_otp_circuit(&otp_tys); + let circ = build_otp_circuit(&otp_typs); let inputs = values .iter() @@ -583,19 +537,12 @@ impl DEAP { .collect::>(); self.execute( - id, - circ, - &inputs, - &masked_refs, - sink, - stream, - ot_send, - ot_recv, + id, circ, &inputs, &mask_refs, sink, stream, ot_send, ot_recv, ) .await?; // Discard masked values - _ = self.decode(id, &masked_refs, sink, stream).await?; + _ = self.decode(id, &mask_refs, sink, stream).await?; Ok(()) } @@ -615,83 +562,41 @@ impl DEAP { OTS: OTSendEncoding, OTR: OTReceiveEncoding, { - let (otp_0_refs, (otp_1_refs, masked_refs)): (Vec<_>, (Vec<_>, Vec<_>)) = values - .iter() - .map(|value| { - ( - value.append_id("otp_0"), - (value.append_id("otp_1"), value.append_id("masked")), - ) - }) - .unzip(); - - let (otp_tys, otp_values) = { + #[allow(clippy::type_complexity)] + let ((((otp_0_refs, otp_1_refs), otp_typs), otp_values), mask_refs): ( + (((Vec<_>, Vec<_>), Vec<_>), Vec<_>), + Vec<_>, + ) = { let mut state = self.state(); - let otp_tys = values + values .iter() - .map(|value| { - state - .value_registry - .get_value_type_with_ref(value) - .ok_or_else(|| DEAPError::ValueDoesNotExist(value.clone())) + .enumerate() + .map(|(idx, value)| { + let (otp_0_ref, otp_1_ref, otp_value, otp_typ) = match self.role { + Role::Leader => { + let (otp_0_ref, otp_value) = + state.new_private_otp(&format!("{id}/{idx}/otp_0"), value); + let (otp_1_ref, otp_typ) = + state.new_blind_otp(&format!("{id}/{idx}/otp_1"), value); + (otp_0_ref, otp_1_ref, otp_value, otp_typ) + } + Role::Follower => { + let (otp_0_ref, otp_typ) = + state.new_blind_otp(&format!("{id}/{idx}/otp_0"), value); + let (otp_1_ref, otp_value) = + state.new_private_otp(&format!("{id}/{idx}/otp_1"), value); + (otp_0_ref, otp_1_ref, otp_value, otp_typ) + } + }; + let mask_ref = state.new_output_mask(&format!("{id}/{idx}/mask"), value); + ((((otp_0_ref, otp_1_ref), otp_typ), otp_value), mask_ref) }) - .collect::, _>>()?; - - let otp_values = otp_tys - .iter() - .map(|ty| Value::random(&mut thread_rng(), ty)) - .collect::>(); - - for (((otp_0_ref, opt_1_ref), otp_ty), otp_value) in otp_0_refs - .iter() - .zip(&otp_1_refs) - .zip(&otp_tys) - .zip(&otp_values) - { - let (otp_0_config, otp_1_config) = match self.role { - Role::Leader => ( - ValueConfig::new( - otp_0_ref.clone(), - otp_ty.clone(), - Some(otp_value.clone()), - Visibility::Private, - ) - .expect("config is valid"), - ValueConfig::new( - opt_1_ref.clone(), - otp_ty.clone(), - None, - Visibility::Private, - ) - .expect("config is valid"), - ), - Role::Follower => ( - ValueConfig::new( - otp_0_ref.clone(), - otp_ty.clone(), - None, - Visibility::Private, - ) - .expect("config is valid"), - ValueConfig::new( - opt_1_ref.clone(), - otp_ty.clone(), - Some(otp_value.clone()), - Visibility::Private, - ) - .expect("config is valid"), - ), - }; - state.add_input_config(otp_0_ref, otp_0_config); - state.add_input_config(opt_1_ref, otp_1_config); - } - - (otp_tys, otp_values) + .unzip() }; // Apply OTPs to values - let circ = build_otp_shared_circuit(&otp_tys); + let circ = build_otp_shared_circuit(&otp_typs); let inputs = values .iter() @@ -702,19 +607,12 @@ impl DEAP { .collect::>(); self.execute( - id, - circ, - &inputs, - &masked_refs, - sink, - stream, - ot_send, - ot_recv, + id, circ, &inputs, &mask_refs, sink, stream, ot_send, ot_recv, ) .await?; // Decode masked values - let masked_values = self.decode(id, &masked_refs, sink, stream).await?; + let masked_values = self.decode(id, &mask_refs, sink, stream).await?; match self.role { Role::Leader => { @@ -861,21 +759,40 @@ impl DEAP { } impl State { - /// Adds input configs to the buffer. - fn add_input_config(&mut self, value: &ValueRef, config: ValueConfig) { - value - .iter() - .zip(config.flatten()) - .for_each(|(id, config)| _ = self.input_buffer.insert(id.clone(), config)); + pub(crate) fn new_private_otp(&mut self, id: &str, value_ref: &ValueRef) -> (ValueRef, Value) { + let typ = self.memory.get_value_type(value_ref); + let value = Value::random(&mut thread_rng(), &typ); + + let value_ref = self + .memory + .new_input(id, typ, Visibility::Private) + .expect("otp id is unique"); + + self.memory + .assign(&value_ref, value.clone()) + .expect("value should assign"); + + (value_ref, value) } - /// Returns input configs from the buffer. - fn remove_input_configs(&mut self, values: &[ValueRef]) -> Vec { - values - .iter() - .flat_map(|value| value.iter()) - .filter_map(|id| self.input_buffer.remove(id)) - .collect::>() + pub(crate) fn new_blind_otp( + &mut self, + id: &str, + value_ref: &ValueRef, + ) -> (ValueRef, ValueType) { + let typ = self.memory.get_value_type(value_ref); + + ( + self.memory + .new_input(id, typ.clone(), Visibility::Blind) + .expect("otp id is unique"), + typ, + ) + } + + pub(crate) fn new_output_mask(&mut self, id: &str, value_ref: &ValueRef) -> ValueRef { + let typ = self.memory.get_value_type(value_ref); + self.memory.new_output(id, typ).expect("mask id is unique") } /// Drain the states to be finalized. @@ -947,10 +864,12 @@ mod tests { let leader_fut = { let (mut sink, mut stream) = leader_channel.split(); - let key_ref = leader.new_private_input("key", Some(key)).unwrap(); - let msg_ref = leader.new_private_input::<[u8; 16]>("msg", None).unwrap(); + let key_ref = leader.new_private_input::<[u8; 16]>("key").unwrap(); + let msg_ref = leader.new_blind_input::<[u8; 16]>("msg").unwrap(); let ciphertext_ref = leader.new_output::<[u8; 16]>("ciphertext").unwrap(); + leader.assign(&key_ref, key).unwrap(); + async move { leader .execute( @@ -983,10 +902,12 @@ mod tests { let follower_fut = { let (mut sink, mut stream) = follower_channel.split(); - let key_ref = follower.new_private_input::<[u8; 16]>("key", None).unwrap(); - let msg_ref = follower.new_private_input("msg", Some(msg)).unwrap(); + let key_ref = follower.new_blind_input::<[u8; 16]>("key").unwrap(); + let msg_ref = follower.new_private_input::<[u8; 16]>("msg").unwrap(); let ciphertext_ref = follower.new_output::<[u8; 16]>("ciphertext").unwrap(); + follower.assign(&msg_ref, msg).unwrap(); + async move { follower .execute( @@ -1039,10 +960,12 @@ mod tests { let leader_fut = { let (mut sink, mut stream) = leader_channel.split(); let circ = circ.clone(); - let a_ref = leader.new_private_input("a", Some(a)).unwrap(); - let b_ref = leader.new_private_input::("b", None).unwrap(); + let a_ref = leader.new_private_input::("a").unwrap(); + let b_ref = leader.new_blind_input::("b").unwrap(); let c_ref = leader.new_output::("c").unwrap(); + leader.assign(&a_ref, a).unwrap(); + async move { leader .execute( @@ -1082,10 +1005,12 @@ mod tests { let follower_fut = { let (mut sink, mut stream) = follower_channel.split(); - let a_ref = follower.new_private_input::("a", None).unwrap(); - let b_ref = follower.new_private_input("b", Some(b)).unwrap(); + let a_ref = follower.new_blind_input::("a").unwrap(); + let b_ref = follower.new_private_input::("b").unwrap(); let c_ref = follower.new_output::("c").unwrap(); + follower.assign(&b_ref, b).unwrap(); + async move { follower .execute( @@ -1143,10 +1068,12 @@ mod tests { let leader_fut = { let (mut sink, mut stream) = leader_channel.split(); let circ = circ.clone(); - let a_ref = leader.new_private_input("a", Some(a)).unwrap(); - let b_ref = leader.new_private_input::("b", None).unwrap(); + let a_ref = leader.new_private_input::("a").unwrap(); + let b_ref = leader.new_blind_input::("b").unwrap(); let c_ref = leader.new_output::("c").unwrap(); + leader.assign(&a_ref, a).unwrap(); + async move { leader .execute( @@ -1186,10 +1113,12 @@ mod tests { let follower_fut = { let (mut sink, mut stream) = follower_channel.split(); - let a_ref = follower.new_private_input::("a", None).unwrap(); - let b_ref = follower.new_private_input("b", Some(b)).unwrap(); + let a_ref = follower.new_blind_input::("a").unwrap(); + let b_ref = follower.new_private_input::("b").unwrap(); let c_ref = follower.new_output::("c").unwrap(); + follower.assign(&b_ref, b).unwrap(); + async move { follower .execute( @@ -1270,12 +1199,12 @@ mod tests { let leader_fut = { let (mut sink, mut stream) = leader_channel.split(); - let key_ref = leader - .new_private_input::<[u8; 16]>("key", Some(key)) - .unwrap(); - let msg_ref = leader.new_private_input::<[u8; 16]>("msg", None).unwrap(); + let key_ref = leader.new_private_input::<[u8; 16]>("key").unwrap(); + let msg_ref = leader.new_blind_input::<[u8; 16]>("msg").unwrap(); let ciphertext_ref = leader.new_output::<[u8; 16]>("ciphertext").unwrap(); + leader.assign(&key_ref, key).unwrap(); + async move { leader .defer_prove( @@ -1299,12 +1228,12 @@ mod tests { let follower_fut = { let (mut sink, mut stream) = follower_channel.split(); - let key_ref = follower.new_private_input::<[u8; 16]>("key", None).unwrap(); - let msg_ref = follower - .new_private_input::<[u8; 16]>("msg", Some(msg)) - .unwrap(); + let key_ref = follower.new_blind_input::<[u8; 16]>("key").unwrap(); + let msg_ref = follower.new_private_input::<[u8; 16]>("msg").unwrap(); let ciphertext_ref = follower.new_output::<[u8; 16]>("ciphertext").unwrap(); + follower.assign(&msg_ref, msg).unwrap(); + async move { follower .defer_verify( diff --git a/garble/mpz-garble/src/protocol/deap/vm.rs b/garble/mpz-garble/src/protocol/deap/vm.rs index 40d7452c..93750592 100644 --- a/garble/mpz-garble/src/protocol/deap/vm.rs +++ b/garble/mpz-garble/src/protocol/deap/vm.rs @@ -10,17 +10,17 @@ use futures::{ }; use mpz_circuits::{ - types::{StaticValueType, Value, ValueType}, + types::{Value, ValueType}, Circuit, }; -use mpz_core::value::ValueRef; use mpz_garble_core::{encoding_state::Active, msg::GarbleMessage, EncodedValue}; use utils::id::NestedId; use utils_aio::{duplex::Duplex, mux::MuxChannel}; use crate::{ - config::Role, + config::{Role, Visibility}, ot::{VerifiableOTReceiveEncoding, VerifiableOTSendEncoding}, + value::ValueRef, Decode, DecodeError, DecodePrivate, Execute, ExecutionError, Memory, MemoryError, Prove, ProveError, Thread, Verify, VerifyError, Vm, VmError, }; @@ -202,85 +202,38 @@ where impl Thread for DEAPThread {} -#[async_trait] impl Memory for DEAPThread { - fn new_public_input( - &self, - id: &str, - value: T, - ) -> Result { - self.deap().new_public_input(id, value) - } - - fn new_public_array_input( - &self, - id: &str, - value: Vec, - ) -> Result - where - Vec: Into, - { - self.deap().new_public_array_input(id, value) - } - - fn new_public_input_by_type(&self, id: &str, value: Value) -> Result { - self.deap().new_public_input_by_type(id, value) - } - - fn new_private_input( - &self, - id: &str, - value: Option, - ) -> Result { - self.deap().new_private_input(id, value) - } - - fn new_private_array_input( - &self, - id: &str, - value: Option>, - len: usize, - ) -> Result - where - Vec: Into, - { - self.deap().new_private_array_input(id, value, len) - } - - fn new_private_input_by_type( + fn new_input_with_type( &self, id: &str, - ty: &ValueType, - value: Option, + typ: ValueType, + visibility: Visibility, ) -> Result { - self.deap().new_private_input_by_type(id, ty, value) + self.deap().new_input_with_type(id, typ, visibility) } - fn new_output(&self, id: &str) -> Result { - self.deap().new_output::(id) + fn new_output_with_type(&self, id: &str, typ: ValueType) -> Result { + self.deap().new_output_with_type(id, typ) } - fn new_array_output( - &self, - id: &str, - len: usize, - ) -> Result - where - Vec: Into, - { - self.deap().new_array_output::(id, len) + fn assign(&self, value_ref: &ValueRef, value: impl Into) -> Result<(), MemoryError> { + self.deap().assign(value_ref, value) } - fn new_output_by_type(&self, id: &str, ty: &ValueType) -> Result { - self.deap().new_output_by_type(id, ty) + fn assign_by_id(&self, id: &str, value: impl Into) -> Result<(), MemoryError> { + self.deap().assign_by_id(id, value) } fn get_value(&self, id: &str) -> Option { self.deap().get_value(id) } - fn get_value_type(&self, id: &str) -> Option { - self.deap().get_value_type(id) + fn get_value_type(&self, value_ref: &ValueRef) -> ValueType { + self.deap().get_value_type(value_ref) + } + + fn get_value_type_by_id(&self, id: &str) -> Option { + self.deap().get_value_type_by_id(id) } } @@ -511,14 +464,12 @@ mod tests { let msg = [69u8; 16]; let leader_fut = { - let key_ref = leader_thread - .new_private_input::<[u8; 16]>("key", Some(key)) - .unwrap(); - let msg_ref = leader_thread - .new_private_input::<[u8; 16]>("msg", None) - .unwrap(); + let key_ref = leader_thread.new_private_input::<[u8; 16]>("key").unwrap(); + let msg_ref = leader_thread.new_blind_input::<[u8; 16]>("msg").unwrap(); let ciphertext_ref = leader_thread.new_output::<[u8; 16]>("ciphertext").unwrap(); + leader_thread.assign(&key_ref, key).unwrap(); + async move { leader_thread .execute( @@ -534,16 +485,16 @@ mod tests { }; let follower_fut = { - let key_ref = follower_thread - .new_private_input::<[u8; 16]>("key", None) - .unwrap(); + let key_ref = follower_thread.new_blind_input::<[u8; 16]>("key").unwrap(); let msg_ref = follower_thread - .new_private_input::<[u8; 16]>("msg", Some(msg)) + .new_private_input::<[u8; 16]>("msg") .unwrap(); let ciphertext_ref = follower_thread .new_output::<[u8; 16]>("ciphertext") .unwrap(); + follower_thread.assign(&msg_ref, msg).unwrap(); + async move { follower_thread .execute( diff --git a/garble/mpz-garble/src/registry.rs b/garble/mpz-garble/src/registry.rs deleted file mode 100644 index 64e1d97d..00000000 --- a/garble/mpz-garble/src/registry.rs +++ /dev/null @@ -1,359 +0,0 @@ -use std::collections::HashMap; - -use mpz_circuits::types::ValueType; -use mpz_core::value::{ValueId, ValueRef}; -use mpz_garble_core::{encoding_state::LabelState, EncodedValue}; - -use crate::MemoryError; - -/// A registry of values. -/// -/// This registry is used to track all the values that exist in a session. -/// -/// It enforces that a value is only defined once, returning an error otherwise. -#[derive(Debug, Default)] -pub struct ValueRegistry { - /// A map of value IDs to their types. - values: HashMap, - /// A map of value IDs to their references. - refs: HashMap, -} - -impl ValueRegistry { - /// Adds a value to the registry. - pub fn add_value(&mut self, id: &str, ty: ValueType) -> Result { - self.add_value_with_offset(id, ty, 0) - } - - /// Adds a value to the registry, applying an offset to the ids of the elements if the - /// value is an array. - pub fn add_value_with_offset( - &mut self, - id: &str, - ty: ValueType, - offset: usize, - ) -> Result { - let value_ref = match ty { - ValueType::Array(elem_ty, len) => ValueRef::Array( - (0..len) - .map(|idx| { - let id = ValueId::new(&format!("{}/{}", id, idx + offset)); - self.add_value_id(id.clone(), (*elem_ty).clone())?; - Ok(id) - }) - .collect::, MemoryError>>()?, - ), - _ => { - let id = ValueId::new(id); - self.add_value_id(id.clone(), ty)?; - ValueRef::Value { id } - } - }; - - self.refs.insert(id.to_string(), value_ref.clone()); - - Ok(value_ref) - } - - fn add_value_id(&mut self, id: ValueId, ty: ValueType) -> Result<(), MemoryError> { - // Ensure that the value is not a collection. - if matches!(ty, ValueType::Array(_, _)) { - return Err(MemoryError::InvalidType(id, ty)); - } - - // Ensure that the value is not already defined. - if self.values.contains_key(&id) { - return Err(MemoryError::DuplicateValueId(id)); - } - - self.values.insert(id, ty); - - Ok(()) - } - - /// Returns a reference to the value with the given ID. - pub(crate) fn get_value(&self, id: &str) -> Option { - self.refs.get(id).cloned() - } - - /// Returns the type of the value with the given ID. - pub(crate) fn get_value_type(&self, id: &str) -> Option { - let value_ref = self.get_value(id)?; - - self.get_value_type_with_ref(&value_ref) - } - - pub(crate) fn get_value_type_with_ref(&self, value: &ValueRef) -> Option { - match value { - ValueRef::Value { id } => self.values.get(id).cloned(), - ValueRef::Array(values) => { - let elem_tys = values - .iter() - .map(|id| self.values.get(id).cloned()) - .collect::>>()?; - - // Ensure that all the elements have the same type. - if elem_tys.windows(2).any(|window| window[0] != window[1]) { - return None; - } - - Some(ValueType::Array( - Box::new(elem_tys[0].clone()), - values.len(), - )) - } - } - } -} - -/// A unique ID for an encoding. -/// -/// # Warning -/// -/// The internal representation for this type is a `u64` and is computed using a hash function. -/// As such, it is not guaranteed to be unique and collisions may occur. Contexts using these -/// IDs should be aware of this and handle collisions appropriately. -/// -/// For example, an encoding should never be used for more than one value as this will compromise -/// the security of the MPC protocol. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] -pub(crate) struct EncodingId(u64); - -impl EncodingId { - /// Create a new encoding ID. - pub(crate) fn new(id: u64) -> Self { - Self(id) - } -} - -impl From for EncodingId { - fn from(id: u64) -> Self { - Self::new(id) - } -} - -/// Errors which can occur when registering an encoding. -#[derive(Debug, thiserror::Error)] -pub enum EncodingRegistryError { - #[error("encoding for value {0:?} is already defined")] - DuplicateId(ValueId), -} - -/// A registry of encodings. -/// -/// This registry is used to store encodings for values. -/// -/// It enforces that an encoding for a value is only set once. -#[derive(Debug)] -pub(crate) struct EncodingRegistry -where - T: LabelState, -{ - encodings: HashMap>, -} - -impl Default for EncodingRegistry -where - T: LabelState, -{ - fn default() -> Self { - Self { - encodings: HashMap::new(), - } - } -} - -impl EncodingRegistry -where - T: LabelState, -{ - /// Set the encoding for a value id. - pub(crate) fn set_encoding_by_id( - &mut self, - id: &ValueId, - encoding: EncodedValue, - ) -> Result<(), EncodingRegistryError> { - let encoding_id = EncodingId::new(id.to_u64()); - if self.encodings.contains_key(&encoding_id) { - return Err(EncodingRegistryError::DuplicateId(id.clone())); - } - - self.encodings.insert(encoding_id, encoding); - - Ok(()) - } - - /// Set the encoding for a value. - /// - /// # Panics - /// - /// Panics if the encoding for the value has already been set, or if the value - /// type does not match the encoding type. - pub(crate) fn set_encoding( - &mut self, - value: &ValueRef, - encoding: EncodedValue, - ) -> Result<(), EncodingRegistryError> { - match (value, encoding) { - (ValueRef::Value { id }, encoding) => self.set_encoding_by_id(id, encoding)?, - (ValueRef::Array(ids), EncodedValue::Array(encodings)) - if ids.len() == encodings.len() => - { - for (id, encoding) in ids.iter().zip(encodings) { - self.set_encoding_by_id(id, encoding)? - } - } - _ => panic!("value type {:?} does not match encoding type", value), - } - - Ok(()) - } - - /// Get the encoding for a value id if it exists. - pub(crate) fn get_encoding_by_id(&self, id: &ValueId) -> Option> { - self.encodings.get(&id.to_u64().into()).cloned() - } - - /// Get the encoding for a value if it exists. - /// - /// # Panics - /// - /// Panics if the value is an array and if the type of its elements are not consistent. - pub(crate) fn get_encoding(&self, value: &ValueRef) -> Option> { - match value { - ValueRef::Value { id, .. } => self.encodings.get(&id.to_u64().into()).cloned(), - ValueRef::Array(ids) => { - let encodings = ids - .iter() - .map(|id| self.encodings.get(&id.to_u64().into()).cloned()) - .collect::>>()?; - - assert!( - encodings - .windows(2) - .all(|window| window[0].value_type() == window[1].value_type()), - "inconsistent element types in array {:?}", - value - ); - - Some(EncodedValue::Array(encodings)) - } - } - } - - /// Returns whether an encoding is present for a value id. - pub(crate) fn contains(&self, id: &ValueId) -> bool { - self.encodings.contains_key(&id.to_u64().into()) - } -} - -#[cfg(test)] -mod tests { - use std::marker::PhantomData; - - use super::*; - - use mpz_circuits::types::StaticValueType; - use mpz_garble_core::{encoding_state, ChaChaEncoder, Encoder}; - use rstest::*; - - #[fixture] - fn encoder() -> ChaChaEncoder { - ChaChaEncoder::new([0; 32]) - } - - fn generate_encoding( - encoder: ChaChaEncoder, - value: &ValueRef, - ty: &ValueType, - ) -> EncodedValue { - match (value, ty) { - (ValueRef::Value { id }, ty) => encoder.encode_by_type(id.to_u64(), ty), - (ValueRef::Array(ids), ValueType::Array(elem_ty, _)) => EncodedValue::Array( - ids.iter() - .map(|id| encoder.encode_by_type(id.to_u64(), elem_ty)) - .collect(), - ), - _ => panic!(), - } - } - - #[rstest] - #[case::bit(PhantomData::)] - #[case::u8(PhantomData::)] - #[case::u16(PhantomData::)] - #[case::u64(PhantomData::)] - #[case::u64(PhantomData::)] - #[case::u128(PhantomData::)] - #[case::bit_array(PhantomData::<[bool; 16]>)] - #[case::u8_array(PhantomData::<[u8; 16]>)] - #[case::u16_array(PhantomData::<[u16; 16]>)] - #[case::u32_array(PhantomData::<[u32; 16]>)] - #[case::u64_array(PhantomData::<[u64; 16]>)] - #[case::u128_array(PhantomData::<[u128; 16]>)] - fn test_value_registry_duplicate_fails(#[case] _ty: PhantomData) - where - T: StaticValueType + Default + std::fmt::Debug, - { - let mut value_registry = ValueRegistry::default(); - - let _ = value_registry.add_value("test", T::value_type()).unwrap(); - - let err = value_registry - .add_value("test", T::value_type()) - .unwrap_err(); - - assert!(matches!(err, MemoryError::DuplicateValueId(_))); - } - - #[rstest] - #[case::bit(PhantomData::)] - #[case::u8(PhantomData::)] - #[case::u16(PhantomData::)] - #[case::u64(PhantomData::)] - #[case::u64(PhantomData::)] - #[case::u128(PhantomData::)] - #[case::bit_array(PhantomData::<[bool; 16]>)] - #[case::u8_array(PhantomData::<[u8; 16]>)] - #[case::u16_array(PhantomData::<[u16; 16]>)] - #[case::u32_array(PhantomData::<[u32; 16]>)] - #[case::u64_array(PhantomData::<[u64; 16]>)] - #[case::u128_array(PhantomData::<[u128; 16]>)] - fn test_encoding_registry_set_duplicate_fails( - encoder: ChaChaEncoder, - #[case] _ty: PhantomData, - ) where - T: StaticValueType + Default + std::fmt::Debug, - { - let mut value_registry = ValueRegistry::default(); - let mut full_encoding_registry = EncodingRegistry::::default(); - let mut active_encoding_registry = EncodingRegistry::::default(); - - let ty = T::value_type(); - let value = value_registry.add_value("test", ty.clone()).unwrap(); - - let encoding = generate_encoding(encoder, &value, &ty); - - full_encoding_registry - .set_encoding(&value, encoding.clone()) - .unwrap(); - - let err = full_encoding_registry - .set_encoding(&value, encoding.clone()) - .unwrap_err(); - - assert!(matches!(err, EncodingRegistryError::DuplicateId(_))); - - let encoding = encoding.select(T::default()).unwrap(); - - active_encoding_registry - .set_encoding(&value, encoding.clone()) - .unwrap(); - - let err = active_encoding_registry - .set_encoding(&value, encoding) - .unwrap_err(); - - assert!(matches!(err, EncodingRegistryError::DuplicateId(_))); - } -} diff --git a/garble/mpz-garble/src/threadpool.rs b/garble/mpz-garble/src/threadpool.rs index 79a178c1..81feda2b 100644 --- a/garble/mpz-garble/src/threadpool.rs +++ b/garble/mpz-garble/src/threadpool.rs @@ -123,10 +123,13 @@ mod tests { thread: &mut T, n: usize, ) -> Result<[u8; 16], VmError> { - let key = thread.new_private_input(&format!("key/{n}"), Some([0u8; 16]))?; - let msg = thread.new_private_input(&format!("msg/{n}"), Some([0u8; 16]))?; + let key = thread.new_private_input::<[u8; 16]>(&format!("key/{n}"))?; + let msg = thread.new_private_input::<[u8; 16]>(&format!("msg/{n}"))?; let ciphertext = thread.new_output::<[u8; 16]>(&format!("ciphertext/{n}"))?; + thread.assign(&key, [0u8; 16])?; + thread.assign(&msg, [0u8; 16])?; + thread .execute(AES128.clone(), &[key, msg], &[ciphertext.clone()]) .await?; @@ -140,8 +143,8 @@ mod tests { thread: &mut T, n: usize, ) -> Result<[u8; 16], VmError> { - let key = thread.new_private_input::<[u8; 16]>(&format!("key/{n}"), None)?; - let msg = thread.new_private_input::<[u8; 16]>(&format!("msg/{n}"), None)?; + let key = thread.new_blind_input::<[u8; 16]>(&format!("key/{n}"))?; + let msg = thread.new_blind_input::<[u8; 16]>(&format!("msg/{n}"))?; let ciphertext = thread.new_output::<[u8; 16]>(&format!("ciphertext/{n}"))?; thread diff --git a/mpz-core/src/value.rs b/garble/mpz-garble/src/value.rs similarity index 54% rename from mpz-core/src/value.rs rename to garble/mpz-garble/src/value.rs index 88beacfd..c6a14e8b 100644 --- a/mpz-core/src/value.rs +++ b/garble/mpz-garble/src/value.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use crate::utils::blake3; +use mpz_core::utils::blake3; /// A unique ID for a value. #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] @@ -19,6 +19,11 @@ impl ValueId { Self::new(&format!("{}/{}", self.0, id)) } + /// Returns a new value ID with the provided counter appended. + pub fn append_counter(&self, counter: usize) -> Self { + Self::new(&format!("{}/{}", self.0, counter)) + } + /// Returns the u64 representation of the value ID. /// /// # Warning @@ -38,6 +43,39 @@ impl AsRef for ValueId { } } +/// A reference to an array value. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ArrayRef { + ids: Vec, +} + +impl ArrayRef { + /// Creates a new array reference. + /// + /// # Invariants + /// + /// The outer context must enforce the following invariants: + /// + /// * The array must have at least one value. + /// * All values in the array must have the same type. + pub(crate) fn new(ids: Vec) -> Self { + assert!(!ids.is_empty(), "cannot create an array with no values"); + + Self { ids } + } + + /// Returns the value IDs. + pub(crate) fn ids(&self) -> &[ValueId] { + &self.ids + } + + /// Returns the number of values. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.ids.len() + } +} + /// A reference to a value. /// /// Every single value is assigned a unique ID. Whereas, arrays are @@ -47,8 +85,8 @@ impl AsRef for ValueId { pub enum ValueRef { /// A single value. Value { id: ValueId }, - /// An array of values. - Array(Vec), + /// A reference to an array of values. + Array(ArrayRef), } impl ValueRef { @@ -57,7 +95,7 @@ impl ValueRef { pub fn len(&self) -> usize { match self { ValueRef::Value { .. } => 1, - ValueRef::Array(values) => values.len(), + ValueRef::Array(values) => values.ids.len(), } } @@ -69,12 +107,13 @@ impl ValueRef { ValueRef::Value { id: value_id } => ValueRef::Value { id: value_id.append_id(id), }, - ValueRef::Array(values) => ValueRef::Array( - values + ValueRef::Array(values) => ValueRef::Array(ArrayRef { + ids: values + .ids .iter() .map(|value_id| value_id.append_id(id)) .collect(), - ), + }), } } @@ -84,10 +123,30 @@ impl ValueRef { } /// Returns an iterator of the value IDs. - pub fn iter(&self) -> Box + '_> { + pub fn iter(&self) -> ValueRefIter<'_> { + match self { + ValueRef::Value { id } => ValueRefIter::Value(std::iter::once(id)), + ValueRef::Array(values) => ValueRefIter::Array(values.ids.iter()), + } + } +} + +/// An iterator over value IDs of a reference. +pub enum ValueRefIter<'a> { + /// A single value. + Value(std::iter::Once<&'a ValueId>), + /// An array of values. + Array(std::slice::Iter<'a, ValueId>), +} + +impl<'a> Iterator for ValueRefIter<'a> { + type Item = &'a ValueId; + + #[inline] + fn next(&mut self) -> Option { match self { - ValueRef::Value { id, .. } => Box::new(std::iter::once(id)), - ValueRef::Array(values) => Box::new(values.iter()), + ValueRefIter::Value(iter) => iter.next(), + ValueRefIter::Array(iter) => iter.next(), } } } diff --git a/garble/mpz-garble/tests/semihonest.rs b/garble/mpz-garble/tests/semihonest.rs index a4a5513d..c0b449c8 100644 --- a/garble/mpz-garble/tests/semihonest.rs +++ b/garble/mpz-garble/tests/semihonest.rs @@ -3,9 +3,7 @@ use mpz_garble_core::msg::GarbleMessage; use mpz_ot::mock::mock_ot_shared_pair; use utils_aio::duplex::MemoryDuplex; -use mpz_garble::{ - config::ValueConfig, Evaluator, Generator, GeneratorConfigBuilder, ValueRegistry, -}; +use mpz_garble::{config::Visibility, Evaluator, Generator, GeneratorConfigBuilder, ValueMemory}; #[tokio::test] async fn test_semi_honest() { @@ -18,36 +16,33 @@ async fn test_semi_honest() { ); let ev = Evaluator::default(); - let mut value_registry = ValueRegistry::default(); - let key = [69u8; 16]; let msg = [42u8; 16]; - let key_ref = value_registry - .add_value("key", <[u8; 16]>::value_type()) - .unwrap(); - let msg_ref = value_registry - .add_value("msg", <[u8; 16]>::value_type()) - .unwrap(); - let ciphertext_ref = value_registry - .add_value("ciphertext", <[u8; 16]>::value_type()) - .unwrap(); - let gen_fut = async { - let value_configs = [ - ValueConfig::new_private::<[u8; 16]>(key_ref.clone(), Some(key)) - .unwrap() - .flatten(), - ValueConfig::new_private::<[u8; 16]>(msg_ref.clone(), None) - .unwrap() - .flatten(), - ] - .concat(); - - gen.setup_inputs("test", &value_configs, &mut gen_channel, &ot_send) - .await + let mut memory = ValueMemory::default(); + + let key_ref = memory + .new_input("key", <[u8; 16]>::value_type(), Visibility::Private) + .unwrap(); + let msg_ref = memory + .new_input("msg", <[u8; 16]>::value_type(), Visibility::Blind) + .unwrap(); + let ciphertext_ref = memory + .new_output("ciphertext", <[u8; 16]>::value_type()) .unwrap(); + memory.assign(&key_ref, key.into()).unwrap(); + + gen.setup_assigned_values( + "test", + &memory.drain_assigned(&[key_ref.clone(), msg_ref.clone()]), + &mut gen_channel, + &ot_send, + ) + .await + .unwrap(); + gen.generate( AES128.clone(), &[key_ref.clone(), msg_ref.clone()], @@ -57,22 +52,33 @@ async fn test_semi_honest() { ) .await .unwrap(); + + gen.get_encoding(&ciphertext_ref).unwrap() }; let ev_fut = async { - let value_configs = [ - ValueConfig::new_private::<[u8; 16]>(key_ref.clone(), None) - .unwrap() - .flatten(), - ValueConfig::new_private::<[u8; 16]>(msg_ref.clone(), Some(msg)) - .unwrap() - .flatten(), - ] - .concat(); - - ev.setup_inputs("test", &value_configs, &mut ev_channel, &ot_recv) - .await + let mut memory = ValueMemory::default(); + + let key_ref = memory + .new_input("key", <[u8; 16]>::value_type(), Visibility::Blind) .unwrap(); + let msg_ref = memory + .new_input("msg", <[u8; 16]>::value_type(), Visibility::Private) + .unwrap(); + let ciphertext_ref = memory + .new_output("ciphertext", <[u8; 16]>::value_type()) + .unwrap(); + + memory.assign(&msg_ref, msg.into()).unwrap(); + + ev.setup_assigned_values( + "test", + &memory.drain_assigned(&[key_ref.clone(), msg_ref.clone()]), + &mut ev_channel, + &ot_recv, + ) + .await + .unwrap(); _ = ev .evaluate( @@ -83,12 +89,11 @@ async fn test_semi_honest() { ) .await .unwrap(); - }; - tokio::join!(gen_fut, ev_fut); + ev.get_encoding(&ciphertext_ref).unwrap() + }; - let ciphertext_full_encoding = gen.get_encoding(&ciphertext_ref).unwrap(); - let ciphertext_active_encoding = ev.get_encoding(&ciphertext_ref).unwrap(); + let (ciphertext_full_encoding, ciphertext_active_encoding) = tokio::join!(gen_fut, ev_fut); let decoding = ciphertext_full_encoding.decoding(); let ciphertext: [u8; 16] = ciphertext_active_encoding diff --git a/mpz-circuits/src/types.rs b/mpz-circuits/src/types.rs index 18738b6b..dd68eaa3 100644 --- a/mpz-circuits/src/types.rs +++ b/mpz-circuits/src/types.rs @@ -41,6 +41,11 @@ pub trait StaticValueType: Into { fn value_type() -> ValueType; } +/// A primitive type. +/// +/// For example, `u8` is a primitive type, but `[u8; 4]` is not. +pub trait PrimitiveType: StaticValueType + BinaryLength {} + /// A type that has a constant bit length. pub trait BinaryLength { /// The length of the type in bits. @@ -220,6 +225,8 @@ macro_rules! define_binary_value { const LEN: usize = $len; } + impl PrimitiveType for $ty {} + impl BinaryLength for [$ty; N] { const LEN: usize = $len * N; } diff --git a/mpz-core/src/lib.rs b/mpz-core/src/lib.rs index c5eda19d..59afa9aa 100644 --- a/mpz-core/src/lib.rs +++ b/mpz-core/src/lib.rs @@ -14,7 +14,6 @@ pub mod prp; pub mod serialize; pub mod tkprp; pub mod utils; -pub mod value; pub use block::{Block, BlockSerialize};