From df9978f64b18056b4b070a74f1c2ab2914a8438a Mon Sep 17 00:00:00 2001 From: Dos Moonen Date: Tue, 21 Nov 2023 15:54:44 +0100 Subject: [PATCH] WIP --- src/aio_pika/queue.py | 54 +++++++++++++++++-------------- src/aio_pika/robust_connection.py | 5 ++- src/aio_pika/robust_queue.py | 23 ++++++------- tests/test_amqp.py | 2 +- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/src/aio_pika/queue.py b/src/aio_pika/queue.py index b7b6b770..3b979176 100644 --- a/src/aio_pika/queue.py +++ b/src/aio_pika/queue.py @@ -1,10 +1,12 @@ import asyncio import sys +from contextlib import suppress from functools import partial from types import TracebackType from typing import Any, Awaitable, Callable, Optional, Type, overload import aiormq +from aiormq import ConnectionClosed from aiormq.abc import DeliveredMessage from pamqp.common import Arguments @@ -63,6 +65,8 @@ def __init__( self.arguments = arguments self.passive = passive + self.channel.close_callbacks.add(self.close_callbacks) + def __str__(self) -> str: return f"{self.name}" @@ -178,6 +182,7 @@ async def unbind( arguments, ) + self.channel.close_callbacks.discard(self.close_callbacks) channel = await self.channel.get_underlay_channel() return await channel.queue_unbind( queue=self.name, @@ -267,6 +272,7 @@ async def cancel( :return: Basic.CancelOk when operation completed successfully """ + self.channel.close_callbacks.discard(self.close_callbacks) channel = await self.channel.get_underlay_channel() return await channel.basic_cancel( consumer_tag=consumer_tag, nowait=nowait, timeout=timeout, @@ -410,15 +416,18 @@ class QueueIterator(AbstractQueueIterator): def consumer_tag(self) -> Optional[ConsumerTag]: return getattr(self, "_consumer_tag", None) + async def _on_channel_close(self) -> None: + await self.close() + async def close(self, *_: Any) -> None: log.debug("Cancelling queue iterator %r", self) - await self._closed.wait() - if not hasattr(self, "_consumer_tag"): log.debug("Queue iterator %r already cancelled", self) return + self._closed.set() + if self._amqp_queue.channel.is_closed: log.debug("Queue iterator %r channel closed", self) return @@ -428,7 +437,9 @@ async def close(self, *_: Any) -> None: del self._consumer_tag self._amqp_queue.close_callbacks.remove(self.close) - await self._amqp_queue.cancel(consumer_tag) + + with suppress(ConnectionClosed): + await self._amqp_queue.cancel(consumer_tag) log.debug("Queue iterator %r closed", self) @@ -442,23 +453,23 @@ async def close(self, *_: Any) -> None: if msg is None: return - if self._amqp_queue.channel.is_closed: - log.warning( - "Message %r lost when queue iterator %r channel closed", - msg, - self, - ) - return + if self._amqp_queue.channel.is_closed: + log.warning( + "Message %r lost when queue iterator %r channel closed", + msg, + self, + ) + return - if self._consume_kwargs.get("no_ack", False): - log.warning( - "Message %r lost for consumer with no_ack %r", - msg, - self, - ) - return + if self._consume_kwargs.get("no_ack", False): + log.warning( + "Message %r lost for consumer with no_ack %r", + msg, + self, + ) + return - await msg.nack(requeue=True, multiple=True) + await msg.nack(requeue=True, multiple=True) def __str__(self) -> str: return f"queue[{self._amqp_queue}](...)" @@ -476,14 +487,7 @@ def __init__(self, queue: Queue, **kwargs: Any): self._queue = asyncio.Queue() self._consume_kwargs = kwargs self._closed = asyncio.Event() - self.close_task = asyncio.create_task(self.close()) - - async def close(*args, **kwargs) -> None: - self._closed.set() - - await self.close_task - setattr(self, "close", close) self._amqp_queue.close_callbacks.add(self.close) async def on_message(self, message: AbstractIncomingMessage) -> None: diff --git a/src/aio_pika/robust_connection.py b/src/aio_pika/robust_connection.py index c1a28368..6b0bcf02 100644 --- a/src/aio_pika/robust_connection.py +++ b/src/aio_pika/robust_connection.py @@ -111,10 +111,11 @@ async def _on_connected(self) -> None: closing = self.loop.create_future() closing.set_exception(e) await self.close_callbacks(closing) - await asyncio.gather( + values = await asyncio.gather( transport.connection.close(e), return_exceptions=True, ) + pass raise if self.connection_attempt: @@ -144,6 +145,8 @@ async def __connection_factory(self) -> None: self.__fail_fast_future.set_result(None) log.debug("Connection made on %r", self) + + continue except CONNECTION_EXCEPTIONS as e: if not self.__fail_fast_future.done(): self.__fail_fast_future.set_exception(e) diff --git a/src/aio_pika/robust_queue.py b/src/aio_pika/robust_queue.py index e85e7cd4..27c3c88a 100644 --- a/src/aio_pika/robust_queue.py +++ b/src/aio_pika/robust_queue.py @@ -1,19 +1,21 @@ import uuid import warnings -from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union, TYPE_CHECKING import aiormq from aiormq import ChannelInvalidStateError from pamqp.common import Arguments from .abc import ( - AbstractChannel, AbstractExchange, AbstractIncomingMessage, + AbstractExchange, AbstractIncomingMessage, AbstractQueueIterator, AbstractRobustQueue, ConsumerTag, TimeoutType, ) from .exchange import ExchangeParamType from .log import get_logger from .queue import Queue, QueueIterator +if TYPE_CHECKING: + from .channel import Channel log = get_logger(__name__) @@ -26,7 +28,7 @@ class RobustQueue(Queue, AbstractRobustQueue): def __init__( self, - channel: AbstractChannel, + channel: "Channel", name: Optional[str], durable: bool = False, exclusive: bool = False, @@ -152,18 +154,11 @@ def iterator(self, **kwargs: Any) -> AbstractQueueIterator: class RobustQueueIterator(QueueIterator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + async def _on_channel_close(self) -> None: + if not self._amqp_queue.channel._closed: + return - inherited = self.close - - async def close(*args, **kwargs) -> None: - if not self._amqp_queue.channel._closed: - return - - await inherited(*args, **kwargs) - - setattr(self, "close", close) + await super()._on_channel_close() async def consume(self) -> None: while True: diff --git a/tests/test_amqp.py b/tests/test_amqp.py index d0ba5fc5..3abba57e 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -1298,7 +1298,7 @@ async def application_stop_request(): data = list() - await asyncio.gather(p, c, asr) + await asyncio.gather(p, c, asr, return_exceptions=True) assert data == list(map(lambda x: str(x).encode(), range(messages.maxsize)))