Skip to content

Commit 9b1fba4

Browse files
committed
stubtest: attempt to resolve decorators from their type
Fixes #19689
1 parent 87e9425 commit 9b1fba4

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

mypy/stubtest.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import enum
1313
import functools
1414
import importlib
15-
import importlib.machinery
1615
import inspect
1716
import os
1817
import pkgutil
@@ -35,11 +34,11 @@
3534

3635
import mypy.build
3736
import mypy.checkexpr
38-
import mypy.checkmember
3937
import mypy.erasetype
4038
import mypy.modulefinder
4139
import mypy.nodes
4240
import mypy.state
41+
import mypy.subtypes
4342
import mypy.types
4443
import mypy.version
4544
from mypy import nodes
@@ -1540,11 +1539,48 @@ def apply_decorator_to_funcitem(
15401539
for decorator in dec.original_decorators:
15411540
resulting_func = apply_decorator_to_funcitem(decorator, func)
15421541
if resulting_func is None:
1542+
# We couldn't figure out how to apply the decorator by transforming nodes, so try to
1543+
# reconstitute a FuncDef from the resulting type of the decorator
1544+
# This is worse because e.g. we lose the values of defaults
1545+
dec_type = mypy.types.get_proper_type(dec.type)
1546+
callable_type = None
1547+
if isinstance(dec_type, mypy.types.Instance):
1548+
callable_type = mypy.subtypes.find_member(
1549+
"__call__", dec_type, dec_type, is_operator=True
1550+
)
1551+
elif isinstance(dec_type, mypy.types.CallableType):
1552+
callable_type = dec_type
1553+
1554+
callable_type = mypy.types.get_proper_type(callable_type)
1555+
if isinstance(callable_type, mypy.types.CallableType):
1556+
return _resolve_funcitem_from_callable_type(callable_type)
15431557
return None
1558+
15441559
func = resulting_func
15451560
return func
15461561

15471562

1563+
def _resolve_funcitem_from_callable_type(typ: mypy.types.CallableType) -> nodes.FuncDef:
1564+
args: list[nodes.Argument] = []
1565+
1566+
for i, (arg_type, arg_kind, arg_name) in enumerate(
1567+
zip(typ.arg_types, typ.arg_kinds, typ.arg_names, strict=True)
1568+
):
1569+
var_name = arg_name if arg_name is not None else f"__arg{i}"
1570+
var = nodes.Var(var_name, arg_type)
1571+
pos_only = arg_name is None and arg_kind == nodes.ARG_POS
1572+
args.append(
1573+
nodes.Argument(
1574+
variable=var,
1575+
type_annotation=arg_type,
1576+
initializer=None, # CallableType doesn't store the values of defaults
1577+
kind=arg_kind,
1578+
pos_only=pos_only,
1579+
)
1580+
)
1581+
return nodes.FuncDef(name=typ.name or "", arguments=args, body=nodes.Block([]), typ=typ)
1582+
1583+
15481584
@verify.register(nodes.Decorator)
15491585
def verify_decorator(
15501586
stub: nodes.Decorator, runtime: MaybeMissing[Any], object_path: list[str]

mypy/test/teststubtest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,56 @@ def f(a, *args): ...
900900
error=None,
901901
)
902902

903+
@collect_cases
904+
def test_decorated_overload(self) -> Iterator[Case]:
905+
yield Case(
906+
stub="""
907+
from typing import overload
908+
909+
class _dec1:
910+
def __init__(self, func: object) -> None: ...
911+
def __call__(self, x: str) -> str: ...
912+
913+
@overload
914+
def good1(x: int) -> int: ...
915+
@overload
916+
@_dec1
917+
def good1(unrelated: int, whatever: str) -> str: ...
918+
""",
919+
runtime="def good1(x): ...",
920+
error=None,
921+
)
922+
yield Case(
923+
stub="""
924+
class _dec2:
925+
def __init__(self, func: object) -> None: ...
926+
def __call__(self, x: str, y: int) -> str: ...
927+
928+
@overload
929+
def good2(x: int) -> str: ...
930+
@overload
931+
@_dec2
932+
def good2(unrelated: int, whatever: str) -> str: ...
933+
""",
934+
runtime="def good2(x, y=...): ...",
935+
error=None,
936+
)
937+
yield Case(
938+
stub="""
939+
class _dec3:
940+
def __init__(self, func: object) -> None: ...
941+
def __call__(self, x: str, y: int) -> str: ...
942+
943+
@overload
944+
def bad(x: int) -> str: ...
945+
@overload
946+
@_dec3
947+
def bad(unrelated: int, whatever: str) -> str: ...
948+
""",
949+
runtime="def bad(x): ...",
950+
error="bad",
951+
)
952+
903953
@collect_cases
904954
def test_property(self) -> Iterator[Case]:
905955
yield Case(

0 commit comments

Comments
 (0)