Skip to content

Commit a56f819

Browse files
authored
Fix regression in DiDegreeView (#14732)
1 parent 95656b1 commit a56f819

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

stubs/networkx/@tests/test_cases/check_tricky_function_params.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing_extensions import assert_type
2+
13
import networkx as nx
4+
from networkx.classes.reportviews import DegreeView, DiDegreeView
25

36
# Test covariant dict type for `pos` in nx_latex functions
47
G: "nx.Graph[int]" = nx.Graph([(1, 2), (2, 3), (3, 4)])
@@ -9,3 +12,14 @@
912
nx.to_latex_raw(G, pos=pos2) # OK: dict[node, str]
1013
pos3: dict[int, int] = {1: 1, 2: 3, 3: 5, 4: 7}
1114
nx.to_latex_raw(G, pos=pos3) # type: ignore # dict keys must be str or collection
15+
16+
# Test that we don't confuse str and Iterable[str] in DiDegreeView.__call__
17+
G_str = nx.Graph[str]()
18+
di_degree_view = DiDegreeView(G_str)
19+
assert_type(di_degree_view(""), int)
20+
assert_type(di_degree_view([""]), DiDegreeView[str])
21+
assert_type(di_degree_view({""}), DiDegreeView[str])
22+
degree_view = DegreeView(G_str)
23+
assert_type(degree_view(""), int)
24+
assert_type(degree_view([""]), DegreeView[str])
25+
assert_type(degree_view({""}), DegreeView[str])

stubs/networkx/networkx/classes/reportviews.pyi

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ class NodeView(Mapping[_Node, dict[str, Any]], AbstractSet[_Node]):
4444
@overload
4545
def __call__(self, data: Literal[False] = False, default=None) -> Self: ...
4646
@overload
47-
def __call__(self, data: Literal[True] | str, default=None) -> Self: ...
48-
def data(self, data: bool | str = True, default=None) -> Self: ...
47+
def __call__(self, data: Literal[True] | str, default=None) -> NodeDataView[_Node]: ...
48+
@overload
49+
def data(self, data: Literal[False], default=None) -> Self: ...
50+
@overload
51+
def data(self, data: Literal[True] | str = True, default=None) -> NodeDataView[_Node]: ...
4952

5053
class NodeDataView(AbstractSet[_Node]):
5154
__slots__ = ("_nodes", "_data", "_default")
@@ -57,12 +60,12 @@ class NodeDataView(AbstractSet[_Node]):
5760

5861
class DiDegreeView(Generic[_Node]):
5962
def __init__(self, G: Graph[_Node], nbunch: _NBunch[_Node] = None, weight: None | bool | str = None) -> None: ...
63+
@overload # Use this overload first in case _Node=str, since `str` matches `Iterable[str]`
64+
def __call__(self, nbunch: _Node, weight: None | bool | str = None) -> int: ... # type: ignore[overload-overlap]
6065
@overload
61-
def __call__(self, nbunch: None = None, weight: None | bool | str = None) -> int: ... # type: ignore[overload-overlap]
62-
@overload
63-
def __call__(self, nbunch: None | Iterable[_Node], weight: None | bool | str = None) -> Self: ...
64-
def __getitem__(self, n: _Node) -> float: ...
65-
def __iter__(self) -> Iterator[tuple[_Node, float]]: ...
66+
def __call__(self, nbunch: Iterable[_Node] | None = None, weight: None | bool | str = None) -> Self: ...
67+
def __getitem__(self, n: _Node) -> int: ...
68+
def __iter__(self) -> Iterator[tuple[_Node, int]]: ...
6669
def __len__(self) -> int: ...
6770

6871
class DegreeView(DiDegreeView[_Node]): ...

0 commit comments

Comments
 (0)