From 7e340469b5b854247eab9bb526d0f4dfd0e8e42a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Wed, 27 Dec 2023 14:08:04 +0100 Subject: [PATCH] Added missing tests for `__init__` and `__await__` (#12) --- src/async_object/__init__.py | 2 ++ src/async_object/contrib/mypy/plugin.py | 31 +++++++++++++++++++++++-- tests/test_metaclass.py | 21 ++++++++++++++--- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/async_object/__init__.py b/src/async_object/__init__.py index 60641fa..84bc70b 100644 --- a/src/async_object/__init__.py +++ b/src/async_object/__init__.py @@ -64,6 +64,8 @@ def __new__(mcs: type[__Self], name: str, bases: tuple[type, ...], namespace: di b.__name__ for b in bases if not issubclass(b, absolute_base_class) and (b.__init__ is not object.__init__) ]: raise TypeError(f"These non-async base classes define a custom __init__: {', '.join(map(repr, invalid_bases))}") + if invalid_bases := [b.__name__ for b in bases if hasattr(b, "__await__")]: + raise TypeError(f"These base classes define __await__: {', '.join(map(repr, invalid_bases))}") return super().__new__(mcs, name, bases, namespace, **kwargs) def __setattr__(cls, name: str, value: Any, /) -> None: diff --git a/src/async_object/contrib/mypy/plugin.py b/src/async_object/contrib/mypy/plugin.py index b8d3e2f..6c81c68 100644 --- a/src/async_object/contrib/mypy/plugin.py +++ b/src/async_object/contrib/mypy/plugin.py @@ -99,8 +99,35 @@ def async_class_def_callback(ctx: ClassDefContext) -> None: info.names[new_ctor_name] = info.names[ctor] - if info.get_method("__await__") is not None: - ctx.api.fail('AsyncObject subclasses must not have "__await__" method', ctx.cls, code=errorcodes.OVERRIDE) + if dunder_await_node := info.names.get("__await__"): + node_ctx = dunder_await_node.node if dunder_await_node.node else ctx.cls + ctx.api.fail( + 'AsyncObject subclasses must not have "__await__" method', + node_ctx, + serious=True, + code=errorcodes.OVERRIDE, + ) + elif base_classes_with_dunder_await := [cls_info.defn.name for cls_info in info.mro[1:] if "__await__" in cls_info.names]: + ctx.api.fail( + f"These base classes define __await__: {', '.join(map(repr, base_classes_with_dunder_await))}", + ctx.cls, + serious=True, + code=errorcodes.OVERRIDE, + ) + + non_async_base_class_info_list = list( + filter(lambda cls_info: not cls_info.has_base(_ASYNC_OBJECT_BASE_CLASS_FULLNAME), info.mro[1:-1]) + ) + + if non_async_base_classes_with_dunder_init := [ + cls_info.defn.name for cls_info in non_async_base_class_info_list if "__init__" in cls_info.names + ]: + ctx.api.fail( + f"These non-async base classes define a custom __init__: {', '.join(map(repr, non_async_base_classes_with_dunder_init))}", + ctx.cls, + serious=True, + code=errorcodes.OVERRIDE, + ) def __set_func_def_name(defn: FuncDef, name: str) -> None: diff --git a/tests/test_metaclass.py b/tests/test_metaclass.py index ab65a62..8f2d086 100644 --- a/tests/test_metaclass.py +++ b/tests/test_metaclass.py @@ -37,7 +37,7 @@ class B: def __init__(self) -> None: pass - class _(AsyncObject, A, B): + class _(AsyncObject, A, B): # type: ignore[override] pass @@ -52,11 +52,26 @@ def __init__(self) -> None: # type: ignore[override] def test_dunder_await_defined() -> None: with pytest.raises(TypeError, match=r"^AsyncObject subclasses must not have __await__ method$"): - class _(AsyncObject): # type: ignore[override] # We are testing the final case - def __await__(self) -> Generator[Any, Any, Any]: + class _(AsyncObject): + def __await__(self) -> Generator[Any, Any, Any]: # type: ignore[override] # We are testing the final case raise NotImplementedError +def test_base_class_with_dunder_await() -> None: + with pytest.raises(TypeError, match=r"^These base classes define __await__: 'A', 'B'$"): + + class A: + def __await__(self) -> None: + pass + + class B: + def __await__(self) -> None: + pass + + class _(AsyncObject, A, B): # type: ignore[override] + pass + + def test_AsyncObject_immutable_on_set() -> None: with pytest.raises(AttributeError, match=r"^AsyncObject is immutable$"): setattr(AsyncObject, "something", None)