diff --git a/stubs/networkx/@tests/test_cases/check_tricky_function_params.py b/stubs/networkx/@tests/test_cases/check_tricky_function_params.py index 6a1105347ee3..20a38a85a8ad 100644 --- a/stubs/networkx/@tests/test_cases/check_tricky_function_params.py +++ b/stubs/networkx/@tests/test_cases/check_tricky_function_params.py @@ -1,4 +1,7 @@ +from typing_extensions import assert_type + import networkx as nx +from networkx.classes.reportviews import DegreeView, DiDegreeView # Test covariant dict type for `pos` in nx_latex functions G: "nx.Graph[int]" = nx.Graph([(1, 2), (2, 3), (3, 4)]) @@ -9,3 +12,14 @@ nx.to_latex_raw(G, pos=pos2) # OK: dict[node, str] pos3: dict[int, int] = {1: 1, 2: 3, 3: 5, 4: 7} nx.to_latex_raw(G, pos=pos3) # type: ignore # dict keys must be str or collection + +# Test that we don't confuse str and Iterable[str] in DiDegreeView.__call__ +G_str = nx.Graph[str]() +di_degree_view = DiDegreeView(G_str) +assert_type(di_degree_view(""), int) +assert_type(di_degree_view([""]), DiDegreeView[str]) +assert_type(di_degree_view({""}), DiDegreeView[str]) +degree_view = DegreeView(G_str) +assert_type(degree_view(""), int) +assert_type(degree_view([""]), DegreeView[str]) +assert_type(degree_view({""}), DegreeView[str]) diff --git a/stubs/networkx/networkx/classes/reportviews.pyi b/stubs/networkx/networkx/classes/reportviews.pyi index dc9175291af4..b141f395feac 100644 --- a/stubs/networkx/networkx/classes/reportviews.pyi +++ b/stubs/networkx/networkx/classes/reportviews.pyi @@ -44,8 +44,11 @@ class NodeView(Mapping[_Node, dict[str, Any]], AbstractSet[_Node]): @overload def __call__(self, data: Literal[False] = False, default=None) -> Self: ... @overload - def __call__(self, data: Literal[True] | str, default=None) -> Self: ... - def data(self, data: bool | str = True, default=None) -> Self: ... + def __call__(self, data: Literal[True] | str, default=None) -> NodeDataView[_Node]: ... + @overload + def data(self, data: Literal[False], default=None) -> Self: ... + @overload + def data(self, data: Literal[True] | str = True, default=None) -> NodeDataView[_Node]: ... class NodeDataView(AbstractSet[_Node]): __slots__ = ("_nodes", "_data", "_default") @@ -57,12 +60,12 @@ class NodeDataView(AbstractSet[_Node]): class DiDegreeView(Generic[_Node]): def __init__(self, G: Graph[_Node], nbunch: _NBunch[_Node] = None, weight: None | bool | str = None) -> None: ... + @overload # Use this overload first in case _Node=str, since `str` matches `Iterable[str]` + def __call__(self, nbunch: _Node, weight: None | bool | str = None) -> int: ... # type: ignore[overload-overlap] @overload - def __call__(self, nbunch: None = None, weight: None | bool | str = None) -> int: ... # type: ignore[overload-overlap] - @overload - def __call__(self, nbunch: None | Iterable[_Node], weight: None | bool | str = None) -> Self: ... - def __getitem__(self, n: _Node) -> float: ... - def __iter__(self) -> Iterator[tuple[_Node, float]]: ... + def __call__(self, nbunch: Iterable[_Node] | None = None, weight: None | bool | str = None) -> Self: ... + def __getitem__(self, n: _Node) -> int: ... + def __iter__(self) -> Iterator[tuple[_Node, int]]: ... def __len__(self) -> int: ... class DegreeView(DiDegreeView[_Node]): ...