Skip to content

Commit

Permalink
fix channel pools deadlocking everything on connection cancellation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Fuyukai committed Feb 1, 2024
1 parent 5284154 commit 8a4542a
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 38 deletions.
5 changes: 1 addition & 4 deletions src/serena/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _check_closed(self) -> None:
# todo: switch to our own exception?
raise ClosedResourceError("This channel is closed")

async def _close(self, payload: ChannelClosePayload | None) -> None:
def _close(self, payload: ChannelClosePayload | None) -> None:
"""
Closes this channel.
"""
Expand All @@ -180,9 +180,6 @@ async def _close(self, payload: ChannelClosePayload | None) -> None:

self._close_event.set()

# aclose doesn't seem to checkpoint...
await checkpoint()

def _enqueue_regular(self, frame: MethodFrame[MethodPayload]) -> None:
"""
Enqueues a regular method frame.
Expand Down
27 changes: 23 additions & 4 deletions src/serena/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,14 @@ async def _handle_control_frame(self, frame: MethodFrame[MethodPayload]) -> None
# noinspection PyAsyncCall
self._cancel_scope.cancel()

async def _remove_channel(self, channel_id: int, payload: ChannelClosePayload | None) -> None:
def _remove_channel(self, channel_id: int, payload: ChannelClosePayload | None) -> None:
"""
Removes a channel.
"""

self._channels[channel_id] = False
chan = self._channel_channels.pop(channel_id)
await chan._close(payload)
chan._close(payload)

async def _enqueue_frame(self, channel: Channel, frame: Frame) -> None:
"""
Expand Down Expand Up @@ -596,7 +596,7 @@ async def _listen_for_messages(self) -> None:
# ack the close
# todo: should his be here?
try:
await self._remove_channel(channel, close_payload)
self._remove_channel(channel, close_payload)
finally:
if is_unclean:
await self._send_method_frame(channel, ChannelCloseOkPayload())
Expand Down Expand Up @@ -625,12 +625,31 @@ async def _listen_for_messages(self) -> None:
channel_object = self._channel_channels[channel]
await self._enqueue_frame(channel_object, frame)

async def _listen_wrapper(self) -> None:
"""
Wraps ``listen_for_messages`` so that it'll automatically signal to all channels if
the task is closed.
"""

# https://matrix.to/#/%23python-trio_general%3Agitter.im/%24b81b-VPgvL6z6ALlACbL27TYjAk27Bz_jIM6q39QWW0?via=gitter.im&via=matrix.org
# Basically, if the top-most task that holds the connection open is cancelled, then the
# "wait for channel closure" logic won't be fired, because the task actually responsible for
# pumping the messages around is killed.
#
# So, this wraps the listen function in a try/finally that sets the event manually.

try:
await self._listen_for_messages()
finally:
for channel in self._channel_channels.values():
channel._close(None)

def _start_tasks(self, nursery: TaskGroup) -> None:
"""
Starts the background tasks for this connection.
"""
self._cancel_scope = nursery.cancel_scope
nursery.start_soon(self._listen_for_messages)
nursery.start_soon(self._listen_wrapper)
nursery.start_soon(self._heartbeat_loop)

async def _close(self, reply_code: int = 200, reply_text: str = "Normal close") -> None:
Expand Down
41 changes: 18 additions & 23 deletions src/serena/pool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import AsyncGenerator, AsyncIterable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from contextlib import asynccontextmanager
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -67,7 +67,6 @@ async def _open(self, initial_size: int) -> None:
"""

for _ in range(0, initial_size):
# noinspection PyAsyncCall
self._qwrite.send_nowait(await self._conn._open_channel())

async def _close(self) -> None:
Expand All @@ -94,20 +93,16 @@ async def _open_channels(self) -> NoReturn:
while True:
await self._needs_new_connection.wait()
channel = await self._conn._open_channel()
# noinspection PyAsyncCall
self._qwrite.send_nowait(channel)

# re-exported so that pycharm likes the type annotation better
def checkout(self) -> AbstractAsyncContextManager[Channel]:
"""
Checks out a new channel from the pool, and uses it persistently. The channel lifetime
will be automatically managed for you.
@asynccontextmanager
async def checkout(self) -> AsyncGenerator[Channel, None]:
"""
Checks out a new channel from the pool, and uses it persistently.
return self._checkout()
The channel lifetime will be automatically managed for you.
"""

@asynccontextmanager
async def _checkout(self) -> AsyncGenerator[Channel, None]:
channel = await self._qread.receive()

try:
Expand Down Expand Up @@ -158,7 +153,7 @@ async def basic_consume(
"""

async with (
self._checkout() as channel,
self.checkout() as channel,
channel.basic_consume(
queue_name=queue_name,
consumer_tag=consumer_tag,
Expand All @@ -182,7 +177,7 @@ async def basic_get(self, queue: str, *, no_ack: bool = False) -> AMQPMessage |
:return: A :class:`.AMQPMessage` if one existed on the queue, otherwise None.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
return await channel.basic_get(
queue=queue,
no_ack=no_ack,
Expand Down Expand Up @@ -219,7 +214,7 @@ async def basic_publish(
to close.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
await channel.basic_publish(
exchange_name=exchange_name,
routing_key=routing_key,
Expand Down Expand Up @@ -251,7 +246,7 @@ async def exchange_bind(
:return: Nothing.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
await channel.exchange_bind(
destination=destination,
source=source,
Expand Down Expand Up @@ -287,7 +282,7 @@ async def exchange_declare(
:return: The name of the exchange, as it exists on the server.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
return await channel.exchange_declare(
name=name,
type=type,
Expand All @@ -309,7 +304,7 @@ async def exchange_delete(self, name: str, *, if_unused: bool = False) -> None:
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
await channel.exchange_delete(
name=name,
if_unused=if_unused,
Expand Down Expand Up @@ -337,7 +332,7 @@ async def exchange_unbind(
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
await channel.exchange_unbind(
destination=destination,
source=source,
Expand All @@ -363,7 +358,7 @@ async def queue_bind(
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
await channel.queue_bind(
queue_name=queue_name,
exchange_name=exchange_name,
Expand Down Expand Up @@ -400,7 +395,7 @@ async def queue_declare(
:return: The :class:`.QueueDeclareOkPayload` the server returned.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
return await channel.queue_declare(
name=name,
passive=passive,
Expand All @@ -423,7 +418,7 @@ async def queue_delete(
:return: The number of messages deleted.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
return await channel.queue_delete(
queue_name=queue_name,
if_empty=if_empty,
Expand All @@ -439,7 +434,7 @@ async def queue_purge(self, queue_name: str) -> int:
:return: The number of messages deleted.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
return await channel.queue_purge(
queue_name=queue_name,
)
Expand All @@ -461,7 +456,7 @@ async def queue_unbind(
:param arguments: Implementation-specific arguments to use.
"""

async with self._checkout() as channel:
async with self.checkout() as channel:
return await channel.queue_unbind(
queue_name=queue_name,
exchange_name=exchange_name,
Expand Down
6 changes: 2 additions & 4 deletions src/serena/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,10 @@ def write_longlong_signed(self, value: int) -> None:
self._write(struct.pack(">q", value))

@overload
def write_timestamp(self, value: datetime) -> None:
...
def write_timestamp(self, value: datetime) -> None: ...

@overload
def write_timestamp(self, value: int) -> None:
...
def write_timestamp(self, value: int) -> None: ...

def write_timestamp(self, value: datetime | int) -> None:
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/channel/test_channel_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import anyio
import pytest

from tests import _open_connection

pytestmark = pytest.mark.anyio


@pytest.mark.slow
async def test_cancelling_outer_conn_doesnt_cause_hang():
async def _bad_task():
async with (
_open_connection() as conn,
conn.open_channel_pool() as pool,
pool.checkout() as _,
):
# keep
await anyio.sleep_forever()

async with anyio.create_task_group() as tg:
tg.start_soon(_bad_task)

# give it a chance to open
await anyio.sleep(1)
tg.cancel_scope.cancel()

assert True, "what?"
4 changes: 2 additions & 2 deletions tests/channel/test_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ async def test_ex_declaration_invalid_type():

with pytest.raises(UnexpectedCloseError) as e:
async with _open_connection() as conn:
async with conn.open_channel() as channel:
await channel.exchange_declare(name="invalid", type="invalid")
async with conn.open_channel() as channel:
await channel.exchange_declare(name="invalid", type="invalid")

assert e.value.reply_code in (ReplyCode.command_invalid, ReplyCode.precondition_failed)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from serena.enums import ReplyCode
from serena.exc import UnexpectedCloseError

from tests import _open_connection, AMQP_HOST, AMQP_PORT, AMQP_USERNAME
from tests import AMQP_HOST, AMQP_PORT, AMQP_USERNAME, _open_connection

pytestmark = pytest.mark.anyio

Expand Down

0 comments on commit 8a4542a

Please sign in to comment.