Skip to content

Commit 6eb332d

Browse files
committed
Improved shard connect and shard disconnect to reliably call the event on time
- Improved "closed mid request" check in RESTClient
1 parent 48d576d commit 6eb332d

File tree

5 files changed

+7
-12
lines changed

5 files changed

+7
-12
lines changed

changes/idk.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Ensure shard connect and disconnect always get sent in pairs and properly waited for

hikari/events/shard_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class ShardStateEvent(ShardEvent, abc.ABC):
100100
@attrs_extensions.with_copy
101101
@attrs.define(kw_only=True, weakref_slot=False)
102102
class ShardConnectedEvent(ShardStateEvent):
103-
"""Event fired when a shard connects."""
103+
"""Event fired when a shard successfully connects."""
104104

105105
app: traits.RESTAware = attrs.field(metadata={attrs_extensions.SKIP_DEEP_COPY: True})
106106
# <<inherited docstring from Event>>.

hikari/impl/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ async def _request(
731731

732732
await aio.first_completed(request_task, self._close_event.wait())
733733

734-
if not self._close_event.is_set():
734+
if not request_task.cancelled():
735735
return request_task.result()
736736

737737
raise errors.ComponentStateConflictError("The REST client was closed mid-request")

hikari/impl/shard.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,6 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]:
818818
dumps=self._dumps,
819819
url=url,
820820
)
821-
self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self))
822821

823822
# Expect initial HELLO
824823
hello_payload = await self._ws.receive_json()
@@ -893,6 +892,7 @@ async def _keep_alive(self) -> None:
893892
if not self._handshake_event.is_set():
894893
continue
895894

895+
await self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self))
896896
await aio.first_completed(*lifetime_tasks)
897897

898898
# Since nothing went wrong, we can reset the backoff and try again
@@ -957,7 +957,9 @@ async def _keep_alive(self) -> None:
957957
else:
958958
await ws.send_close(code=_RESUME_CLOSE_CODE, message=b"shard disconnecting temporarily")
959959

960-
self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self))
960+
if self._handshake_event.is_set():
961+
# We dispatched the connected event, so we can dispatch the disconnected one too
962+
await self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self))
961963

962964
def _serialize_and_store_presence_payload(
963965
self,

tests/hikari/impl/test_shard.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,10 +1014,6 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy
10141014
dumps=client._dumps,
10151015
url="wss://somewhere.com?somewhere=true&v=400&encoding=json",
10161016
)
1017-
client._event_factory.deserialize_connected_event.assert_called_once_with(client)
1018-
client._event_manager.dispatch.assert_called_once_with(
1019-
client._event_factory.deserialize_connected_event.return_value
1020-
)
10211017

10221018
assert create_task.call_count == 2
10231019
create_task.assert_has_calls(
@@ -1103,10 +1099,6 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set
11031099
transport_compression=True,
11041100
url="wss://notsomewhere.com?somewhere=true&v=400&encoding=json&compress=zlib-stream",
11051101
)
1106-
client._event_factory.deserialize_connected_event.assert_called_once_with(client)
1107-
client._event_manager.dispatch.assert_called_once_with(
1108-
client._event_factory.deserialize_connected_event.return_value
1109-
)
11101102

11111103
assert create_task.call_count == 2
11121104
create_task.assert_has_calls(

0 commit comments

Comments
 (0)