Skip to content

Commit

Permalink
Use the same close reason across all frameworks
Browse files Browse the repository at this point in the history
  • Loading branch information
DoctorJohn committed Sep 17, 2024
1 parent 70096e9 commit 4afadd4
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
5 changes: 3 additions & 2 deletions strawberry/channels/handlers/ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ async def receive(self, *args: str, **kwargs: Any) -> None:
# Overriding this so that we can pass the errors to handle_invalid_message
try:
await super().receive(*args, **kwargs)
except ValueError as e:
await self._handler.handle_invalid_message(str(e))
except ValueError:
reason = "WebSocket message type must be text"
await self._handler.handle_invalid_message(reason)

async def receive_json(self, content: Any, **kwargs: Any) -> None:
await self._handler.handle_message(content)
Expand Down
7 changes: 2 additions & 5 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,10 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):

await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode())

data = await ws.receive(timeout=2)
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
if ws.name() == "channels":
ws.assert_reason("No text section for incoming WebSocket frame!")
else:
ws.assert_reason("WebSocket message type must be text")
ws.assert_reason("WebSocket message type must be text")


async def test_connection_init_timeout(
Expand Down
5 changes: 1 addition & 4 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,7 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 1002
if ws.name() == "channels":
ws.assert_reason("No text section for incoming WebSocket frame!")
else:
ws.assert_reason("WebSocket message type must be text")
ws.assert_reason("WebSocket message type must be text")


async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient):
Expand Down

0 comments on commit 4afadd4

Please sign in to comment.