Skip to content

Commit

Permalink
fix: fix issues with topology verification
Browse files Browse the repository at this point in the history
Co-authored-by: Konrad Jałowiecki <dexter2206@gmail.com>
  • Loading branch information
mstechly and dexter2206 committed Jun 13, 2024
1 parent dbdf4e5 commit 2f18936
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 67 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.vscode/
4 changes: 2 additions & 2 deletions docs/library/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ from qref.verification import verify_topology

program = load_some_program()

# This will raise if data is not valid
verification_output = verify_topology(program)

if verification_output == False:
if not verification_output:
print("Program topology is incorrect, due to the following issues:")
for problem in verification_output.problems:
print(problem)

Expand Down
36 changes: 19 additions & 17 deletions src/qref/_schema_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,28 @@ def __init__(self, **data: Any):
@field_validator("connections", mode="after")
@classmethod
def _validate_connections(cls, v, values) -> list[_ConnectionV1]:
for connection in v:
_validate_connection_end(connection.source, values)
_validate_connection_end(connection.target, values)
children_port_names = [
f"{child.name}.{port.name}"
for child in values.data.get("children")
for port in child.ports
]
parent_port_names = [port.name for port in values.data["ports"]]
available_port_names = set(children_port_names + parent_port_names)

missed_ports = [
port
for connection in v
for port in (connection.source, connection.target)
if port not in available_port_names
]
if missed_ports:
raise ValueError(
"The following ports appear in a connection but are not "
"among routine's port or their children's ports: {missed_ports}."
)
return v


def _validate_connection_end(port_name, values):
parent_port_names = [port.name for port in values.data["ports"]]
children_port_names = []
for child in values.data.get("children", []):
children_port_names += [".".join([child.name, port.name]) for port in child.ports]

available_port_names = parent_port_names + children_port_names
if port_name not in available_port_names:
raise ValueError(
f"Port name: {port_name} is present in a connection,"
"but is not among routine's ports or their children."
)


class SchemaV1(BaseModel):
"""Root object in Program schema V1."""

Expand Down
105 changes: 62 additions & 43 deletions src/qref/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
class TopologyVerificationOutput:
"""Dataclass containing the output of the topology verification"""

is_valid: bool
problems: list[str]

@property
def is_valid(self):
return len(self.problems) == 0

def __bool__(self) -> bool:
return self.is_valid

Expand All @@ -38,51 +41,43 @@ def verify_topology(routine: Union[SchemaV1, RoutineV1]) -> TopologyVerification
"""
if isinstance(routine, SchemaV1):
routine = routine.program
problems = _verify_routine_topology(routine, is_root=True)
if problems:
return TopologyVerificationOutput(False, problems)
else:
return TopologyVerificationOutput(True, problems)
problems = _verify_routine_topology(routine)
return TopologyVerificationOutput(problems)


def _verify_routine_topology(routine: RoutineV1, is_root: bool) -> list[str]:
def _verify_routine_topology(routine: RoutineV1) -> list[str]:
problems = []
flat_graph = _get_flat_graph_from_routine(routine, path=None)
edge_list = []
for source, targets in flat_graph.items():
edge_list += [(source, target) for target in targets]
adjacency_list = _get_adjacency_list_from_routine(routine, path=None)

problems += _find_cycles(flat_graph)
problems += _find_disconnected_ports(routine, is_root)
problems += _find_cycles(adjacency_list)
problems += _find_disconnected_ports(routine)

for child in routine.children:
new_problems = _verify_routine_topology(child, is_root=False)
if new_problems:
problems += new_problems
new_problems = _verify_routine_topology(child)
problems += new_problems
return problems


def _get_flat_graph_from_routine(routine, path) -> dict[str, list[str]]:
def _get_adjacency_list_from_routine(routine: RoutineV1, path: str) -> dict[str, list[str]]:
"""This function creates a flat graph representing one hierarchy level of a routine.
Nodes represent ports and edges represent connections (they're directed).
Additionaly, we add node for each children and edges coming from all the input ports
into the children, and from the children into all the output ports.
"""
graph = defaultdict(list)
if path is None:
current_path = routine.name
else:
current_path = ".".join([path, routine.name])

input_ports = []
output_ports = []

for port in routine.ports:
if port.direction == "input":
input_ports.append(".".join([current_path, port.name]))
elif port.direction == "output":
output_ports.append(".".join([current_path, port.name]))

# First, we go through all the connections and add them as adges to the graph
for connection in routine.connections:
source = ".".join([current_path, connection.source])
target = ".".join([current_path, connection.target])
graph[source].append(target)

# Then for each children we add an extra node and set of connections
for child in routine.children:
input_ports = []
output_ports = []
Expand All @@ -102,41 +97,65 @@ def _get_flat_graph_from_routine(routine, path) -> dict[str, list[str]]:
return graph


def _find_cycles(edges) -> list[str]:
for node in list(edges.keys()):
visited: list[str] = []
problem = _dfs_iteration(edges, node, node, visited)
def _find_cycles(adjacency_list: dict[str, list[str]]) -> list[str]:
# Note: it only returns the first detected cycle.
for node in list(adjacency_list.keys()):
problem = _dfs_iteration(adjacency_list, node)
if problem:
return problem
return []


def _dfs_iteration(edges, initial_node, node, visited):
if node != initial_node:
# def _dfs_iteration(adjacency_list, initial_node, node, visited):
# if node != initial_node:
# visited.append(node)
# for neighbour in adjacency_list[node]:
# if neighbour not in visited:
# if neighbour == initial_node:
# return [f"Cycle detected for node: {node}. Cycle: {visited}."]
# problem = _dfs_iteration(adjacency_list, initial_node, neighbour, visited)
# if problem:
# return problem


def _dfs_iteration(adjacency_list, start_node) -> list[str]:
to_visit = [start_node]
visited = []
predecessors = {}

while to_visit:
node = to_visit.pop()
visited.append(node)
for neighbour in edges[node]:
if neighbour not in visited:
if neighbour == initial_node:
return [f"Cycle detected for node: {node}. Cycle: {visited}."]
problem = _dfs_iteration(edges, initial_node, neighbour, visited)
if problem:
return problem
for neighbour in adjacency_list[node]:
predecessors[neighbour] = node
if neighbour == start_node:
# Reconstruct the cycle
cycle = [neighbour]
while len(cycle) < 2 or cycle[-1] != start_node:
cycle.append(predecessors[cycle[-1]])
return [f"Cycle detected: {cycle[::-1]}"]
if neighbour not in visited:
to_visit.append(neighbour)
return []


def _find_disconnected_ports(routine: RoutineV1, is_root: bool):
def _find_disconnected_ports(routine: RoutineV1):
problems = []
conns = [c.model_dump() for c in routine.connections]
for child in routine.children:
for port in child.ports:
pname = f"{routine.name}.{child.name}.{port.name}"
if port.direction == "input":
matches_in = [c for c in conns if c["target"] == f"{child.name}.{port.name}"]
matches_in = [
c for c in routine.connections if c.target == f"{child.name}.{port.name}"
]
if len(matches_in) == 0:
problems.append(f"No incoming connections to {pname}.")
elif len(matches_in) > 1:
problems.append(f"Too many incoming connections to {pname}.")
elif port.direction == "output":
matches_out = [c for c in conns if c["source"] == f"{child.name}.{port.name}"]
matches_out = [
c for c in routine.connections if c.source == f"{child.name}.{port.name}"
]
if len(matches_out) == 0:
problems.append(f"No outgoing connections from {pname}.")
elif len(matches_out) > 1:
Expand Down
27 changes: 27 additions & 0 deletions tests/qref/data/invalid_program_examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,30 @@
description: "Target of a paramater link is not namespaced"
error_path: "$.program.linked_params[0].targets[0]"
error_message: "'N' does not match '^[A-Za-z_][A-Za-z0-9_]*\\\\.[A-Za-z_][A-Za-z0-9_]*'"
- input:
version: v1
program:
name: root
children:
- name: foo
ports:
- name: in_0
direction: input
size: 3
- name: out_0
direction: output
size: 3
- name: bar
ports:
- name: in_0
direction: input
size: 3
- name: out_0
direction: output
size: 3
connections:
- source: foo.out_0
target: bar.in_1
description: "Connection contains non-existent port name"
error_path: "$.program.connections[0].source"
error_message: "'foo.foo.out_0' does not match '^(([A-Za-z_][A-Za-z0-9_]*)|([A-Za-z_][A-Za-z0-9_]*\\\\.[A-Za-z_][A-Za-z0-9_]*))$'"
4 changes: 1 addition & 3 deletions tests/qref/data/invalid_topology_programs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
version: v1
description: Program contains cycles
problems:
- "Cycle detected for node: *"
- "Cycle detected: ['root.child_1.out_0', 'root.child_2.in_0', 'root.child_2', 'root.child_2.out_1', 'root.child_3.in_1', 'root.child_3', 'root.child_3.out_1', 'root.child_1.in_1', 'root.child_1', 'root.child_1.out_0']"
- input:
program:
children:
Expand Down Expand Up @@ -100,8 +100,6 @@
- direction: output
name: out_1
size: 1


connections:
- source: in_0
target: child_1.in_0
Expand Down
3 changes: 1 addition & 2 deletions tests/qref/test_topology_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from pathlib import Path

import pytest
Expand Down Expand Up @@ -49,4 +48,4 @@ def test_invalid_program_fails_to_validate_with_schema_v1(input, problems):
assert not verification_output
assert len(problems) == len(verification_output.problems)
for expected_problem, problem in zip(problems, verification_output.problems):
assert re.match(expected_problem, problem)
assert expected_problem == problem

0 comments on commit 2f18936

Please sign in to comment.