From 5d07bcfbdc4b2abc38896a2674bd9b623935e3fe Mon Sep 17 00:00:00 2001 From: Brechy Date: Sat, 8 Jun 2024 16:07:53 -0300 Subject: [PATCH] implement missing checks --- crates/mpz-circuits-generic/src/circuit.rs | 309 ++++++++------------- 1 file changed, 111 insertions(+), 198 deletions(-) diff --git a/crates/mpz-circuits-generic/src/circuit.rs b/crates/mpz-circuits-generic/src/circuit.rs index ac929b11..97255640 100644 --- a/crates/mpz-circuits-generic/src/circuit.rs +++ b/crates/mpz-circuits-generic/src/circuit.rs @@ -3,10 +3,6 @@ //! Main circuit module. use crate::{model::Component, Node}; -use std::{ - collections::{HashMap, VecDeque}, - mem::take, -}; use thiserror::Error; /// The Circuit Builder assembles a collection of gates into a circuit. @@ -75,7 +71,9 @@ where { let gate = f(&mut Next(&mut self.current_node)); - // Verify the gate has at least one input and output + if gate.get_inputs().count() == 0 || gate.get_outputs().count() == 0 { + return Err(CircuitBuilderError::DisconnectedGate); + } self.gates.push(gate); @@ -84,11 +82,26 @@ where /// Builds the circuit. pub fn build(self) -> Result, CircuitBuilderError> { - // Verify that the circuit has at least one gate. + if self.gates.is_empty() { + return Err(CircuitBuilderError::EmptyCircuit); + } // Get the total stack length. - - // Verifying that no node is out of index. + let total_length = self.current_node.0; + + // Verify that all nodes are within the stack. + for gate in &self.gates { + for node in gate.get_inputs() { + if node.0 >= total_length { + return Err(CircuitBuilderError::NodeOutOfIndex); + } + } + for node in gate.get_outputs() { + if node.0 >= total_length { + return Err(CircuitBuilderError::NodeOutOfIndex); + } + } + } Ok(Circuit::new( self.inputs.len(), @@ -96,80 +109,6 @@ where self.gates, )) } - - // /// Performs a topological sort of the gates. - // /// - // /// This ensures that the gates are linearly ordered such that the - // /// dependencies (input gates) of each gate are processed before the gate itself. - // /// - // /// This requires that the gates form a directed acyclic graph (DAG). - // /// - // /// The sorting is done using Kahn's Algorithm. - // fn sort_gates(&mut self) -> Result<(), CircuitBuilderError> { - // // In-degree: the number of gates that provide input to each gate - // // This represents how many other gates need to be processed before this gate - // let mut in_degree = vec![0; self.gates.len()]; - // // Adjacency list: for each gate, list the gates that directly depend on its output - // // This is used to keep track of which gates need to be updated after processing a gate - // let mut adjacency_list = vec![vec![]; self.gates.len()]; - - // // Populate lists - // for (i, gate) in self.gates.iter().enumerate() { - // for output in gate.get_outputs() { - // let output = self.input_map.get(&output.id); - - // if let Some(&gate_index) = output { - // adjacency_list[i].push(gate_index); - // in_degree[gate_index] += 1; - // } - // } - // } - - // let mut queue = VecDeque::new(); - // let mut sorted_indices = Vec::with_capacity(self.gates.len()); - - // // Push ready-to-process nodes (no dependencies) to the queue - // for (i, °ree) in in_degree.iter().enumerate() { - // if degree == 0 { - // queue.push_back(i); - // } - // } - - // // Process nodes - // while let Some(node) = queue.pop_front() { - // sorted_indices.push(node); - - // // Reduce in-degree of dependent nodes - // for &neighbor in &adjacency_list[node] { - // in_degree[neighbor] -= 1; - - // // If the dependent node is now ready to be processed, add it to the queue - // if in_degree[neighbor] == 0 { - // queue.push_back(neighbor); - // } - // } - // } - - // // If some node is left unprocessed, there is a cycle - // if sorted_indices.len() != self.gates.len() { - // return Err(CircuitBuilderError::CycleDetected); - // } - - // // Sort the gates - // // To preserve the order of the gates we create this temporary vector of optionals - // let mut temp_gates: Vec> = self.gates.drain(..).map(Some).collect(); - // let mut sorted_gates = Vec::with_capacity(temp_gates.len()); - // for &i in &sorted_indices { - // // Whenever we take a gate from the vector we replace it with None - // // This way we avoid shifting items - // if let Some(gate) = temp_gates[i].take() { - // sorted_gates.push(gate); - // } - // } - - // self.gates = sorted_gates; - // Ok(()) - // } } /// A circuit constructed from a collection of gates. @@ -212,8 +151,12 @@ impl Circuit { /// Circuit errors. #[derive(Debug, Error, PartialEq, Eq)] pub enum CircuitBuilderError { - #[error("Cycle detected")] - CycleDetected, + #[error("Disconnected gate")] + DisconnectedGate, + #[error("Empty circuit")] + EmptyCircuit, + #[error("Node out of index")] + NodeOutOfIndex, } #[cfg(test)] @@ -284,129 +227,99 @@ mod tests { ); assert_eq!( gates[1].get_outputs().collect::>(), - vec![&Node(10)], - "Second gate outputs mismatch" // Gate 5 + vec![&Node(3)], + "Second gate outputs mismatch" // Gate 2 ); assert_eq!( gates[2].get_outputs().collect::>(), - vec![&Node(5)], - "Third gate outputs mismatch" // Gate 2 + vec![&Node(4)], + "Third gate outputs mismatch" // Gate 3 ); + } + + #[test] + fn test_builder_add_gate() { + // Setup circuit builder + let mut builder = CircuitBuilder::::new(); + + let (in_0, in_1) = (builder.add_input(), builder.add_input()); + + // Add a valid gate + let &Gate { output, .. } = builder + .add_gate(|next| Gate { + inputs: vec![in_0, in_1], + output: next.next(), + }) + .unwrap(); + + // Add a disconnected gate + let gate_result = builder.add_gate(|next| Gate { + inputs: Vec::new(), + output: next.next(), + }); + + // Ensure the disconnected gate is detected + assert!(gate_result.is_err(), "Expected disconnected gate error"); assert_eq!( - gates[3].get_outputs().collect::>(), - vec![&Node(6)], - "Fourth gate outputs mismatch" // Gate 3 + gate_result.unwrap_err(), + CircuitBuilderError::DisconnectedGate, + "Unexpected error type" ); - assert_eq!( - gates[4].get_outputs().collect::>(), - vec![&Node(7)], - "Fifth gate outputs mismatch" // Gate 4 + + // Add valid gate + let &Gate { output, .. } = builder + .add_gate(|next| Gate { + inputs: vec![output, in_1], + output: next.next(), + }) + .unwrap(); + + builder.add_output(output); + + // Build the circuit + let circuit = builder.build(); + assert!( + circuit.is_ok(), + "Failed to build circuit: {:?}", + circuit.err() ); + } + + #[test] + fn test_empty_circuit() { + let builder = CircuitBuilder::::new(); + + let circuit = builder.build(); + + assert!(circuit.is_err(), "Expected empty circuit error"); assert_eq!( - gates[5].get_outputs().collect::>(), - vec![&Node(11)], - "Sixth gate outputs mismatch" // Gate 6 + circuit.unwrap_err(), + CircuitBuilderError::EmptyCircuit, + "Unexpected error type" ); } - // #[test] - // fn test_cycle_detection() { - // // Setup circuit builder - // let mut circuit_builder = CircuitBuilder::::new(); - - // // Define gates - // let gate1 = Gate { - // inputs: vec![Node(0), Node(1)], - // output: Node(2), - // }; - // let gate2 = Gate { - // inputs: vec![Node(2), Node(3)], - // output: Node(4), - // }; - // let cycle_gate = Gate { - // inputs: vec![Node(4)], - // output: Node(0), - // }; - - // // Add gates - // circuit_builder - // .add_gate(gate1) - // .add_gate(gate2) - // .add_gate(cycle_gate); - - // // Expect build to fail - // let circuit = circuit_builder.build(); - // assert!(circuit.is_err(), "Expected cycle detection error"); - // assert_eq!( - // circuit.unwrap_err(), - // CircuitBuilderError::CycleDetected, - // "Unexpected error type" - // ); - // } - - // #[test] - // fn test_disconnected_gate() { - // // Setup circuit builder - // let mut circuit_builder = CircuitBuilder::::new(); - - // // Define gates, with one gate disconnected - // let gate1 = Gate { - // inputs: vec![Node(0), Node(1)], - // output: Node(2), - // }; - // let gate2 = Gate { - // inputs: vec![Node(3), Node(4)], - // output: Node(5), - // }; - // let gate3 = Gate { - // inputs: vec![Node(2), Node(5)], - // output: Node(6), - // }; - // let disconnected_gate = Gate { - // inputs: vec![Node(7), Node(8)], - // output: Node(9), - // }; - - // // Add gates including the disconnected gate - // circuit_builder - // .add_gate(gate1) - // .add_gate(gate2) - // .add_gate(gate3) - // .add_gate(disconnected_gate); - - // // Build circuit - // let circuit = circuit_builder.build(); - // assert!( - // circuit.is_ok(), - // "Failed to build circuit: {:?}", - // circuit.err() - // ); - // let circuit = circuit.unwrap(); - // let gates = circuit.gates(); - - // // Verify order - // // Gate 1 and 2 were added first and have in_degree 0 so they will be processed right away - // // The disconnected gate also has in_degree 0 so it will be put next to them - // // Gate 3 will be processed last for having in_degree > 0 - // assert_eq!( - // gates[0].get_outputs().collect::>(), - // vec![&Node(2)], - // "First gate outputs mismatch" // Gate 1 - // ); - // assert_eq!( - // gates[1].get_outputs().collect::>(), - // vec![&Node(5)], - // "Second gate outputs mismatch" // Gate 2 - // ); - // assert_eq!( - // gates[2].get_outputs().collect::>(), - // vec![&Node(9)], - // "Third gate outputs mismatch" // Disconnected Gate - // ); - // assert_eq!( - // gates[3].get_outputs().collect::>(), - // vec![&Node(6)], - // "Fourth gate outputs mismatch" // Gate 3 - // ); - // } + #[test] + fn test_node_out_of_index() { + let mut builder = CircuitBuilder::::new(); + + let input = builder.add_input(); + + // Add a gate with an out-of-index node + builder + .add_gate(|next| Gate { + inputs: vec![input, Node(100)], + output: next.next(), + }) + .unwrap(); + + let circuit = builder.build(); + + assert!(circuit.is_err(), "Expected node out of index error"); + assert_eq!( + circuit.unwrap_err(), + CircuitBuilderError::NodeOutOfIndex, + "Unexpected error type" + ); + } }