Skip to content

Commit

Permalink
Get pytools.graph to typecheck strictly
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Sep 4, 2024
1 parent 0182fd9 commit 09e7e39
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
20 changes: 13 additions & 7 deletions pytools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
Mapping,
MutableSet,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Expand Down Expand Up @@ -267,8 +268,13 @@ def __init__(self, node: NodeT) -> None:
self.node = node


class _SupportsLT(Protocol):
def __lt__(self, other: object) -> bool:
...


@dataclass(frozen=True)
class HeapEntry(Generic[NodeT]):
class _HeapEntry(Generic[NodeT]):
"""
Helper class to compare associated keys while comparing the elements in
heap operations.
Expand All @@ -277,9 +283,9 @@ class HeapEntry(Generic[NodeT]):
<https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
"""
node: NodeT
key: Any
key: _SupportsLT

def __lt__(self, other: HeapEntry) -> bool:
def __lt__(self, other: _HeapEntry[NodeT]) -> bool:
return self.key < other.key


Expand Down Expand Up @@ -325,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT],
# heap: list of instances of HeapEntry(n) where 'n' is a node in
# 'graph' with no predecessor. Nodes with no predecessors are the
# schedulable candidates.
heap = [HeapEntry(n, keyfunc(n))
heap = [_HeapEntry(n, keyfunc(n))
for n, num_preds in nodes_to_num_predecessors.items()
if num_preds == 0]
heapify(heap)
Expand All @@ -340,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT],
for child in graph.get(node_to_be_scheduled, ()):
nodes_to_num_predecessors[child] -= 1
if nodes_to_num_predecessors[child] == 0:
heappush(heap, HeapEntry(child, keyfunc(child)))
heappush(heap, _HeapEntry(child, keyfunc(child)))

if len(order) != total_num_nodes:
# any node which has a predecessor left is a part of a cycle
Expand Down Expand Up @@ -461,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT],
from pytools.graphviz import dot_escape

if node_labels is None:
def node_labels(x):
def node_labels(x: NodeT) -> str:
return str(x)

if edge_labels is None:
def edge_labels(x, y):
def edge_labels(x: NodeT, y: NodeT) -> str:
return ""

node_to_id = {}
Expand Down
1 change: 1 addition & 0 deletions run-mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ mypy --show-error-codes pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
pytools/datatable.py \
pytools/persistent_dict.py

0 comments on commit 09e7e39

Please sign in to comment.