Skip to content

Commit 87543b7

Browse files
authored
fix: Avoid false negatives for TC001-003 related to typing.cast.
1 parent eef57df commit 87543b7

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

flake8_type_checking/checker.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,31 @@ def visit_annotated_value(self, node: ast.expr) -> None:
992992
self.import_visitor.in_soft_use_context = previous_context
993993

994994

995+
class CastTypeExpressionVisitor(AnnotationVisitor):
996+
"""Visit a cast type expression and collect all the quoted names."""
997+
998+
def __init__(self, typing_lookup: SupportsIsTyping) -> None:
999+
#: All the quoted_names referenced inside the type expression
1000+
self.quoted_names: set[str] = set()
1001+
self._typing_lookup = typing_lookup
1002+
1003+
def is_typing(self, node: ast.AST, symbol: str) -> bool:
1004+
"""Check if the given node matches the given typing symbol."""
1005+
return self._typing_lookup.is_typing(node, symbol)
1006+
1007+
def visit_annotation_name(self, node: ast.Name) -> None:
1008+
"""Ignore visited names."""
1009+
# We could either record them as quoted names pre-emptively or
1010+
# as uses, but neither seems ideal, let's just skip these names
1011+
# as we have previously.
1012+
1013+
def visit_annotation_string(self, node: ast.Constant) -> None:
1014+
"""Collect all the names referenced inside the forward reference."""
1015+
visitor = StringAnnotationVisitor(self._typing_lookup)
1016+
visitor.parse_and_visit_string_annotation(node.value)
1017+
self.quoted_names.update(visitor.names)
1018+
1019+
9951020
class ImportVisitor(
9961021
DunderAllMixin,
9971022
FunctoolsSingledispatchMixin,
@@ -1081,6 +1106,10 @@ def __init__(
10811106
#: Where typing.cast() is called with an unquoted type.
10821107
self.unquoted_types_in_casts: list[tuple[int, int, str]] = []
10831108

1109+
#: All forward referenced names used in cast type expressions
1110+
# we need to track this in order to avoid false negatives for TC001-003
1111+
self.quoted_type_names_in_casts: set[str] = set()
1112+
10841113
#: For tracking which comprehension/IfExp we're currently inside of
10851114
self.active_context: Comprehension | ast.IfExp | None = None
10861115

@@ -1895,6 +1924,10 @@ def register_unquoted_type_in_typing_cast(self, node: ast.Call) -> None:
18951924

18961925
arg = node.args[0]
18971926

1927+
visitor = CastTypeExpressionVisitor(self)
1928+
visitor.visit(arg)
1929+
self.quoted_type_names_in_casts.update(visitor.quoted_names)
1930+
18981931
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
18991932
return # Type argument is already a string literal.
19001933

@@ -1999,10 +2032,13 @@ def unused_imports(self) -> Flake8Generator:
19992032
unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses
20002033
used_imports = all_imports - unused_imports
20012034
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
2002-
annotation_names = (
2003-
[n for i in self.visitor.wrapped_annotations for n in i.names]
2004-
+ [i.annotation for i in self.visitor.unwrapped_annotations]
2005-
+ [n for i in self.visitor.excess_wrapped_annotations for n in i.names]
2035+
annotation_names = list(
2036+
chain(
2037+
(n for i in self.visitor.wrapped_annotations for n in i.names),
2038+
(i.annotation for i in self.visitor.unwrapped_annotations),
2039+
(n for i in self.visitor.excess_wrapped_annotations for n in i.names),
2040+
self.visitor.quoted_type_names_in_casts,
2041+
)
20062042
)
20072043

20082044
for name in unused_imports:

tests/test_tc001_to_tc003.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,30 @@ def example() -> Any:
323323
),
324324
set(),
325325
),
326+
# Issue #127
327+
(
328+
textwrap.dedent(
329+
f'''
330+
from {import_} import Foo
331+
from typing import Any, cast
332+
333+
a = cast('Foo', 1)
334+
'''
335+
),
336+
{'2:0 ' + ERROR.format(module=f'{import_}.Foo')},
337+
),
338+
# forward reference in sub-expression of cast type
339+
(
340+
textwrap.dedent(
341+
f'''
342+
from {import_} import Foo
343+
from typing import Any, cast
344+
345+
a = cast(list['Foo'], 1)
346+
'''
347+
),
348+
{'2:0 ' + ERROR.format(module=f'{import_}.Foo')},
349+
),
326350
]
327351

328352
return [

0 commit comments

Comments
 (0)