diff --git a/pytools/graph.py b/pytools/graph.py index 5fb933af..a1f35761 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -83,6 +83,7 @@ Mapping, MutableSet, Optional, + Protocol, Set, Tuple, TypeVar, @@ -267,8 +268,13 @@ def __init__(self, node: NodeT) -> None: self.node = node +class _SupportsLT(Protocol): + def __lt__(self, other: object) -> bool: + ... + + @dataclass(frozen=True) -class HeapEntry(Generic[NodeT]): +class _HeapEntry(Generic[NodeT]): """ Helper class to compare associated keys while comparing the elements in heap operations. @@ -277,9 +283,9 @@ class HeapEntry(Generic[NodeT]): . """ node: NodeT - key: Any + key: _SupportsLT - def __lt__(self, other: HeapEntry) -> bool: + def __lt__(self, other: _HeapEntry[NodeT]) -> bool: return self.key < other.key @@ -325,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT], # heap: list of instances of HeapEntry(n) where 'n' is a node in # 'graph' with no predecessor. Nodes with no predecessors are the # schedulable candidates. - heap = [HeapEntry(n, keyfunc(n)) + heap = [_HeapEntry(n, keyfunc(n)) for n, num_preds in nodes_to_num_predecessors.items() if num_preds == 0] heapify(heap) @@ -340,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT], for child in graph.get(node_to_be_scheduled, ()): nodes_to_num_predecessors[child] -= 1 if nodes_to_num_predecessors[child] == 0: - heappush(heap, HeapEntry(child, keyfunc(child))) + heappush(heap, _HeapEntry(child, keyfunc(child))) if len(order) != total_num_nodes: # any node which has a predecessor left is a part of a cycle @@ -461,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT], from pytools.graphviz import dot_escape if node_labels is None: - def node_labels(x): + def node_labels(x: NodeT) -> str: return str(x) if edge_labels is None: - def edge_labels(x, y): + def edge_labels(x: NodeT, y: NodeT) -> str: return "" node_to_id = {} diff --git a/run-mypy.sh b/run-mypy.sh index 244d6cc4..0220058b 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -6,5 +6,6 @@ mypy --show-error-codes pytools mypy --strict --follow-imports=silent \ pytools/tag.py \ + pytools/graph.py \ pytools/datatable.py \ pytools/persistent_dict.py