Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Darsstar committed Nov 21, 2023
1 parent c58d690 commit df9978f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 41 deletions.
54 changes: 29 additions & 25 deletions src/aio_pika/queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import sys
from contextlib import suppress
from functools import partial
from types import TracebackType
from typing import Any, Awaitable, Callable, Optional, Type, overload

import aiormq
from aiormq import ConnectionClosed
from aiormq.abc import DeliveredMessage
from pamqp.common import Arguments

Expand Down Expand Up @@ -63,6 +65,8 @@ def __init__(
self.arguments = arguments
self.passive = passive

self.channel.close_callbacks.add(self.close_callbacks)

def __str__(self) -> str:
return f"{self.name}"

Expand Down Expand Up @@ -178,6 +182,7 @@ async def unbind(
arguments,
)

self.channel.close_callbacks.discard(self.close_callbacks)
channel = await self.channel.get_underlay_channel()
return await channel.queue_unbind(
queue=self.name,
Expand Down Expand Up @@ -267,6 +272,7 @@ async def cancel(
:return: Basic.CancelOk when operation completed successfully
"""

self.channel.close_callbacks.discard(self.close_callbacks)
channel = await self.channel.get_underlay_channel()
return await channel.basic_cancel(
consumer_tag=consumer_tag, nowait=nowait, timeout=timeout,
Expand Down Expand Up @@ -410,15 +416,18 @@ class QueueIterator(AbstractQueueIterator):
def consumer_tag(self) -> Optional[ConsumerTag]:
return getattr(self, "_consumer_tag", None)

async def _on_channel_close(self) -> None:
await self.close()

async def close(self, *_: Any) -> None:
log.debug("Cancelling queue iterator %r", self)

await self._closed.wait()

if not hasattr(self, "_consumer_tag"):
log.debug("Queue iterator %r already cancelled", self)
return

self._closed.set()

if self._amqp_queue.channel.is_closed:
log.debug("Queue iterator %r channel closed", self)
return
Expand All @@ -428,7 +437,9 @@ async def close(self, *_: Any) -> None:
del self._consumer_tag

self._amqp_queue.close_callbacks.remove(self.close)
await self._amqp_queue.cancel(consumer_tag)

with suppress(ConnectionClosed):
await self._amqp_queue.cancel(consumer_tag)

log.debug("Queue iterator %r closed", self)

Expand All @@ -442,23 +453,23 @@ async def close(self, *_: Any) -> None:
if msg is None:
return

if self._amqp_queue.channel.is_closed:
log.warning(
"Message %r lost when queue iterator %r channel closed",
msg,
self,
)
return
if self._amqp_queue.channel.is_closed:
log.warning(
"Message %r lost when queue iterator %r channel closed",
msg,
self,
)
return

if self._consume_kwargs.get("no_ack", False):
log.warning(
"Message %r lost for consumer with no_ack %r",
msg,
self,
)
return
if self._consume_kwargs.get("no_ack", False):
log.warning(
"Message %r lost for consumer with no_ack %r",
msg,
self,
)
return

await msg.nack(requeue=True, multiple=True)
await msg.nack(requeue=True, multiple=True)

def __str__(self) -> str:
return f"queue[{self._amqp_queue}](...)"
Expand All @@ -476,14 +487,7 @@ def __init__(self, queue: Queue, **kwargs: Any):
self._queue = asyncio.Queue()
self._consume_kwargs = kwargs
self._closed = asyncio.Event()
self.close_task = asyncio.create_task(self.close())

async def close(*args, **kwargs) -> None:
self._closed.set()

await self.close_task

setattr(self, "close", close)
self._amqp_queue.close_callbacks.add(self.close)

async def on_message(self, message: AbstractIncomingMessage) -> None:
Expand Down
5 changes: 4 additions & 1 deletion src/aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,11 @@ async def _on_connected(self) -> None:
closing = self.loop.create_future()
closing.set_exception(e)
await self.close_callbacks(closing)
await asyncio.gather(
values = await asyncio.gather(
transport.connection.close(e),
return_exceptions=True,
)
pass
raise

if self.connection_attempt:
Expand Down Expand Up @@ -144,6 +145,8 @@ async def __connection_factory(self) -> None:
self.__fail_fast_future.set_result(None)

log.debug("Connection made on %r", self)

continue
except CONNECTION_EXCEPTIONS as e:
if not self.__fail_fast_future.done():
self.__fail_fast_future.set_exception(e)
Expand Down
23 changes: 9 additions & 14 deletions src/aio_pika/robust_queue.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import uuid
import warnings
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union, TYPE_CHECKING

import aiormq
from aiormq import ChannelInvalidStateError
from pamqp.common import Arguments

from .abc import (
AbstractChannel, AbstractExchange, AbstractIncomingMessage,
AbstractExchange, AbstractIncomingMessage,
AbstractQueueIterator, AbstractRobustQueue, ConsumerTag, TimeoutType,
)
from .exchange import ExchangeParamType
from .log import get_logger
from .queue import Queue, QueueIterator

if TYPE_CHECKING:
from .channel import Channel

log = get_logger(__name__)

Expand All @@ -26,7 +28,7 @@ class RobustQueue(Queue, AbstractRobustQueue):

def __init__(
self,
channel: AbstractChannel,
channel: "Channel",
name: Optional[str],
durable: bool = False,
exclusive: bool = False,
Expand Down Expand Up @@ -152,18 +154,11 @@ def iterator(self, **kwargs: Any) -> AbstractQueueIterator:

class RobustQueueIterator(QueueIterator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def _on_channel_close(self) -> None:
if not self._amqp_queue.channel._closed:
return

inherited = self.close

async def close(*args, **kwargs) -> None:
if not self._amqp_queue.channel._closed:
return

await inherited(*args, **kwargs)

setattr(self, "close", close)
await super()._on_channel_close()

async def consume(self) -> None:
while True:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ async def application_stop_request():

data = list()

await asyncio.gather(p, c, asr)
await asyncio.gather(p, c, asr, return_exceptions=True)

assert data == list(map(lambda x: str(x).encode(), range(messages.maxsize)))

Expand Down

0 comments on commit df9978f

Please sign in to comment.