From f5a62af4fb9bb085bc5665798e6e59c9f433aca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Mon, 30 Sep 2024 16:57:28 -0700 Subject: [PATCH] fix: Fix topology verification (#123) * Reimplement checking of ports arity * Streamline handling of paths * Improve naming * Add new example that would trip previous implementation * Fix malformed f-String * Tidy up imports * Fix name clash --- src/qref/schema_v1.py | 2 +- src/qref/verification.py | 118 ++++++++++++------ .../qref/data/invalid_topology_programs.yaml | 62 ++++++++- tests/qref/test_schema_validation.py | 18 +++ tests/qref/test_topology_verification.py | 7 +- 5 files changed, 164 insertions(+), 43 deletions(-) diff --git a/src/qref/schema_v1.py b/src/qref/schema_v1.py index 766877b..b096f29 100644 --- a/src/qref/schema_v1.py +++ b/src/qref/schema_v1.py @@ -208,7 +208,7 @@ def _validate_connections(self) -> Self: if missed_ports: raise ValueError( "The following ports appear in a connection but are not " - "among routine's port or their children's ports: {missed_ports}." + f"among routine's port or their children's ports: {missed_ports}." ) return self diff --git a/src/qref/verification.py b/src/qref/verification.py index 2b0fac7..4222936 100644 --- a/src/qref/verification.py +++ b/src/qref/verification.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict +from collections import Counter, defaultdict from dataclasses import dataclass +from typing import Callable from .functools import accepts_all_qref_types -from .schema_v1 import RoutineV1, SchemaV1 +from .schema_v1 import RoutineV1 AdjacencyList = dict[str, list[str]] @@ -36,7 +37,7 @@ def __bool__(self) -> bool: @accepts_all_qref_types -def verify_topology(routine: SchemaV1 | RoutineV1) -> TopologyVerificationOutput: +def verify_topology(routine: RoutineV1) -> TopologyVerificationOutput: """Checks whether program has correct topology. Correct topology cannot include cycles or disconnected ports. @@ -44,23 +45,32 @@ def verify_topology(routine: SchemaV1 | RoutineV1) -> TopologyVerificationOutput Args: routine: Routine or program to be verified. """ - if isinstance(routine, SchemaV1): - routine = routine.program problems = _verify_routine_topology(routine) return TopologyVerificationOutput(problems) -def _verify_routine_topology(routine: RoutineV1) -> list[str]: - adjacency_list = _get_adjacency_list_from_routine(routine, path=None) +def _make_prefixer(ancestor_path: tuple[str, ...]) -> Callable[[str], str]: + def _prefix(text: str) -> str: + return ".".join([*ancestor_path, text]) + + return _prefix + + +def _verify_routine_topology(routine: RoutineV1, ancestor_path: tuple[str, ...] = ()) -> list[str]: + adjacency_list = _get_adjacency_list_from_routine(routine, path=ancestor_path) return [ - *_find_cycles(adjacency_list), - *_find_disconnected_ports(routine), - *[problem for child in routine.children for problem in _verify_routine_topology(child)], + *_find_cycles(adjacency_list, ancestor_path), + *_find_disconnected_ports(routine, ancestor_path), + *[ + problem + for child in routine.children + for problem in _verify_routine_topology(child, ancestor_path + (routine.name,)) + ], ] -def _get_adjacency_list_from_routine(routine: RoutineV1, path: str | None) -> AdjacencyList: +def _get_adjacency_list_from_routine(routine: RoutineV1, path: tuple[str, ...]) -> AdjacencyList: """This function creates a flat graph representing one hierarchy level of a routine. Nodes represent ports and edges represent connections (they're directed). @@ -68,15 +78,12 @@ def _get_adjacency_list_from_routine(routine: RoutineV1, path: str | None) -> Ad into the children, and from the children into all the output ports. """ graph = defaultdict[str, list[str]](list) - if path is None: - current_path = routine.name - else: - current_path = ".".join([path, routine.name]) + _prefix = _make_prefixer(path + (routine.name,)) # First, we go through all the connections and add them as adges to the graph for connection in routine.connections: - source = ".".join([current_path, connection.source]) - target = ".".join([current_path, connection.target]) + source = _prefix(connection.source) + target = _prefix(connection.target) graph[source].append(target) # Then for each children we add an extra node and set of connections @@ -84,7 +91,7 @@ def _get_adjacency_list_from_routine(routine: RoutineV1, path: str | None) -> Ad input_ports: list[str] = [] output_ports: list[str] = [] - child_path = ".".join([current_path, child.name]) + child_path = _prefix(child.name) for port in child.ports: if port.direction == "input": input_ports.append(".".join([child_path, port.name])) @@ -99,20 +106,22 @@ def _get_adjacency_list_from_routine(routine: RoutineV1, path: str | None) -> Ad return graph -def _find_cycles(adjacency_list: AdjacencyList) -> list[str]: +def _find_cycles(adjacency_list: AdjacencyList, ancestor_path: tuple[str, ...]) -> list[str]: # Note: it only returns the first detected cycle. for node in list(adjacency_list): - problem = _dfs_iteration(adjacency_list, node) + problem = _dfs_iteration(adjacency_list, node, ancestor_path) if problem: return problem return [] -def _dfs_iteration(adjacency_list: AdjacencyList, start_node: str) -> list[str]: +def _dfs_iteration(adjacency_list: AdjacencyList, start_node: str, ancestor_path: tuple[str, ...]) -> list[str]: to_visit: list[str] = [start_node] visited: list[str] = [] predecessors: dict[str, str] = {} + _prefix = _make_prefixer(ancestor_path) + while to_visit: node = to_visit.pop() visited.append(node) @@ -122,29 +131,66 @@ def _dfs_iteration(adjacency_list: AdjacencyList, start_node: str) -> list[str]: # Reconstruct the cycle cycle = [neighbour] while len(cycle) < 2 or cycle[-1] != start_node: - cycle.append(predecessors[cycle[-1]]) + cycle.append(_prefix(predecessors[cycle[-1]])) return [f"Cycle detected: {cycle[::-1]}"] if neighbour not in visited: to_visit.append(neighbour) return [] -def _find_disconnected_ports(routine: RoutineV1) -> list[str]: +def _find_disconnected_ports(routine: RoutineV1, ancestor_path: tuple[str, ...]) -> list[str]: problems: list[str] = [] + + _prefix = _make_prefixer(ancestor_path + (routine.name,)) + + sources_counts = Counter[str]() + target_counts = Counter[str]() + + for connection in routine.connections: + sources_counts[connection.source] += 1 + target_counts[connection.target] += 1 + + multi_sources = [source for source, count in sources_counts.items() if count > 1] + + multi_targets = [target for target, count in target_counts.items() if count > 1] + + if multi_sources: + problems.append(f"Too many outgoing connections from {','.join(_prefix(target) for target in multi_sources)}.") + + if multi_targets: + problems.append(f"Too many incoming connections to {','.join(_prefix(target) for target in multi_targets)}.") + + requiring_outgoing = set[str]() + requiring_incoming = set[str]() + thru_ports = set[str]() + + for port in routine.ports: + if port.direction == "input" and routine.children: + requiring_outgoing.add(port.name) + elif port.direction == "output" and routine.children: + requiring_incoming.add(port.name) + elif port.direction == "through": # Note: through ports have to be valid regardless of existence of children + thru_ports.add(port.name) + for child in routine.children: + # Directions are reversed compared to parent + through ports have to be connected on both ends for port in child.ports: - pname = f"{routine.name}.{child.name}.{port.name}" - if port.direction == "input": - matches_in = [c for c in routine.connections if c.target == f"{child.name}.{port.name}"] - if len(matches_in) == 0: - problems.append(f"No incoming connections to {pname}.") - elif len(matches_in) > 1: - problems.append(f"Too many incoming connections to {pname}.") - elif port.direction == "output": - matches_out = [c for c in routine.connections if c.source == f"{child.name}.{port.name}"] - if len(matches_out) == 0: - problems.append(f"No outgoing connections from {pname}.") - elif len(matches_out) > 1: - problems.append(f"Too many outgoing connections from {pname}.") + pname = f"{child.name}.{port.name}" + if port.direction != "output": + requiring_incoming.add(pname) + if port.direction != "input": + requiring_outgoing.add(pname) + + for pname in requiring_outgoing: + if pname not in sources_counts: + problems.append(f"No outgoing connection from {_prefix(pname)}.") + + for pname in requiring_incoming: + if pname not in target_counts: + problems.append(f"No incoming connection to {_prefix(pname)}.") + + for pname in thru_ports: + if pname in sources_counts or pname in target_counts: + problems.append(f"A through port {_prefix(pname)} is connected via an internal connection.") return problems diff --git a/tests/qref/data/invalid_topology_programs.yaml b/tests/qref/data/invalid_topology_programs.yaml index d9b2620..3dd45ba 100644 --- a/tests/qref/data/invalid_topology_programs.yaml +++ b/tests/qref/data/invalid_topology_programs.yaml @@ -128,7 +128,65 @@ version: v1 description: Program has badly connected ports problems: - - "No incoming connections to root.child_1.in_1." + - "No incoming connection to root.child_1.in_1." - "Too many incoming connections to root.child_2.in_0." - "Too many outgoing connections from root.child_2.out_0." - - "No outgoing connections from root.child_2.out_1." + - "No outgoing connection from root.child_2.out_1." +- input: + version: v1 + program: + name: root + ports: + - direction: through + name: thru_0 + size: 1 + - direction: output + name: out_0 + size: 1 + connections: + - thru_0 -> out_0 + description: "A routine with its thru port connected to its output port." + problems: + - "A through port root.thru_0 is connected via an internal connection." +- input: + version: v1 + program: + name: root + ports: + - direction: input + name: in_0 + size: N + - direction: output + name: out_0 + size: N + - direction: output + name: out_1 + size: N + connections: + - in_0 -> foo.in_0 + - foo.out_0 -> out_0 + - foo.out_1 -> out_1 + children: + - name: foo + ports: + - direction: input + name: in_0 + size: N + - direction: output + name: out_0 + size: N + - direction: output + name: out_1 + size: N + connections: + - bar.out_0 -> out_1 + children: + - name: bar + ports: + - name: out_0 + size: N + direction: output + description: "Program with disconnected container ports" + problems: + - "No outgoing connection from root.foo.in_0." + - "No incoming connection to root.foo.out_0." diff --git a/tests/qref/test_schema_validation.py b/tests/qref/test_schema_validation.py index 1e37538..32dd322 100644 --- a/tests/qref/test_schema_validation.py +++ b/tests/qref/test_schema_validation.py @@ -44,3 +44,21 @@ def test_invalid_program_fails_to_validate_with_pydantic_model_v1(input): def test_valid_program_succesfully_validate_with_pydantic_model_v1(valid_program): SchemaV1.model_validate(valid_program) + + +def test_validation_error_includes_name_of_the_missed_port(): + input = { + "version": "v1", + "program": { + "name": "root", + "ports": [{"name": "in_0", "direction": "input", "size": 1}], + "connections": ["in_0 -> out_0"], + }, + } + + pattern = ( + "The following ports appear in a connection but are not among routine's port " + r"or their children's ports: \['out_0'\]." # <- out_0 is the important bit here + ) + with pytest.raises(pydantic.ValidationError, match=pattern): + SchemaV1.model_validate(input) diff --git a/tests/qref/test_topology_verification.py b/tests/qref/test_topology_verification.py index e38e37a..7c0c95f 100644 --- a/tests/qref/test_topology_verification.py +++ b/tests/qref/test_topology_verification.py @@ -45,7 +45,6 @@ def test_correct_routines_pass_topology_validation(valid_program): def test_invalid_program_fails_to_validate_with_schema_v1(input, problems): verification_output = verify_topology(SchemaV1(**input)) - assert not verification_output - assert len(problems) == len(verification_output.problems) - for expected_problem, problem in zip(problems, verification_output.problems): - assert expected_problem == problem + # We use sorted here, to make sure that we don't test the order in which the + # problems appear, as the order is only an implementation detail. + assert sorted(verification_output.problems) == sorted(problems)