Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added missing tests for __init__ and __await__ #12

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/async_object/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __new__(mcs: type[__Self], name: str, bases: tuple[type, ...], namespace: di
b.__name__ for b in bases if not issubclass(b, absolute_base_class) and (b.__init__ is not object.__init__)
]:
raise TypeError(f"These non-async base classes define a custom __init__: {', '.join(map(repr, invalid_bases))}")
if invalid_bases := [b.__name__ for b in bases if hasattr(b, "__await__")]:
raise TypeError(f"These base classes define __await__: {', '.join(map(repr, invalid_bases))}")
return super().__new__(mcs, name, bases, namespace, **kwargs)

def __setattr__(cls, name: str, value: Any, /) -> None:
Expand Down
31 changes: 29 additions & 2 deletions src/async_object/contrib/mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,35 @@ def async_class_def_callback(ctx: ClassDefContext) -> None:

info.names[new_ctor_name] = info.names[ctor]

if info.get_method("__await__") is not None:
ctx.api.fail('AsyncObject subclasses must not have "__await__" method', ctx.cls, code=errorcodes.OVERRIDE)
if dunder_await_node := info.names.get("__await__"):
node_ctx = dunder_await_node.node if dunder_await_node.node else ctx.cls
ctx.api.fail(
'AsyncObject subclasses must not have "__await__" method',
node_ctx,
serious=True,
code=errorcodes.OVERRIDE,
)
elif base_classes_with_dunder_await := [cls_info.defn.name for cls_info in info.mro[1:] if "__await__" in cls_info.names]:
ctx.api.fail(
f"These base classes define __await__: {', '.join(map(repr, base_classes_with_dunder_await))}",
ctx.cls,
serious=True,
code=errorcodes.OVERRIDE,
)

non_async_base_class_info_list = list(
filter(lambda cls_info: not cls_info.has_base(_ASYNC_OBJECT_BASE_CLASS_FULLNAME), info.mro[1:-1])
)

if non_async_base_classes_with_dunder_init := [
cls_info.defn.name for cls_info in non_async_base_class_info_list if "__init__" in cls_info.names
]:
ctx.api.fail(
f"These non-async base classes define a custom __init__: {', '.join(map(repr, non_async_base_classes_with_dunder_init))}",
ctx.cls,
serious=True,
code=errorcodes.OVERRIDE,
)


def __set_func_def_name(defn: FuncDef, name: str) -> None:
Expand Down
21 changes: 18 additions & 3 deletions tests/test_metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class B:
def __init__(self) -> None:
pass

class _(AsyncObject, A, B):
class _(AsyncObject, A, B): # type: ignore[override]
pass


Expand All @@ -52,11 +52,26 @@ def __init__(self) -> None: # type: ignore[override]
def test_dunder_await_defined() -> None:
with pytest.raises(TypeError, match=r"^AsyncObject subclasses must not have __await__ method$"):

class _(AsyncObject): # type: ignore[override] # We are testing the final case
def __await__(self) -> Generator[Any, Any, Any]:
class _(AsyncObject):
def __await__(self) -> Generator[Any, Any, Any]: # type: ignore[override] # We are testing the final case
raise NotImplementedError


def test_base_class_with_dunder_await() -> None:
with pytest.raises(TypeError, match=r"^These base classes define __await__: 'A', 'B'$"):

class A:
def __await__(self) -> None:
pass

class B:
def __await__(self) -> None:
pass

class _(AsyncObject, A, B): # type: ignore[override]
pass


def test_AsyncObject_immutable_on_set() -> None:
with pytest.raises(AttributeError, match=r"^AsyncObject is immutable$"):
setattr(AsyncObject, "something", None)
Expand Down