Skip to content

Commit

Permalink
implement missing checks
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 committed Jun 8, 2024
1 parent cfe3d7f commit 5d07bcf
Showing 1 changed file with 111 additions and 198 deletions.
309 changes: 111 additions & 198 deletions crates/mpz-circuits-generic/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
//! Main circuit module.
use crate::{model::Component, Node};
use std::{
collections::{HashMap, VecDeque},
mem::take,
};
use thiserror::Error;

/// The Circuit Builder assembles a collection of gates into a circuit.
Expand Down Expand Up @@ -75,7 +71,9 @@ where
{
let gate = f(&mut Next(&mut self.current_node));

// Verify the gate has at least one input and output
if gate.get_inputs().count() == 0 || gate.get_outputs().count() == 0 {
return Err(CircuitBuilderError::DisconnectedGate);
}

self.gates.push(gate);

Expand All @@ -84,92 +82,33 @@ where

/// Builds the circuit.
pub fn build(self) -> Result<Circuit<T>, CircuitBuilderError> {
// Verify that the circuit has at least one gate.
if self.gates.is_empty() {
return Err(CircuitBuilderError::EmptyCircuit);
}

// Get the total stack length.

// Verifying that no node is out of index.
let total_length = self.current_node.0;

// Verify that all nodes are within the stack.
for gate in &self.gates {
for node in gate.get_inputs() {
if node.0 >= total_length {
return Err(CircuitBuilderError::NodeOutOfIndex);
}
}
for node in gate.get_outputs() {
if node.0 >= total_length {
return Err(CircuitBuilderError::NodeOutOfIndex);
}
}
}

Ok(Circuit::new(
self.inputs.len(),
self.outputs.len(),
self.gates,
))
}

// /// Performs a topological sort of the gates.
// ///
// /// This ensures that the gates are linearly ordered such that the
// /// dependencies (input gates) of each gate are processed before the gate itself.
// ///
// /// This requires that the gates form a directed acyclic graph (DAG).
// ///
// /// The sorting is done using Kahn's Algorithm.
// fn sort_gates(&mut self) -> Result<(), CircuitBuilderError> {
// // In-degree: the number of gates that provide input to each gate
// // This represents how many other gates need to be processed before this gate
// let mut in_degree = vec![0; self.gates.len()];
// // Adjacency list: for each gate, list the gates that directly depend on its output
// // This is used to keep track of which gates need to be updated after processing a gate
// let mut adjacency_list = vec![vec![]; self.gates.len()];

// // Populate lists
// for (i, gate) in self.gates.iter().enumerate() {
// for output in gate.get_outputs() {
// let output = self.input_map.get(&output.id);

// if let Some(&gate_index) = output {
// adjacency_list[i].push(gate_index);
// in_degree[gate_index] += 1;
// }
// }
// }

// let mut queue = VecDeque::new();
// let mut sorted_indices = Vec::with_capacity(self.gates.len());

// // Push ready-to-process nodes (no dependencies) to the queue
// for (i, &degree) in in_degree.iter().enumerate() {
// if degree == 0 {
// queue.push_back(i);
// }
// }

// // Process nodes
// while let Some(node) = queue.pop_front() {
// sorted_indices.push(node);

// // Reduce in-degree of dependent nodes
// for &neighbor in &adjacency_list[node] {
// in_degree[neighbor] -= 1;

// // If the dependent node is now ready to be processed, add it to the queue
// if in_degree[neighbor] == 0 {
// queue.push_back(neighbor);
// }
// }
// }

// // If some node is left unprocessed, there is a cycle
// if sorted_indices.len() != self.gates.len() {
// return Err(CircuitBuilderError::CycleDetected);
// }

// // Sort the gates
// // To preserve the order of the gates we create this temporary vector of optionals
// let mut temp_gates: Vec<Option<T>> = self.gates.drain(..).map(Some).collect();
// let mut sorted_gates = Vec::with_capacity(temp_gates.len());
// for &i in &sorted_indices {
// // Whenever we take a gate from the vector we replace it with None
// // This way we avoid shifting items
// if let Some(gate) = temp_gates[i].take() {
// sorted_gates.push(gate);
// }
// }

// self.gates = sorted_gates;
// Ok(())
// }
}

/// A circuit constructed from a collection of gates.
Expand Down Expand Up @@ -212,8 +151,12 @@ impl<T> Circuit<T> {
/// Circuit errors.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum CircuitBuilderError {
#[error("Cycle detected")]
CycleDetected,
#[error("Disconnected gate")]
DisconnectedGate,
#[error("Empty circuit")]
EmptyCircuit,
#[error("Node out of index")]
NodeOutOfIndex,
}

#[cfg(test)]
Expand Down Expand Up @@ -284,129 +227,99 @@ mod tests {
);
assert_eq!(
gates[1].get_outputs().collect::<Vec<_>>(),
vec![&Node(10)],
"Second gate outputs mismatch" // Gate 5
vec![&Node(3)],
"Second gate outputs mismatch" // Gate 2
);
assert_eq!(
gates[2].get_outputs().collect::<Vec<_>>(),
vec![&Node(5)],
"Third gate outputs mismatch" // Gate 2
vec![&Node(4)],
"Third gate outputs mismatch" // Gate 3
);
}

#[test]
fn test_builder_add_gate() {
// Setup circuit builder
let mut builder = CircuitBuilder::<Gate>::new();

let (in_0, in_1) = (builder.add_input(), builder.add_input());

// Add a valid gate
let &Gate { output, .. } = builder
.add_gate(|next| Gate {
inputs: vec![in_0, in_1],
output: next.next(),
})
.unwrap();

// Add a disconnected gate
let gate_result = builder.add_gate(|next| Gate {
inputs: Vec::new(),
output: next.next(),
});

// Ensure the disconnected gate is detected
assert!(gate_result.is_err(), "Expected disconnected gate error");
assert_eq!(
gates[3].get_outputs().collect::<Vec<_>>(),
vec![&Node(6)],
"Fourth gate outputs mismatch" // Gate 3
gate_result.unwrap_err(),
CircuitBuilderError::DisconnectedGate,
"Unexpected error type"
);
assert_eq!(
gates[4].get_outputs().collect::<Vec<_>>(),
vec![&Node(7)],
"Fifth gate outputs mismatch" // Gate 4

// Add valid gate
let &Gate { output, .. } = builder
.add_gate(|next| Gate {
inputs: vec![output, in_1],
output: next.next(),
})
.unwrap();

builder.add_output(output);

// Build the circuit
let circuit = builder.build();
assert!(
circuit.is_ok(),
"Failed to build circuit: {:?}",
circuit.err()
);
}

#[test]
fn test_empty_circuit() {
let builder = CircuitBuilder::<Gate>::new();

let circuit = builder.build();

assert!(circuit.is_err(), "Expected empty circuit error");
assert_eq!(
gates[5].get_outputs().collect::<Vec<_>>(),
vec![&Node(11)],
"Sixth gate outputs mismatch" // Gate 6
circuit.unwrap_err(),
CircuitBuilderError::EmptyCircuit,
"Unexpected error type"
);
}

// #[test]
// fn test_cycle_detection() {
// // Setup circuit builder
// let mut circuit_builder = CircuitBuilder::<Gate>::new();

// // Define gates
// let gate1 = Gate {
// inputs: vec![Node(0), Node(1)],
// output: Node(2),
// };
// let gate2 = Gate {
// inputs: vec![Node(2), Node(3)],
// output: Node(4),
// };
// let cycle_gate = Gate {
// inputs: vec![Node(4)],
// output: Node(0),
// };

// // Add gates
// circuit_builder
// .add_gate(gate1)
// .add_gate(gate2)
// .add_gate(cycle_gate);

// // Expect build to fail
// let circuit = circuit_builder.build();
// assert!(circuit.is_err(), "Expected cycle detection error");
// assert_eq!(
// circuit.unwrap_err(),
// CircuitBuilderError::CycleDetected,
// "Unexpected error type"
// );
// }

// #[test]
// fn test_disconnected_gate() {
// // Setup circuit builder
// let mut circuit_builder = CircuitBuilder::<Gate>::new();

// // Define gates, with one gate disconnected
// let gate1 = Gate {
// inputs: vec![Node(0), Node(1)],
// output: Node(2),
// };
// let gate2 = Gate {
// inputs: vec![Node(3), Node(4)],
// output: Node(5),
// };
// let gate3 = Gate {
// inputs: vec![Node(2), Node(5)],
// output: Node(6),
// };
// let disconnected_gate = Gate {
// inputs: vec![Node(7), Node(8)],
// output: Node(9),
// };

// // Add gates including the disconnected gate
// circuit_builder
// .add_gate(gate1)
// .add_gate(gate2)
// .add_gate(gate3)
// .add_gate(disconnected_gate);

// // Build circuit
// let circuit = circuit_builder.build();
// assert!(
// circuit.is_ok(),
// "Failed to build circuit: {:?}",
// circuit.err()
// );
// let circuit = circuit.unwrap();
// let gates = circuit.gates();

// // Verify order
// // Gate 1 and 2 were added first and have in_degree 0 so they will be processed right away
// // The disconnected gate also has in_degree 0 so it will be put next to them
// // Gate 3 will be processed last for having in_degree > 0
// assert_eq!(
// gates[0].get_outputs().collect::<Vec<_>>(),
// vec![&Node(2)],
// "First gate outputs mismatch" // Gate 1
// );
// assert_eq!(
// gates[1].get_outputs().collect::<Vec<_>>(),
// vec![&Node(5)],
// "Second gate outputs mismatch" // Gate 2
// );
// assert_eq!(
// gates[2].get_outputs().collect::<Vec<_>>(),
// vec![&Node(9)],
// "Third gate outputs mismatch" // Disconnected Gate
// );
// assert_eq!(
// gates[3].get_outputs().collect::<Vec<_>>(),
// vec![&Node(6)],
// "Fourth gate outputs mismatch" // Gate 3
// );
// }
#[test]
fn test_node_out_of_index() {
let mut builder = CircuitBuilder::<Gate>::new();

let input = builder.add_input();

// Add a gate with an out-of-index node
builder
.add_gate(|next| Gate {
inputs: vec![input, Node(100)],
output: next.next(),
})
.unwrap();

let circuit = builder.build();

assert!(circuit.is_err(), "Expected node out of index error");
assert_eq!(
circuit.unwrap_err(),
CircuitBuilderError::NodeOutOfIndex,
"Unexpected error type"
);
}
}

0 comments on commit 5d07bcf

Please sign in to comment.