Skip to content

Commit 2591e4b

Browse files
Fix type narrowing
1 parent c6786ab commit 2591e4b

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

mypy/checker.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6436,6 +6436,10 @@ def find_isinstance_check_helper(
64366436
# If the callee is a RefExpr, extract TypeGuard/TypeIs directly.
64376437
if isinstance(node.callee, RefExpr):
64386438
type_is, type_guard = node.callee.type_is, node.callee.type_guard
6439+
if type_guard is not None:
6440+
type_guard = self.expand_narrowed_type(type_guard)
6441+
if type_is is not None:
6442+
type_is = self.expand_narrowed_type(type_is)
64396443
if type_guard is not None or type_is is not None:
64406444
# TODO: Follow *args, **kwargs
64416445
if node.arg_kinds[0] != nodes.ARG_POS:
@@ -7926,7 +7930,7 @@ def conditional_types_with_intersection(
79267930
for types, reason in errors:
79277931
self.msg.impossible_intersection(types, reason, ctx)
79287932
return UninhabitedType(), expr_type
7929-
new_yes_type = make_simplified_union(out)
7933+
new_yes_type: Type = make_simplified_union(out)
79307934
return new_yes_type, expr_type
79317935

79327936
def is_writable_attribute(self, node: Node) -> bool:
@@ -7985,8 +7989,31 @@ def get_isinstance_type(self, expr: Expression) -> list[TypeRange] | None:
79857989
types.append(TypeRange(typ, is_upper_bound=False))
79867990
else: # we didn't see an actual type, but rather a variable with unknown value
79877991
return None
7992+
return self.expand_isinstance_type_ranges(types)
7993+
7994+
def expand_isinstance_type_ranges(self, types: list[TypeRange]) -> list[TypeRange]:
7995+
if disallow_str_iteration_state.disallow_str_iteration:
7996+
str_type = self.named_type("builtins.str")
7997+
return types + [
7998+
TypeRange(str_type, is_upper_bound=type_range.is_upper_bound)
7999+
for type_range in types
8000+
if self._is_str_iteration_protocol_for_narrowing(type_range.item)
8001+
]
79888002
return types
79898003

8004+
def _is_str_iteration_protocol_for_narrowing(self, typ: Type) -> bool:
8005+
proper = get_proper_type(typ)
8006+
return isinstance(proper, Instance) and proper.type.fullname in STR_ITERATION_PROTOCOL_BASES
8007+
8008+
def expand_narrowed_type(self, typ: Type) -> Type:
8009+
if disallow_str_iteration_state.disallow_str_iteration:
8010+
proper = get_proper_type(typ)
8011+
if isinstance(proper, UnionType):
8012+
return make_simplified_union([self.expand_narrowed_type(item) for item in proper.items])
8013+
if self._is_str_iteration_protocol_for_narrowing(proper):
8014+
return make_simplified_union([typ, self.named_type("builtins.str")])
8015+
return typ
8016+
79908017
def is_literal_enum(self, n: Expression) -> bool:
79918018
"""Returns true if this expression (with the given type context) is an Enum literal.
79928019

test-data/unit/check-flags.test

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,9 +2452,9 @@ f(memoryview(b"asdf")) # E: Argument 1 to "f" has incompatible type "memoryview
24522452
[builtins fixtures/primitives.pyi]
24532453

24542454
[case testDisallowStrIteration]
2455-
# flags: --disallow-str-iteration
2455+
# flags: --disallow-str-iteration --warn-unreachable
24562456
from abc import abstractmethod
2457-
from typing import Collection, Container, Iterable, Mapping, Protocol, Sequence, SupportsIndex, TypeVar, Union
2457+
from typing import Collection, Container, Hashable, Iterable, Mapping, Protocol, Sequence, SupportsIndex, TypeGuard, TypeIs, TypeVar, Union, runtime_checkable
24582458

24592459
def takes_str(x: str):
24602460
for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled
@@ -2476,7 +2476,7 @@ seq: Sequence[str] = s # E: Incompatible types in assignment (expression has ty
24762476
iterable: Iterable[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Iterable[str]")
24772477
collection: Collection[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Collection[str]")
24782478

2479-
def takes_maybe_seq(x: "str | Sequence[int]") -> None:
2479+
def takes_maybe_seq(x: Union[str, Sequence[int]]) -> None:
24802480
for ch in x: # E: Iterating over "str | Sequence[int]" is disallowed # N: This is because --disallow-str-iteration is enabled
24812481
reveal_type(ch) # N: Revealed type is "builtins.str | builtins.int"
24822482

@@ -2495,6 +2495,7 @@ def takes_str_subclass(x: StrSubclass):
24952495
for ch in x: # E: Iterating over "StrSubclass" is disallowed # N: This is because --disallow-str-iteration is enabled
24962496
reveal_type(ch) # N: Revealed type is "builtins.str"
24972497

2498+
@runtime_checkable
24982499
class CollectionSubclass(Collection[_T_co], Protocol[_T_co]):
24992500
@abstractmethod
25002501
def __missing_impl__(self): ...
@@ -2512,12 +2513,52 @@ takes_collection_subclass(StrSubclass()) # E: Argument 1 to "takes_collection_s
25122513
def dict_unpacking_unaffected_by_union_simplification(x: Mapping[str, Union[str, Sequence[str]]]) -> None:
25132514
x = {**x}
25142515

2515-
def narrowing(x: "str | Sequence[str]"):
2516+
def narrowing_str(x: Union[str, Sequence[str]]):
25162517
if isinstance(x, str):
25172518
reveal_type(x) # N: Revealed type is "builtins.str"
25182519
else:
25192520
reveal_type(x) # N: Revealed type is "typing.Sequence[builtins.str]"
25202521

2522+
def narrowing_iterable(x: Union[str, Iterable[str], Iterable[str]]):
2523+
if isinstance(x, Iterable):
2524+
reveal_type(x) # N: Revealed type is "builtins.str | typing.Iterable[builtins.str]"
2525+
else:
2526+
reveal_type(x) # E: Statement is unreachable
2527+
2528+
def is_iterable_guard(x: object) -> TypeGuard[Iterable[str]]: ...
2529+
2530+
def is_iterable_is(x: object) -> TypeIs[Iterable[str]]: ...
2531+
2532+
def narrowing_typeguard_iterable(x: Union[str, Iterable[str]]):
2533+
if is_iterable_guard(x):
2534+
reveal_type(x) # N: Revealed type is "builtins.str | typing.Iterable[builtins.str]"
2535+
else:
2536+
reveal_type(x) # N: Revealed type is "builtins.str | typing.Iterable[builtins.str]"
2537+
2538+
def narrowing_typeis_iterable(x: Union[str, Iterable[str]]):
2539+
if is_iterable_is(x):
2540+
reveal_type(x) # N: Revealed type is "builtins.str | typing.Iterable[builtins.str]"
2541+
else:
2542+
reveal_type(x) # E: Statement is unreachable
2543+
2544+
def narrowing_incompatible_collection_subclass(x: Union[str, CollectionSubclass[str]]):
2545+
if isinstance(x, CollectionSubclass):
2546+
reveal_type(x) # N: Revealed type is "__main__.CollectionSubclass[builtins.str]"
2547+
else:
2548+
reveal_type(x) # N: Revealed type is "builtins.str"
2549+
2550+
def narrowing_object(x: object):
2551+
if isinstance(x, Iterable):
2552+
reveal_type(x) # N: Revealed type is "typing.Iterable[Any] | builtins.str"
2553+
else:
2554+
reveal_type(x) # N: Revealed type is "builtins.object"
2555+
2556+
def narrowing_type(x: Union[type[str], type[int]]):
2557+
if issubclass(x, Iterable):
2558+
reveal_type(x) # N: Revealed type is "type[builtins.str]"
2559+
else:
2560+
reveal_type(x) # N: Revealed type is "type[builtins.int]"
2561+
25212562
Item = TypeVar("Item")
25222563
def takes_generic_list(x: list[Item]) -> None: ...
25232564
takes_generic_list(reveal_type([s, seq])) # N: Revealed type is "builtins.list[builtins.object]"

test-data/unit/fixtures/str-iter.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Builtins stub used in disallow-str-iteration tests.
22

33

4-
from typing import Generic, Iterator, Mapping, Sequence, SupportsIndex,TypeVar, overload
4+
from typing import Generic, Iterator, Mapping, Sequence, SupportsIndex, Type, TypeVar, overload
55

66
_T = TypeVar("_T")
77
_KT = TypeVar("_KT")
88
_VT = TypeVar("_VT")
99

1010
class object:
1111
def __init__(self) -> None: pass
12+
def __hash__(self) -> int: pass
1213

1314
class type: pass
1415
class int: pass
@@ -21,6 +22,7 @@ class str(Sequence[str]):
2122
def __len__(self) -> int: pass
2223
def __contains__(self, item: object) -> bool: pass
2324
def __getitem__(self, key: SupportsIndex | slice, /) -> str: pass
25+
def __hash__(self) -> int: pass
2426

2527
class list(Sequence[_T], Generic[_T]):
2628
def __iter__(self) -> Iterator[_T]: pass
@@ -47,3 +49,4 @@ class dict(Mapping[_KT, _VT], Generic[_KT, _VT]):
4749
def __getitem__(self, key: _KT) -> _VT: pass
4850

4951
def isinstance(x: object, t: type) -> bool: pass
52+
def issubclass(x: type, t: type) -> bool: pass

test-data/unit/fixtures/typing-str-iter.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ TypeVar = 0
88
Generic = 0
99
Protocol = 0
1010
Union = 0
11+
Type = 0
12+
TypeGuard = 0
13+
TypeIs = 0
1114
overload = 0
1215

1316
_T = TypeVar("_T")
@@ -17,25 +20,35 @@ _KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers.
1720
_VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers.
1821
_TC = TypeVar("_TC", bound=type[object])
1922

23+
@runtime_checkable
24+
class Hashable(Protocol, metaclass=ABCMeta):
25+
@abstractmethod
26+
def __hash__(self) -> int: pass
27+
28+
@runtime_checkable
2029
class Iterable(Protocol[_T_co]):
2130
@abstractmethod
2231
def __iter__(self) -> Iterator[_T_co]: ...
2332

33+
@runtime_checkable
2434
class Iterator(Iterable[_T_co], Protocol[_T_co]):
2535
@abstractmethod
2636
def __next__(self) -> _T_co: ...
2737
def __iter__(self) -> Iterator[_T_co]: ...
2838

39+
@runtime_checkable
2940
class Container(Protocol[_T_co]):
3041
# This is generic more on vibes than anything else
3142
@abstractmethod
3243
def __contains__(self, x: object, /) -> bool: ...
3344

45+
@runtime_checkable
3446
class Collection(Iterable[_T_co], Container[_T_co], Protocol[_T_co]):
3547
# Implement Sized (but don't have it as a base class).
3648
@abstractmethod
3749
def __len__(self) -> int: ...
3850

51+
@runtime_checkable
3952
class SupportsIndex(Protocol, metaclass=ABCMeta):
4053
@abstractmethod
4154
def __index__(self) -> int: ...

0 commit comments

Comments
 (0)