Skip to content

Commit c81ff55

Browse files
authored
fix: Injector plugin requires all annotations in the injected function
1 parent 7ce3b36 commit c81ff55

File tree

2 files changed

+26
-39
lines changed

2 files changed

+26
-39
lines changed

flake8_type_checking/checker.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,9 @@ class InjectorMixin:
475475
def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
476476
...
477477

478+
def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
479+
...
480+
478481
def visit_FunctionDef(self, node: FunctionDef) -> None:
479482
"""Remove and map function arguments and returns."""
480483
super().visit_FunctionDef(node) # type: ignore[misc]
@@ -487,6 +490,12 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
487490
if self.injector_enabled:
488491
self.handle_injector_declaration(node)
489492

493+
def _has_injected_annotation(self, node: AsyncFunctionDef | FunctionDef) -> bool:
494+
return any(
495+
isinstance(expr, ast.Subscript) and self.lookup_full_name(expr.value) == 'injector.Inject'
496+
for expr in iter_function_annotation_nodes(node)
497+
)
498+
490499
def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> None:
491500
"""
492501
Adjust for injector declaration setting.
@@ -496,17 +505,11 @@ def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> N
496505
497506
To achieve this, we just visit the annotations to register them as "uses".
498507
"""
499-
for path in [node.args.args, node.args.kwonlyargs]:
500-
for argument in path:
501-
if hasattr(argument, 'annotation') and argument.annotation:
502-
annotation = argument.annotation
503-
if not hasattr(annotation, 'value'):
504-
continue
505-
value = annotation.value
506-
if hasattr(value, 'id') and value.id == 'Inject':
507-
self.visit(argument.annotation)
508-
if hasattr(value, 'attr') and value.attr == 'Inject':
509-
self.visit(argument.annotation)
508+
if not self._has_injected_annotation(node):
509+
return
510+
511+
for expr in iter_function_annotation_nodes(node):
512+
self.visit(expr)
510513

511514

512515
class FastAPIMixin:
@@ -592,6 +595,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
592595
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
593596
"""Remove and map function arguments and returns."""
594597
super().visit_AsyncFunctionDef(node) # type: ignore[misc]
598+
if self.in_type_checking_block(node.lineno, node.col_offset):
599+
return
595600
if self.has_singledispatch_decorator(node):
596601
for expr in iter_function_annotation_nodes(node):
597602
self.visit(expr)

tests/test_injector.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, service: Inject[Service]) -> None:
5454
@pytest.mark.parametrize(
5555
('enabled', 'expected'),
5656
[
57-
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
57+
(True, set()),
5858
(
5959
False,
6060
{
@@ -65,8 +65,8 @@ def __init__(self, service: Inject[Service]) -> None:
6565
),
6666
],
6767
)
68-
def test_injector_option_only_allows_injected_dependencies(enabled, expected):
69-
"""Whenever an injector option is enabled, only injected dependencies should be ignored."""
68+
def test_injector_option_all_annotations_in_function_are_runtime_dependencies(enabled, expected):
69+
"""Whenever an argument is injected, all the other annotations are runtime required too."""
7070
example = textwrap.dedent(
7171
'''
7272
from injector import Inject
@@ -82,38 +82,20 @@ def __init__(self, service: Inject[Service], other: OtherDependency) -> None:
8282
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected
8383

8484

85-
@pytest.mark.parametrize(
86-
('enabled', 'expected'),
87-
[
88-
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
89-
(
90-
False,
91-
{
92-
'2:0 ' + TC002.format(module='injector.Inject'),
93-
'3:0 ' + TC002.format(module='services.Service'),
94-
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
95-
},
96-
),
97-
],
98-
)
99-
def test_injector_option_only_allows_injector_slices(enabled, expected):
100-
"""
101-
Whenever an injector option is enabled, only injected dependencies should be ignored,
102-
not any dependencies with slices.
103-
"""
85+
def test_injector_option_require_injections_under_unpack():
86+
"""Whenever an injector option is enabled, injected dependencies should be ignored, even if unpacked."""
10487
example = textwrap.dedent(
10588
"""
89+
from typing import Unpack
10690
from injector import Inject
107-
from services import Service
108-
from other_dependency import OtherDependency
109-
91+
from services import ServiceKwargs
11092
class X:
111-
def __init__(self, service: Inject[Service], other_deps: list[OtherDependency]) -> None:
93+
def __init__(self, service: Inject[Service], **kwargs: Unpack[ServiceKwargs]) -> None:
11294
self.service = service
113-
self.other_deps = other_deps
95+
self.args = args
11496
"""
11597
)
116-
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected
98+
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=True) == set()
11799

118100

119101
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)