From 08cf25d205845c209b73076ed0067d2503898f5e Mon Sep 17 00:00:00 2001 From: Brechy Date: Mon, 10 Jun 2024 18:12:49 -0300 Subject: [PATCH] add output validation --- crates/mpz-circuits-generic/src/circuit.rs | 95 +++++++++++++++------- crates/mpz-circuits-generic/src/model.rs | 2 +- 2 files changed, 68 insertions(+), 29 deletions(-) diff --git a/crates/mpz-circuits-generic/src/circuit.rs b/crates/mpz-circuits-generic/src/circuit.rs index 97255640..c861975a 100644 --- a/crates/mpz-circuits-generic/src/circuit.rs +++ b/crates/mpz-circuits-generic/src/circuit.rs @@ -16,6 +16,7 @@ pub struct CircuitBuilder { inputs: Vec, outputs: Vec, gates: Vec, + stack_size: usize, } impl Default for CircuitBuilder { @@ -25,6 +26,7 @@ impl Default for CircuitBuilder { inputs: Default::default(), outputs: Default::default(), gates: Default::default(), + stack_size: 0, } } } @@ -53,6 +55,7 @@ where pub fn add_input(&mut self) -> Node { let input = self.current_node.next(); self.inputs.push(input); + self.stack_size += 1; input } @@ -71,10 +74,14 @@ where { let gate = f(&mut Next(&mut self.current_node)); - if gate.get_inputs().count() == 0 || gate.get_outputs().count() == 0 { + let output_count = gate.get_outputs().count(); + + if output_count == 0 || gate.get_inputs().count() == 0 { return Err(CircuitBuilderError::DisconnectedGate); } + self.stack_size += output_count; + self.gates.push(gate); Ok(self.gates.last().unwrap()) @@ -86,23 +93,34 @@ where return Err(CircuitBuilderError::EmptyCircuit); } - // Get the total stack length. - let total_length = self.current_node.0; + let mut gate_inputs = std::collections::HashSet::new(); + let mut gate_outputs = std::collections::HashSet::new(); - // Verify that all nodes are within the stack. for gate in &self.gates { - for node in gate.get_inputs() { - if node.0 >= total_length { + for input in gate.get_inputs() { + if input.0 as usize >= self.stack_size { return Err(CircuitBuilderError::NodeOutOfIndex); } + gate_inputs.insert(*input); } - for node in gate.get_outputs() { - if node.0 >= total_length { + + for output in gate.get_outputs() { + if output.0 as usize >= self.stack_size { return Err(CircuitBuilderError::NodeOutOfIndex); } + gate_outputs.insert(*output); } } + // Verify that output nodes are not inputs to any gate + if self + .outputs + .iter() + .any(|output| gate_inputs.contains(output)) + { + return Err(CircuitBuilderError::OutputValidationFailed); + } + Ok(Circuit::new( self.inputs.len(), self.outputs.len(), @@ -155,6 +173,8 @@ pub enum CircuitBuilderError { DisconnectedGate, #[error("Empty circuit")] EmptyCircuit, + #[error("Output validation failed")] + OutputValidationFailed, #[error("Node out of index")] NodeOutOfIndex, } @@ -245,7 +265,7 @@ mod tests { let (in_0, in_1) = (builder.add_input(), builder.add_input()); // Add a valid gate - let &Gate { output, .. } = builder + let &Gate { .. } = builder .add_gate(|next| Gate { inputs: vec![in_0, in_1], output: next.next(), @@ -258,31 +278,12 @@ mod tests { output: next.next(), }); - // Ensure the disconnected gate is detected assert!(gate_result.is_err(), "Expected disconnected gate error"); assert_eq!( gate_result.unwrap_err(), CircuitBuilderError::DisconnectedGate, "Unexpected error type" ); - - // Add valid gate - let &Gate { output, .. } = builder - .add_gate(|next| Gate { - inputs: vec![output, in_1], - output: next.next(), - }) - .unwrap(); - - builder.add_output(output); - - // Build the circuit - let circuit = builder.build(); - assert!( - circuit.is_ok(), - "Failed to build circuit: {:?}", - circuit.err() - ); } #[test] @@ -322,4 +323,42 @@ mod tests { "Unexpected error type" ); } + + #[test] + fn test_output_validation() { + let mut builder = CircuitBuilder::::new(); + + let in_0 = builder.add_input(); + let in_1 = builder.add_input(); + + let &Gate { output, .. } = builder + .add_gate(|next| Gate { + inputs: vec![in_0, in_1], + output: next.next(), + }) + .unwrap(); + + builder.add_output(output); + + // Use the output node as an input to a new gate + let &Gate { + output: new_output, .. + } = builder + .add_gate(|next| Gate { + inputs: vec![output, in_0], + output: next.next(), + }) + .unwrap(); + + builder.add_output(new_output); + + let circuit = builder.build(); + + assert!(circuit.is_err(), "Expected output validation error"); + assert_eq!( + circuit.unwrap_err(), + CircuitBuilderError::OutputValidationFailed, + "Unexpected error type" + ); + } } diff --git a/crates/mpz-circuits-generic/src/model.rs b/crates/mpz-circuits-generic/src/model.rs index a262da84..bfe6d2ad 100644 --- a/crates/mpz-circuits-generic/src/model.rs +++ b/crates/mpz-circuits-generic/src/model.rs @@ -12,7 +12,7 @@ pub trait Component { } /// A circuit node. -#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] pub struct Node(pub(crate) u32); impl Node {