Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: decouple zk execution from proving output #83

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions garble/mpz-garble/src/evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ impl Evaluator {
self.state().memory.get_encoding(value)
}

/// Returns the encodings for a slice of values.
pub fn get_encodings(
&self,
values: &[ValueRef],
) -> Result<Vec<EncodedValue<encoding_state::Active>>, EvaluatorError> {
let state = self.state();

values
.iter()
.map(|value| {
state
.memory
.get_encoding(value)
.ok_or_else(|| EvaluatorError::MissingEncoding(value.clone()))
})
.collect()
}

/// Adds a decoding log entry.
pub(crate) fn add_decoding_log(&self, value: &ValueRef, decoding: Decoding) {
self.state().decoding_logs.insert(value.clone(), decoding);
Expand Down
17 changes: 17 additions & 0 deletions garble/mpz-garble/src/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ impl Generator {
self.state().memory.get_encoding(value)
}

/// Returns the encodings for a slice of values.
pub fn get_encodings(
&self,
values: &[ValueRef],
) -> Result<Vec<EncodedValue<encoding_state::Full>>, GeneratorError> {
let state = self.state();
values
.iter()
.map(|value| {
state
.memory
.get_encoding(value)
.ok_or_else(|| GeneratorError::MissingEncoding(value.clone()))
})
.collect()
}

pub(crate) fn get_encodings_by_id(
&self,
ids: &[ValueId],
Expand Down
19 changes: 14 additions & 5 deletions garble/mpz-garble/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,25 +338,34 @@ pub trait Execute {
/// This trait provides methods for proving the output of a circuit.
#[async_trait]
pub trait Prove {
/// Proves the output of the circuit with the provided inputs, assigning to the provided output values
async fn prove(
/// Executes the provided circuit as the prover, assigning to the provided output values.
async fn execute_prove(
&mut self,
circ: Arc<Circuit>,
inputs: &[ValueRef],
outputs: &[ValueRef],
) -> Result<(), ProveError>;

/// Proves the provided values.
async fn prove(&mut self, values: &[ValueRef]) -> Result<(), ProveError>;
}

/// This trait provides methods for verifying the output of a circuit.
#[async_trait]
pub trait Verify {
/// Verifies the output of the circuit with the provided inputs, assigning to the provided output values
async fn verify(
/// Executes the provided circuit as the verifier, assigning to the provided output values.
async fn execute_verify(
&mut self,
circ: Arc<Circuit>,
inputs: &[ValueRef],
outputs: &[ValueRef],
expected_outputs: &[Value],
) -> Result<(), VerifyError>;

/// Verifies the provided values against the expected values.
async fn verify(
&mut self,
values: &[ValueRef],
expected_values: &[Value],
) -> Result<(), VerifyError>;
}

Expand Down
6 changes: 5 additions & 1 deletion garble/mpz-garble/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ where
value: &ValueRef,
encoding: EncodedValue<T>,
) -> Result<(), EncodingMemoryError> {
let encoding_type = encoding.value_type();
match (value, encoding) {
(ValueRef::Value { id }, encoding) => self.set_encoding_by_id(id, encoding)?,
(ValueRef::Array(array), EncodedValue::Array(encodings))
Expand All @@ -397,7 +398,10 @@ where
self.set_encoding_by_id(id, encoding)?
}
}
_ => panic!("value type {:?} does not match encoding type", value),
_ => panic!(
"value type {:?} does not match encoding type: {:?}",
value, encoding_type
),
}

Ok(())
Expand Down
121 changes: 78 additions & 43 deletions garble/mpz-garble/src/protocol/deap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,17 @@ impl DEAP {
/// * `stream` - The stream to receive messages from.
/// * `ot_recv` - The OT receiver.
#[allow(clippy::too_many_arguments)]
pub async fn defer_prove<T, U, OTR>(
pub async fn execute_prove<S, OTR>(
&self,
id: &str,
circ: Arc<Circuit>,
inputs: &[ValueRef],
outputs: &[ValueRef],
sink: &mut T,
stream: &mut U,
stream: &mut S,
ot_recv: &OTR,
) -> Result<(), DEAPError>
where
T: Sink<GarbleMessage, Error = std::io::Error> + Unpin,
U: Stream<Item = Result<GarbleMessage, std::io::Error>> + Unpin,
S: Stream<Item = Result<GarbleMessage, std::io::Error>> + Unpin,
OTR: OTReceiveEncoding,
{
if matches!(self.role, Role::Follower) {
Expand All @@ -275,60 +273,40 @@ impl DEAP {
.map_err(DEAPError::from)
.await?;

let outputs = self
.ev
self.ev
.evaluate(circ, inputs, outputs, stream)
.map_err(DEAPError::from)
.await?;

let output_digest = outputs.hash();
let (decommitment, commitment) = output_digest.hash_commit();

// Store output proof decommitment until finalization
self.state()
.proof_decommitments
.insert(id.to_string(), decommitment);

sink.send(GarbleMessage::HashCommitment(commitment)).await?;

Ok(())
}

/// Verifies the output of a circuit.
/// Executes the circuit where only the follower is the generator.
///
/// # Notes
///
/// This function can only be called by the follower.
///
/// This function does _not_ verify the output right away,
/// instead the leader commits to the proof and later it is checked
/// during the call to [`finalize`](Self::finalize).
///
/// # Arguments
///
/// * `id` - The ID of the circuit.
/// * `circ` - The circuit to execute.
/// * `inputs` - The inputs to the circuit.
/// * `outputs` - The outputs to the circuit.
/// * `expected_outputs` - The expected outputs of the circuit.
/// * `sink` - The sink to send messages to.
/// * `stream` - The stream to receive messages from.
/// * `ot_send` - The OT sender.
#[allow(clippy::too_many_arguments)]
pub async fn defer_verify<T, U, OTS>(
pub async fn execute_verify<T, OTS>(
&self,
id: &str,
circ: Arc<Circuit>,
inputs: &[ValueRef],
outputs: &[ValueRef],
expected_outputs: &[Value],
sink: &mut T,
stream: &mut U,
ot_send: &OTS,
) -> Result<(), DEAPError>
where
T: Sink<GarbleMessage, Error = std::io::Error> + Unpin,
U: Stream<Item = Result<GarbleMessage, std::io::Error>> + Unpin,
OTS: OTSendEncoding,
{
if matches!(self.role, Role::Leader) {
Expand All @@ -346,19 +324,64 @@ impl DEAP {
.map_err(DEAPError::from)
.await?;

let (encoded_outputs, _) = self
.gen
self.gen
.generate(circ.clone(), inputs, outputs, sink, false)
.map_err(DEAPError::from)
.await?;

let expected_outputs = expected_outputs
Ok(())
}

/// Sends a commitment to the provided values, proving them to the follower upon finalization.
pub async fn defer_prove<S: Sink<GarbleMessage, Error = std::io::Error> + Unpin>(
&self,
id: &str,
values: &[ValueRef],
sink: &mut S,
) -> Result<(), DEAPError> {
let encoded_values = self.ev.get_encodings(values)?;

let encoding_digest = encoded_values.hash();
let (decommitment, commitment) = encoding_digest.hash_commit();

// Store output proof decommitment until finalization
self.state()
.proof_decommitments
.insert(id.to_string(), decommitment);

sink.send(GarbleMessage::HashCommitment(commitment)).await?;

Ok(())
}

/// Receives a commitment to the provided values, and stores it until finalization.
///
/// # Notes
///
/// This function does not verify the values until [`finalize`](Self::finalize).
///
/// # Arguments
///
/// * `id` - The ID of the operation
/// * `values` - The values to receive a commitment to
/// * `expected_values` - The expected values which will be verified against the commitment
/// * `stream` - The stream to receive messages from
pub async fn defer_verify<S: Stream<Item = Result<GarbleMessage, std::io::Error>> + Unpin>(
&self,
id: &str,
values: &[ValueRef],
expected_values: &[Value],
stream: &mut S,
) -> Result<(), DEAPError> {
let encoded_values = self.gen.get_encodings(values)?;

let expected_values = expected_values
.iter()
.zip(encoded_outputs)
.map(|(expected, encoded)| encoded.select(expected.clone()).unwrap())
.collect::<Vec<_>>();
.zip(encoded_values)
.map(|(expected, encoded)| encoded.select(expected.clone()))
.collect::<Result<Vec<_>, _>>()?;

let expected_digest = expected_outputs.hash();
let expected_digest = expected_values.hash();

let commitment = expect_msg_or_err!(stream, GarbleMessage::HashCommitment)?;

Expand Down Expand Up @@ -1359,18 +1382,22 @@ mod tests {

async move {
leader
.defer_prove(
"test",
.execute_prove(
"test0",
AES128.clone(),
&[key_ref, msg_ref],
&[ciphertext_ref],
&mut sink,
&[ciphertext_ref.clone()],
&mut stream,
&leader_ot_recv,
)
.await
.unwrap();

leader
.defer_prove("test1", &[ciphertext_ref], &mut sink)
.await
.unwrap();

leader
.finalize(&mut sink, &mut stream, &leader_ot_recv)
.await
Expand All @@ -1388,15 +1415,23 @@ mod tests {

async move {
follower
.defer_verify(
"test",
.execute_verify(
"test0",
AES128.clone(),
&[key_ref, msg_ref],
&[ciphertext_ref.clone()],
&mut sink,
&follower_ot_send,
)
.await
.unwrap();

follower
.defer_verify(
"test1",
&[ciphertext_ref],
&[expected_ciphertext.into()],
&mut sink,
&mut stream,
&follower_ot_send,
)
.await
.unwrap();
Expand Down
Loading