From 2f18936c00ece9419ff5c0223ee23c2a037e2078 Mon Sep 17 00:00:00 2001 From: mstechly Date: Thu, 13 Jun 2024 11:35:49 -0400 Subject: [PATCH] fix: fix issues with topology verification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Konrad JaƂowiecki --- .gitignore | 1 + docs/library/userguide.md | 4 +- src/qref/_schema_v1.py | 36 +++--- src/qref/verification.py | 105 +++++++++++------- tests/qref/data/invalid_program_examples.yaml | 27 +++++ .../qref/data/invalid_topology_programs.yaml | 4 +- tests/qref/test_topology_verification.py | 3 +- 7 files changed, 113 insertions(+), 67 deletions(-) diff --git a/.gitignore b/.gitignore index b1cb160..5e8f6be 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.vscode/ \ No newline at end of file diff --git a/docs/library/userguide.md b/docs/library/userguide.md index e2a5f18..215b89f 100644 --- a/docs/library/userguide.md +++ b/docs/library/userguide.md @@ -67,10 +67,10 @@ from qref.verification import verify_topology program = load_some_program() -# This will raise if data is not valid verification_output = verify_topology(program) -if verification_output == False: +if not verification_output: + print("Program topology is incorrect, due to the following issues:") for problem in verification_output.problems: print(problem) diff --git a/src/qref/_schema_v1.py b/src/qref/_schema_v1.py index 5655c2c..64866c7 100644 --- a/src/qref/_schema_v1.py +++ b/src/qref/_schema_v1.py @@ -106,26 +106,28 @@ def __init__(self, **data: Any): @field_validator("connections", mode="after") @classmethod def _validate_connections(cls, v, values) -> list[_ConnectionV1]: - for connection in v: - _validate_connection_end(connection.source, values) - _validate_connection_end(connection.target, values) + children_port_names = [ + f"{child.name}.{port.name}" + for child in values.data.get("children") + for port in child.ports + ] + parent_port_names = [port.name for port in values.data["ports"]] + available_port_names = set(children_port_names + parent_port_names) + + missed_ports = [ + port + for connection in v + for port in (connection.source, connection.target) + if port not in available_port_names + ] + if missed_ports: + raise ValueError( + "The following ports appear in a connection but are not " + "among routine's port or their children's ports: {missed_ports}." + ) return v -def _validate_connection_end(port_name, values): - parent_port_names = [port.name for port in values.data["ports"]] - children_port_names = [] - for child in values.data.get("children", []): - children_port_names += [".".join([child.name, port.name]) for port in child.ports] - - available_port_names = parent_port_names + children_port_names - if port_name not in available_port_names: - raise ValueError( - f"Port name: {port_name} is present in a connection," - "but is not among routine's ports or their children." - ) - - class SchemaV1(BaseModel): """Root object in Program schema V1.""" diff --git a/src/qref/verification.py b/src/qref/verification.py index 5454e4c..48fa7f0 100644 --- a/src/qref/verification.py +++ b/src/qref/verification.py @@ -23,9 +23,12 @@ class TopologyVerificationOutput: """Dataclass containing the output of the topology verification""" - is_valid: bool problems: list[str] + @property + def is_valid(self): + return len(self.problems) == 0 + def __bool__(self) -> bool: return self.is_valid @@ -38,51 +41,43 @@ def verify_topology(routine: Union[SchemaV1, RoutineV1]) -> TopologyVerification """ if isinstance(routine, SchemaV1): routine = routine.program - problems = _verify_routine_topology(routine, is_root=True) - if problems: - return TopologyVerificationOutput(False, problems) - else: - return TopologyVerificationOutput(True, problems) + problems = _verify_routine_topology(routine) + return TopologyVerificationOutput(problems) -def _verify_routine_topology(routine: RoutineV1, is_root: bool) -> list[str]: +def _verify_routine_topology(routine: RoutineV1) -> list[str]: problems = [] - flat_graph = _get_flat_graph_from_routine(routine, path=None) - edge_list = [] - for source, targets in flat_graph.items(): - edge_list += [(source, target) for target in targets] + adjacency_list = _get_adjacency_list_from_routine(routine, path=None) - problems += _find_cycles(flat_graph) - problems += _find_disconnected_ports(routine, is_root) + problems += _find_cycles(adjacency_list) + problems += _find_disconnected_ports(routine) for child in routine.children: - new_problems = _verify_routine_topology(child, is_root=False) - if new_problems: - problems += new_problems + new_problems = _verify_routine_topology(child) + problems += new_problems return problems -def _get_flat_graph_from_routine(routine, path) -> dict[str, list[str]]: +def _get_adjacency_list_from_routine(routine: RoutineV1, path: str) -> dict[str, list[str]]: + """This function creates a flat graph representing one hierarchy level of a routine. + + Nodes represent ports and edges represent connections (they're directed). + Additionaly, we add node for each children and edges coming from all the input ports + into the children, and from the children into all the output ports. + """ graph = defaultdict(list) if path is None: current_path = routine.name else: current_path = ".".join([path, routine.name]) - input_ports = [] - output_ports = [] - - for port in routine.ports: - if port.direction == "input": - input_ports.append(".".join([current_path, port.name])) - elif port.direction == "output": - output_ports.append(".".join([current_path, port.name])) - + # First, we go through all the connections and add them as adges to the graph for connection in routine.connections: source = ".".join([current_path, connection.source]) target = ".".join([current_path, connection.target]) graph[source].append(target) + # Then for each children we add an extra node and set of connections for child in routine.children: input_ports = [] output_ports = [] @@ -102,41 +97,65 @@ def _get_flat_graph_from_routine(routine, path) -> dict[str, list[str]]: return graph -def _find_cycles(edges) -> list[str]: - for node in list(edges.keys()): - visited: list[str] = [] - problem = _dfs_iteration(edges, node, node, visited) +def _find_cycles(adjacency_list: dict[str, list[str]]) -> list[str]: + # Note: it only returns the first detected cycle. + for node in list(adjacency_list.keys()): + problem = _dfs_iteration(adjacency_list, node) if problem: return problem return [] -def _dfs_iteration(edges, initial_node, node, visited): - if node != initial_node: +# def _dfs_iteration(adjacency_list, initial_node, node, visited): +# if node != initial_node: +# visited.append(node) +# for neighbour in adjacency_list[node]: +# if neighbour not in visited: +# if neighbour == initial_node: +# return [f"Cycle detected for node: {node}. Cycle: {visited}."] +# problem = _dfs_iteration(adjacency_list, initial_node, neighbour, visited) +# if problem: +# return problem + + +def _dfs_iteration(adjacency_list, start_node) -> list[str]: + to_visit = [start_node] + visited = [] + predecessors = {} + + while to_visit: + node = to_visit.pop() visited.append(node) - for neighbour in edges[node]: - if neighbour not in visited: - if neighbour == initial_node: - return [f"Cycle detected for node: {node}. Cycle: {visited}."] - problem = _dfs_iteration(edges, initial_node, neighbour, visited) - if problem: - return problem + for neighbour in adjacency_list[node]: + predecessors[neighbour] = node + if neighbour == start_node: + # Reconstruct the cycle + cycle = [neighbour] + while len(cycle) < 2 or cycle[-1] != start_node: + cycle.append(predecessors[cycle[-1]]) + return [f"Cycle detected: {cycle[::-1]}"] + if neighbour not in visited: + to_visit.append(neighbour) + return [] -def _find_disconnected_ports(routine: RoutineV1, is_root: bool): +def _find_disconnected_ports(routine: RoutineV1): problems = [] - conns = [c.model_dump() for c in routine.connections] for child in routine.children: for port in child.ports: pname = f"{routine.name}.{child.name}.{port.name}" if port.direction == "input": - matches_in = [c for c in conns if c["target"] == f"{child.name}.{port.name}"] + matches_in = [ + c for c in routine.connections if c.target == f"{child.name}.{port.name}" + ] if len(matches_in) == 0: problems.append(f"No incoming connections to {pname}.") elif len(matches_in) > 1: problems.append(f"Too many incoming connections to {pname}.") elif port.direction == "output": - matches_out = [c for c in conns if c["source"] == f"{child.name}.{port.name}"] + matches_out = [ + c for c in routine.connections if c.source == f"{child.name}.{port.name}" + ] if len(matches_out) == 0: problems.append(f"No outgoing connections from {pname}.") elif len(matches_out) > 1: diff --git a/tests/qref/data/invalid_program_examples.yaml b/tests/qref/data/invalid_program_examples.yaml index 8391ea1..6157e48 100644 --- a/tests/qref/data/invalid_program_examples.yaml +++ b/tests/qref/data/invalid_program_examples.yaml @@ -244,3 +244,30 @@ description: "Target of a paramater link is not namespaced" error_path: "$.program.linked_params[0].targets[0]" error_message: "'N' does not match '^[A-Za-z_][A-Za-z0-9_]*\\\\.[A-Za-z_][A-Za-z0-9_]*'" +- input: + version: v1 + program: + name: root + children: + - name: foo + ports: + - name: in_0 + direction: input + size: 3 + - name: out_0 + direction: output + size: 3 + - name: bar + ports: + - name: in_0 + direction: input + size: 3 + - name: out_0 + direction: output + size: 3 + connections: + - source: foo.out_0 + target: bar.in_1 + description: "Connection contains non-existent port name" + error_path: "$.program.connections[0].source" + error_message: "'foo.foo.out_0' does not match '^(([A-Za-z_][A-Za-z0-9_]*)|([A-Za-z_][A-Za-z0-9_]*\\\\.[A-Za-z_][A-Za-z0-9_]*))$'" diff --git a/tests/qref/data/invalid_topology_programs.yaml b/tests/qref/data/invalid_topology_programs.yaml index 5b4ca2d..d9b2620 100644 --- a/tests/qref/data/invalid_topology_programs.yaml +++ b/tests/qref/data/invalid_topology_programs.yaml @@ -71,7 +71,7 @@ version: v1 description: Program contains cycles problems: - - "Cycle detected for node: *" + - "Cycle detected: ['root.child_1.out_0', 'root.child_2.in_0', 'root.child_2', 'root.child_2.out_1', 'root.child_3.in_1', 'root.child_3', 'root.child_3.out_1', 'root.child_1.in_1', 'root.child_1', 'root.child_1.out_0']" - input: program: children: @@ -100,8 +100,6 @@ - direction: output name: out_1 size: 1 - - connections: - source: in_0 target: child_1.in_0 diff --git a/tests/qref/test_topology_verification.py b/tests/qref/test_topology_verification.py index c3e7c77..e38e37a 100644 --- a/tests/qref/test_topology_verification.py +++ b/tests/qref/test_topology_verification.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from pathlib import Path import pytest @@ -49,4 +48,4 @@ def test_invalid_program_fails_to_validate_with_schema_v1(input, problems): assert not verification_output assert len(problems) == len(verification_output.problems) for expected_problem, problem in zip(problems, verification_output.problems): - assert re.match(expected_problem, problem) + assert expected_problem == problem