Skip to content

Commit

Permalink
add output validation
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 committed Jun 10, 2024
1 parent 5d07bcf commit 08cf25d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 29 deletions.
95 changes: 67 additions & 28 deletions crates/mpz-circuits-generic/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct CircuitBuilder<T> {
inputs: Vec<Node>,
outputs: Vec<Node>,
gates: Vec<T>,
stack_size: usize,
}

impl<T> Default for CircuitBuilder<T> {
Expand All @@ -25,6 +26,7 @@ impl<T> Default for CircuitBuilder<T> {
inputs: Default::default(),
outputs: Default::default(),
gates: Default::default(),
stack_size: 0,
}
}
}
Expand Down Expand Up @@ -53,6 +55,7 @@ where
pub fn add_input(&mut self) -> Node {
let input = self.current_node.next();
self.inputs.push(input);
self.stack_size += 1;
input
}

Expand All @@ -71,10 +74,14 @@ where
{
let gate = f(&mut Next(&mut self.current_node));

if gate.get_inputs().count() == 0 || gate.get_outputs().count() == 0 {
let output_count = gate.get_outputs().count();

if output_count == 0 || gate.get_inputs().count() == 0 {
return Err(CircuitBuilderError::DisconnectedGate);
}

self.stack_size += output_count;

self.gates.push(gate);

Ok(self.gates.last().unwrap())
Expand All @@ -86,23 +93,34 @@ where
return Err(CircuitBuilderError::EmptyCircuit);
}

// Get the total stack length.
let total_length = self.current_node.0;
let mut gate_inputs = std::collections::HashSet::new();
let mut gate_outputs = std::collections::HashSet::new();

// Verify that all nodes are within the stack.
for gate in &self.gates {
for node in gate.get_inputs() {
if node.0 >= total_length {
for input in gate.get_inputs() {
if input.0 as usize >= self.stack_size {
return Err(CircuitBuilderError::NodeOutOfIndex);
}
gate_inputs.insert(*input);
}
for node in gate.get_outputs() {
if node.0 >= total_length {

for output in gate.get_outputs() {
if output.0 as usize >= self.stack_size {
return Err(CircuitBuilderError::NodeOutOfIndex);
}
gate_outputs.insert(*output);
}
}

// Verify that output nodes are not inputs to any gate
if self
.outputs
.iter()
.any(|output| gate_inputs.contains(output))
{
return Err(CircuitBuilderError::OutputValidationFailed);
}

Ok(Circuit::new(
self.inputs.len(),
self.outputs.len(),
Expand Down Expand Up @@ -155,6 +173,8 @@ pub enum CircuitBuilderError {
DisconnectedGate,
#[error("Empty circuit")]
EmptyCircuit,
#[error("Output validation failed")]
OutputValidationFailed,
#[error("Node out of index")]
NodeOutOfIndex,
}
Expand Down Expand Up @@ -245,7 +265,7 @@ mod tests {
let (in_0, in_1) = (builder.add_input(), builder.add_input());

// Add a valid gate
let &Gate { output, .. } = builder
let &Gate { .. } = builder
.add_gate(|next| Gate {
inputs: vec![in_0, in_1],
output: next.next(),
Expand All @@ -258,31 +278,12 @@ mod tests {
output: next.next(),
});

// Ensure the disconnected gate is detected
assert!(gate_result.is_err(), "Expected disconnected gate error");
assert_eq!(
gate_result.unwrap_err(),
CircuitBuilderError::DisconnectedGate,
"Unexpected error type"
);

// Add valid gate
let &Gate { output, .. } = builder
.add_gate(|next| Gate {
inputs: vec![output, in_1],
output: next.next(),
})
.unwrap();

builder.add_output(output);

// Build the circuit
let circuit = builder.build();
assert!(
circuit.is_ok(),
"Failed to build circuit: {:?}",
circuit.err()
);
}

#[test]
Expand Down Expand Up @@ -322,4 +323,42 @@ mod tests {
"Unexpected error type"
);
}

#[test]
fn test_output_validation() {
let mut builder = CircuitBuilder::<Gate>::new();

let in_0 = builder.add_input();
let in_1 = builder.add_input();

let &Gate { output, .. } = builder
.add_gate(|next| Gate {
inputs: vec![in_0, in_1],
output: next.next(),
})
.unwrap();

builder.add_output(output);

// Use the output node as an input to a new gate
let &Gate {
output: new_output, ..
} = builder
.add_gate(|next| Gate {
inputs: vec![output, in_0],
output: next.next(),
})
.unwrap();

builder.add_output(new_output);

let circuit = builder.build();

assert!(circuit.is_err(), "Expected output validation error");
assert_eq!(
circuit.unwrap_err(),
CircuitBuilderError::OutputValidationFailed,
"Unexpected error type"
);
}
}
2 changes: 1 addition & 1 deletion crates/mpz-circuits-generic/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub trait Component {
}

/// A circuit node.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub struct Node(pub(crate) u32);

impl Node {
Expand Down

0 comments on commit 08cf25d

Please sign in to comment.