Skip to content

Commit

Permalink
feat: Add async websocket_connect to AsyncTestClient (#3328)
Browse files Browse the repository at this point in the history
feat: Add `websocket_connect` method to AsyncTestClient

Co-authored-by: kedod <kedod>
  • Loading branch information
kedod authored Apr 6, 2024
1 parent fac641a commit 43e3041
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
59 changes: 57 additions & 2 deletions litestar/testing/client/async_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar
from urllib.parse import urljoin

from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response

from litestar import HttpMethod
from litestar.testing.client.base import BaseTestClient
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.testing.transport import TestClientTransport
from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport
from litestar.types import AnyIOBackend, ASGIApp

if TYPE_CHECKING:
Expand All @@ -27,6 +28,7 @@
from typing_extensions import Self

from litestar.middleware.session.base import BaseBackendConfig
from litestar.testing.websocket_test_session import WebSocketTestSession


T = TypeVar("T", bound=ASGIApp)
Expand Down Expand Up @@ -468,6 +470,59 @@ async def delete(
extensions=None if extensions is None else dict(extensions),
)

async def websocket_connect(
self,
url: str,
subprotocols: Sequence[str] | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
extensions: Mapping[str, Any] | None = None,
) -> WebSocketTestSession:
"""Sends a GET request to establish a websocket connection.
Args:
url: Request URL.
subprotocols: Websocket subprotocols.
params: Query parameters.
headers: Request headers.
cookies: Request cookies.
auth: Auth headers.
follow_redirects: Whether to follow redirects.
timeout: Request timeout.
extensions: Dictionary of ASGI extensions.
Returns:
A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance.
"""
url = urljoin("ws://testserver", url)
default_headers: dict[str, str] = {}
default_headers.setdefault("connection", "upgrade")
default_headers.setdefault("sec-websocket-key", "testserver==")
default_headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None:
default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
try:
await AsyncClient.request(
self,
"GET",
url,
headers={**dict(headers or {}), **default_headers}, # type: ignore[misc]
params=params,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=None if extensions is None else dict(extensions),
)
except ConnectionUpgradeExceptionError as exc:
return exc.session

raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover

async def get_session_data(self) -> dict[str, Any]:
"""Get session data.
Expand Down
51 changes: 50 additions & 1 deletion tests/unit/test_testing/test_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from litestar import Controller, WebSocket, delete, head, patch, put, websocket
from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT
from litestar.testing import AsyncTestClient, WebSocketTestSession, create_test_client
from litestar.testing import AsyncTestClient, WebSocketTestSession, create_async_test_client, create_test_client

if TYPE_CHECKING:
from litestar.middleware.session.base import BaseBackendConfig
Expand Down Expand Up @@ -261,3 +261,52 @@ async def handler(socket: WebSocket) -> None:
Empty
), client.websocket_connect("/"):
pass


@pytest.mark.parametrize("block,timeout", [(False, None), (False, 0.001), (True, 0.001)])
@pytest.mark.parametrize(
"receive_method",
[
WebSocketTestSession.receive,
WebSocketTestSession.receive_json,
WebSocketTestSession.receive_text,
WebSocketTestSession.receive_bytes,
],
)
async def test_websocket_test_session_block_timeout_async(
receive_method: Callable[..., Any], block: bool, timeout: Optional[float], anyio_backend: "AnyIOBackend"
) -> None:
@websocket()
async def handler(socket: WebSocket) -> None:
await socket.accept()

with pytest.raises(Empty):
async with create_async_test_client(handler, backend=anyio_backend) as client:
with await client.websocket_connect("/") as ws:
receive_method(ws, timeout=timeout, block=block)


async def test_websocket_accept_timeout_async(anyio_backend: "AnyIOBackend") -> None:
@websocket()
async def handler(socket: WebSocket) -> None:
pass

async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client:
with pytest.raises(Empty):
with await client.websocket_connect("/"):
pass


async def test_websocket_connect_async(anyio_backend: "AnyIOBackend") -> None:
@websocket()
async def handler(socket: WebSocket) -> None:
await socket.accept()
data = await socket.receive_json()
await socket.send_json(data)
await socket.close()

async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client:
with await client.websocket_connect("/", subprotocols="wamp") as ws:
ws.send_json({"data": "123"})
data = ws.receive_json()
assert data == {"data": "123"}

0 comments on commit 43e3041

Please sign in to comment.