From c0b429bd16704e78f4f6a0771c6797fc870bfa3e Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 7 May 2024 23:15:29 +0000 Subject: [PATCH 01/61] expose all ws_connect arguments --- gql/transport/aiohttp_websockets.py | 126 ++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 gql/transport/aiohttp_websockets.py diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py new file mode 100644 index 00000000..5f04becd --- /dev/null +++ b/gql/transport/aiohttp_websockets.py @@ -0,0 +1,126 @@ +from enum import verify +import re +import time +import aiohttp +from gql.transport.async_transport import AsyncTransport +from typing import Optional, Union, Collection +from aiohttp.typedefs import LooseHeaders, Mapping, StrOrURL +from aiohttp.helpers import hdrs, BasicAuth, _SENTINEL + +"""HTTP Client for asyncio.""" + +from typing import ( + Collection, + Mapping, + Optional, + Union, +) + + +from aiohttp import hdrs +from aiohttp.client_reqrep import ( + Fingerprint, +) +from aiohttp.helpers import ( + _SENTINEL, + BasicAuth, + sentinel, +) +from aiohttp.typedefs import LooseHeaders, StrOrURL + +from ssl import SSLContext + + +class AIOHTTPWebsocketsTransport(AsyncTransport): + + def __init__( + self, + url: StrOrURL, + *, + method: str = hdrs.METH_GET, + protocols: Collection[str] = (), + timeout: Union[float, _SENTINEL, None] = sentinel, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, + ssl_context: Optional[SSLContext] = None, + verify_ssl: Optional[bool] = True, + server_hostname: Optional[str] = None, + proxy_headers: Optional[LooseHeaders] = None, + compress: int = 0, + max_msg_size: int = 4 * 1024 * 1024, + ) -> None: + self.url: str = url + self.headers: Optional[LooseHeaders] = headers + self.auth: Optional[BasicAuth] = auth + self.autoclose: bool = autoclose + self.autoping: bool = autoping + self.compress: int = compress + self.heartbeat: Optional[float] = heartbeat + self.max_msg_size: int = max_msg_size + self.method: str = method + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + self.protocols: Optional[list[str]] = protocols + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + self.receive_timeout: Optional[float] = receive_timeout + self.ssl: Union[SSLContext, bool] = ssl + self.ssl_context: Optional[SSLContext] = ssl_context + self.timeout: Union[float, _SENTINEL, None] = timeout + self.verify_ssl: Optional[bool] = verify_ssl + + + self.session: Optional[aiohttp.ClientSession] = None + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + + super().__init__() + + async def connect(self) -> None: + if self.session is None: + self.session = aiohttp.ClientSession() + + if self.session is not None: + try: + self.websocket = await self.session.ws_connect( + method=self.method, + url=self.url, + headers=self.headers, + auth=self.auth, + autoclose=self.autoclose, + autoping=self.autoping, + compress=self.compress, + heartbeat=self.heartbeat, + max_msg_size=self.max_msg_size, + origin=self.origin, + params=self.params, + protocols=self.protocols, + proxy=self.proxy, + proxy_auth=self.proxy_auth, + proxy_headers=self.proxy_headers, + receive_timeout=self.receive_timeout, + ssl=self.ssl, + ssl_context=None, + receive_timeout=self.receive_timeout, + timeout=self.timeout, + verify_ssl=self.verify_ssl, + ) + except Exception as e: + raise e + finally: + ... + + async def close(self): ... + + async def execute(self, document, variable_values=None, operation_name=None): ... + + def subscribe(self, document, variable_values=None, operation_name=None): ... From c875fd8dc0321993e4ffca4f498b1773913e0156 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Wed, 8 May 2024 17:53:38 +0000 Subject: [PATCH 02/61] wip: mirrroring websockets_base interface in aiohttp_websockets --- gql/transport/aiohttp_websockets.py | 120 ++++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 7 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 5f04becd..3d57a78c 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,11 +1,16 @@ -from enum import verify +import logging import re import time import aiohttp from gql.transport.async_transport import AsyncTransport -from typing import Optional, Union, Collection +from typing import Any, AsyncGenerator, Dict, Optional, Union, Collection from aiohttp.typedefs import LooseHeaders, Mapping, StrOrURL from aiohttp.helpers import hdrs, BasicAuth, _SENTINEL +from gql.transport.exceptions import TransportClosed, TransportProtocolError, TransportQueryError +from graphql import DocumentNode, ExecutionResult +from h11 import Data +from websockets import ConnectionClosed +from gql.transport.websockets_base import ListenerQueue """HTTP Client for asyncio.""" @@ -30,6 +35,7 @@ from ssl import SSLContext +log = logging.getLogger("gql.transport.aiohttp_websockets") class AIOHTTPWebsocketsTransport(AsyncTransport): @@ -57,7 +63,9 @@ def __init__( proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, + **kwargs, ) -> None: + super().__init__(**kwargs) self.url: str = url self.headers: Optional[LooseHeaders] = headers self.auth: Optional[BasicAuth] = auth @@ -83,7 +91,67 @@ def __init__( self.session: Optional[aiohttp.ClientSession] = None self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None - super().__init__() + + async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + pass # pragma: no cover + + async def _stop_listener(self, query_id: int): + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + pass # pragma: no cover + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + pass # pragma: no cover + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close""" + pass # pragma: no cover + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _send(self, message: str) -> None: + if self.websocket is None: + raise TransportClosed("WebSocket connection is closed") + + try: + await self.websocket.send_str(message) + log.info(">>> %s", message) + except ConnectionClosed as e: + await self._fail(e, clean_close=False) + raise e + + async def _receive(self) -> str: + + if self.websocket is None: + raise TransportClosed("WebSocket connection is closed") + + data: Data = await self.websocket.receive() + + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + log.info("<<< %s", answer) + + return answer async def connect(self) -> None: if self.session is None: @@ -110,7 +178,6 @@ async def connect(self) -> None: receive_timeout=self.receive_timeout, ssl=self.ssl, ssl_context=None, - receive_timeout=self.receive_timeout, timeout=self.timeout, verify_ssl=self.verify_ssl, ) @@ -119,8 +186,47 @@ async def connect(self) -> None: finally: ... - async def close(self): ... + async def close(self) -> None: ... + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + + Send a query but close the async generator as soon as we have the first answer. + + The result is sent as an ExecutionResult object. + """ + first_result = None - async def execute(self, document, variable_values=None, operation_name=None): ... + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) - def subscribe(self, document, variable_values=None, operation_name=None): ... + async for result in generator: + first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) + + return first_result + + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: ... From 1117d3a12502074f3fe6b69d603f0cd6de77d1cc Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Wed, 8 May 2024 18:06:33 +0000 Subject: [PATCH 03/61] black format --- gql/transport/aiohttp_websockets.py | 59 ++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 3d57a78c..571be863 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,3 +1,4 @@ +import json import logging import re import time @@ -6,8 +7,12 @@ from typing import Any, AsyncGenerator, Dict, Optional, Union, Collection from aiohttp.typedefs import LooseHeaders, Mapping, StrOrURL from aiohttp.helpers import hdrs, BasicAuth, _SENTINEL -from gql.transport.exceptions import TransportClosed, TransportProtocolError, TransportQueryError -from graphql import DocumentNode, ExecutionResult +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) +from graphql import DocumentNode, ExecutionResult, print_ast from h11 import Data from websockets import ConnectionClosed from gql.transport.websockets_base import ListenerQueue @@ -37,6 +42,7 @@ log = logging.getLogger("gql.transport.aiohttp_websockets") + class AIOHTTPWebsocketsTransport(AsyncTransport): def __init__( @@ -84,13 +90,13 @@ def __init__( self.receive_timeout: Optional[float] = receive_timeout self.ssl: Union[SSLContext, bool] = ssl self.ssl_context: Optional[SSLContext] = ssl_context - self.timeout: Union[float, _SENTINEL, None] = timeout + self.timeout: Union[float, _SENTINEL, None] = timeout self.verify_ssl: Optional[bool] = verify_ssl - self.session: Optional[aiohttp.ClientSession] = None self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None - + self.next_query_id: int = 1 + self.listeners: Dict[int, ListenerQueue] = {} async def _initialize(self): """Hook to send the initialization messages after the connection @@ -126,6 +132,41 @@ async def _connection_terminate(self): """ pass # pragma: no cover + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query_str = json.dumps( + {"id": str(query_id), "type": query_type, "payload": payload} + ) + + await self._send(query_str) + + return query_id + async def _send(self, message: str) -> None: if self.websocket is None: raise TransportClosed("WebSocket connection is closed") @@ -136,19 +177,19 @@ async def _send(self, message: str) -> None: except ConnectionClosed as e: await self._fail(e, clean_close=False) raise e - + async def _receive(self) -> str: if self.websocket is None: raise TransportClosed("WebSocket connection is closed") - data: Data = await self.websocket.receive() + data: Data = await self.websocket.receive() if not isinstance(data, str): raise TransportProtocolError("Binary data received in the websocket") - + answer: str = data - + log.info("<<< %s", answer) return answer From f3a927cd362c46a4074b7a8c30ec6a067a6f7c51 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Wed, 8 May 2024 21:56:58 +0000 Subject: [PATCH 04/61] wip: add more methods in pursuit of parity --- gql/transport/aiohttp_websockets.py | 233 +++++++++++++++++++++++++++- 1 file changed, 227 insertions(+), 6 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 571be863..f3982bc3 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,7 +1,10 @@ +import asyncio +from contextlib import suppress import json import logging import re import time +from tkinter import W import aiohttp from gql.transport.async_transport import AsyncTransport from typing import Any, AsyncGenerator, Dict, Optional, Union, Collection @@ -45,6 +48,11 @@ class AIOHTTPWebsocketsTransport(AsyncTransport): + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL: str = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" + def __init__( self, url: StrOrURL, @@ -69,6 +77,10 @@ def __init__( proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -93,11 +105,26 @@ def __init__( self.timeout: Union[float, _SENTINEL, None] = timeout self.verify_ssl: Optional[bool] = verify_ssl + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.session: Optional[aiohttp.ClientSession] = None self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} + self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() + + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + async def _initialize(self): """Hook to send the initialization messages after the connection and potentially wait for the backend ack. @@ -111,10 +138,18 @@ async def _stop_listener(self, query_id: int): pass # pragma: no cover async def _after_connect(self): - """Hook to add custom code for subclasses after the connection - has been established. - """ - pass # pragma: no cover + + # Find the backend subprotocol returned in the response headers + # TODO: find the equivalent of response_headers in aiohttp websocket response + subprotocol = self.websocket.protocol + try: + self.subprotocol = subprotocol + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") async def _after_initialize(self): """Hook to add custom code for subclasses after the initialization @@ -168,6 +203,8 @@ async def _send_query( return query_id async def _send(self, message: str) -> None: + """Send the provided message to the websocket connection and log the message""" + if self.websocket is None: raise TransportClosed("WebSocket connection is closed") @@ -194,6 +231,19 @@ async def _receive(self) -> str: return answer + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + async def connect(self) -> None: if self.session is None: self.session = aiohttp.ClientSession() @@ -226,8 +276,128 @@ async def connect(self) -> None: raise e finally: ... + await self._after_connect() + + async def _clean_close(self, e: Exception) -> None: + """Coroutine which will: + + - send stop messages for each active subscription to the server + - send the connection terminate message + """ + + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): + + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") + + # Calling the subclass hook + await self._connection_terminate() + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: + + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ + + log.debug("_close_coro: starting") + + try: + + # We should always have an active websocket connection here + assert self.websocket is not None + + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + + # Calling the subclass close hook + await self._close_hook() + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close(e) + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) + + log.debug("_close_coro: close websocket connection") + + await self.websocket.close() + + log.debug("_close_coro: websocket connection closed") + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + + log.debug("_close_coro: start cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self.websocket is None: + log.debug("_fail started with self.websocket == None -> already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) + + async def close(self) -> None: + log.debug("close: starting") + + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() + + log.debug("close: done") + + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + await self._wait_closed.wait() + + log.debug("wait_close: done") - async def close(self) -> None: ... async def execute( self, @@ -270,4 +440,55 @@ async def subscribe( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: ... + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. + + The query can be a graphql query, mutation or subscription. + + The results are sent as an ExecutionResult object. + """ + + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) + + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener + + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + + try: + # Loop over the received answers + while True: + + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() + + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break + + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) + + From 7378a53cd9dcdff2b0ae08903d20d6272e064e50 Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Thu, 9 May 2024 18:20:49 +0000 Subject: [PATCH 05/61] WIP: aiohttp websockets parity --- gql/transport/aiohttp_websockets.py | 510 +++++++++++++++++++++++++--- 1 file changed, 455 insertions(+), 55 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index f3982bc3..d53c0c76 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,47 +1,37 @@ +"""Websockets Client for asyncio.""" + +from ssl import SSLContext + +import aiohttp import asyncio -from contextlib import suppress import json import logging -import re -import time -from tkinter import W -import aiohttp -from gql.transport.async_transport import AsyncTransport -from typing import Any, AsyncGenerator, Dict, Optional, Union, Collection +from aiohttp import WSMessage, WSMsgType +from aiohttp.client_reqrep import Fingerprint +from aiohttp.helpers import BasicAuth, hdrs from aiohttp.typedefs import LooseHeaders, Mapping, StrOrURL -from aiohttp.helpers import hdrs, BasicAuth, _SENTINEL -from gql.transport.exceptions import ( - TransportClosed, - TransportProtocolError, - TransportQueryError, -) +from contextlib import suppress from graphql import DocumentNode, ExecutionResult, print_ast -from h11 import Data -from websockets import ConnectionClosed -from gql.transport.websockets_base import ListenerQueue - -"""HTTP Client for asyncio.""" - from typing import ( + Any, + AsyncGenerator, Collection, + Dict, Mapping, Optional, + Tuple, Union, ) - -from aiohttp import hdrs -from aiohttp.client_reqrep import ( - Fingerprint, -) -from aiohttp.helpers import ( - _SENTINEL, - BasicAuth, - sentinel, +from gql.transport.async_transport import AsyncTransport +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, ) -from aiohttp.typedefs import LooseHeaders, StrOrURL - -from ssl import SSLContext +from gql.transport.websockets_base import ListenerQueue log = logging.getLogger("gql.transport.aiohttp_websockets") @@ -59,7 +49,7 @@ def __init__( *, method: str = hdrs.METH_GET, protocols: Collection[str] = (), - timeout: Union[float, _SENTINEL, None] = sentinel, + timeout: float = 10.0, receive_timeout: Optional[float] = None, autoclose: bool = True, autoping: bool = True, @@ -73,7 +63,6 @@ def __init__( ssl: Union[SSLContext, bool, Fingerprint] = True, ssl_context: Optional[SSLContext] = None, verify_ssl: Optional[bool] = True, - server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, @@ -81,10 +70,14 @@ def __init__( close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, + init_payload: Dict[str, Any] = {}, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) - self.url: str = url + self.url: StrOrURL = url self.headers: Optional[LooseHeaders] = headers self.auth: Optional[BasicAuth] = auth self.autoclose: bool = autoclose @@ -95,25 +88,29 @@ def __init__( self.method: str = method self.origin: Optional[str] = origin self.params: Optional[Mapping[str, str]] = params - self.protocols: Optional[list[str]] = protocols + self.protocols: Collection[str] = protocols self.proxy: Optional[StrOrURL] = proxy self.proxy_auth: Optional[BasicAuth] = proxy_auth self.proxy_headers: Optional[LooseHeaders] = proxy_headers self.receive_timeout: Optional[float] = receive_timeout - self.ssl: Union[SSLContext, bool] = ssl + self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.ssl_context: Optional[SSLContext] = ssl_context - self.timeout: Union[float, _SENTINEL, None] = timeout + self.timeout: float = timeout self.verify_ssl: Optional[bool] = verify_ssl + self.init_payload: Dict[str, Any] = init_payload self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() self.session: Optional[aiohttp.ClientSession] = None self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} + self._connecting: bool = False self.receive_data_task: Optional[asyncio.Future] = None self.check_keep_alive_task: Optional[asyncio.Future] = None @@ -125,11 +122,235 @@ def __init__( self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() - async def _initialize(self): - """Hook to send the initialization messages after the connection - and potentially wait for the backend ack. + self.payloads: Dict[str, Any] = {} + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + if protocols is None: + self.supported_subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + else: + self.supported_subprotocols = protocols + + def _parse_answer_graphqlws( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error """ - pass # pragma: no cover + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, _, _ = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + await self._send_init_message_and_wait_ack() async def _stop_listener(self, query_id: int): """Hook to stop to listen to a specific query. @@ -138,6 +359,8 @@ async def _stop_listener(self, query_id: int): pass # pragma: no cover async def _after_connect(self): + if self.websocket is None: + raise TransportClosed("WebSocket connection is closed") # Find the backend subprotocol returned in the response headers # TODO: find the equivalent of response_headers in aiohttp websocket response @@ -151,11 +374,58 @@ async def _after_connect(self): log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - async def _after_initialize(self): - """Hook to add custom code for subclasses after the initialization - has been done. + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. """ - pass # pragma: no cover + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _after_initialize(self): + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) async def _close_hook(self): """Hook to add custom code for subclasses for the connection close""" @@ -211,7 +481,7 @@ async def _send(self, message: str) -> None: try: await self.websocket.send_str(message) log.info(">>> %s", message) - except ConnectionClosed as e: + except ConnectionResetError as e: await self._fail(e, clean_close=False) raise e @@ -220,12 +490,12 @@ async def _receive(self) -> str: if self.websocket is None: raise TransportClosed("WebSocket connection is closed") - data: Data = await self.websocket.receive() + data: WSMessage = await self.websocket.receive() - if not isinstance(data, str): + if data.type != WSMsgType.TEXT: raise TransportProtocolError("Binary data received in the websocket") - answer: str = data + answer: str = data.data log.info("<<< %s", answer) @@ -244,11 +514,112 @@ def _remove_listener(self, query_id) -> None: if remaining == 0: self._no_more_listeners.set() + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + try: + while True: + + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionResetError, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed: + break + + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break + + await self._handle_answer(answer_type, answer_id, execution_result) + + finally: + log.debug("Exiting _receive_data_loop()") + async def connect(self) -> None: + log.debug("connect: starting") + if self.session is None: self.session = aiohttp.ClientSession() - if self.session is not None: + if self.websocket is None and not self._connecting: + self._connecting = True + try: self.websocket = await self.session.ws_connect( method=self.method, @@ -274,10 +645,41 @@ async def connect(self) -> None: ) except Exception as e: raise e - finally: - ... + + self._connecting = False + await self._after_connect() + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() + + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionResetError as e: + raise e + except (TransportProtocolError, asyncio.TimeoutError) as e: + await self._fail(e, clean_close=False) + raise e + + # Run the after_init hook of the subclass + await self._after_initialize() + + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + log.debug("connect: done") + async def _clean_close(self, e: Exception) -> None: """Coroutine which will: @@ -364,6 +766,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self._wait_closed.set() log.debug("_close_coro: exiting") + async def _fail(self, e: Exception, clean_close: bool = True) -> None: log.debug("_fail: starting with exception: " + repr(e)) @@ -398,7 +801,6 @@ async def wait_closed(self) -> None: log.debug("wait_close: done") - async def execute( self, document: DocumentNode, @@ -490,5 +892,3 @@ async def subscribe( finally: log.debug(f"In subscribe finally for query_id {query_id}") self._remove_listener(query_id) - - From 1c8eda588579d1de657e16c89e144ac4cca23f30 Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Thu, 30 May 2024 16:39:09 +0000 Subject: [PATCH 06/61] answer improvements Signed-off-by: Micah Pegman --- gql/transport/aiohttp_websockets.py | 84 +++++++++++++---------------- 1 file changed, 36 insertions(+), 48 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index d53c0c76..1df3cbca 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -4,14 +4,13 @@ import aiohttp import asyncio -import json import logging -from aiohttp import WSMessage, WSMsgType from aiohttp.client_reqrep import Fingerprint from aiohttp.helpers import BasicAuth, hdrs -from aiohttp.typedefs import LooseHeaders, Mapping, StrOrURL +from aiohttp.typedefs import LooseHeaders, StrOrURL from contextlib import suppress from graphql import DocumentNode, ExecutionResult, print_ast +from multidict import CIMultiDict, CIMultiDictProxy from typing import ( Any, AsyncGenerator, @@ -111,6 +110,7 @@ def __init__( self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} self._connecting: bool = False + self.response_headers: Optional[CIMultiDictProxy[str]] = None self.receive_data_task: Optional[asyncio.Future] = None self.check_keep_alive_task: Optional[asyncio.Future] = None @@ -144,13 +144,11 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - if protocols is None: - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - else: - self.supported_subprotocols = protocols + self.supported_subprotocols: Collection[str] = protocols or ( + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ) + self.close_exception: Optional[Exception] = None def _parse_answer_graphqlws( self, json_answer: Dict[str, Any] @@ -233,7 +231,7 @@ def _parse_answer_graphqlws( return answer_type, answer_id, execution_result def _parse_answer_apollo( - self, json_answer: Dict[str, Any] + self, answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the apollo websockets protocol. @@ -250,14 +248,14 @@ def _parse_answer_apollo( execution_result: Optional[ExecutionResult] = None try: - answer_type = str(json_answer.get("type")) + answer_type = str(answer.get("type")) if answer_type in ["data", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) + answer_id = int(str(answer.get("id"))) if answer_type == "data" or answer_type == "error": - payload = json_answer.get("payload") + payload = answer.get("payload") if not isinstance(payload, dict): raise ValueError("payload is not a dict") @@ -288,35 +286,28 @@ def _parse_answer_apollo( elif answer_type == "connection_ack": pass elif answer_type == "connection_error": - error_payload = json_answer.get("payload") + error_payload = answer.get("payload") raise TransportServerError(f"Server error: '{repr(error_payload)}'") else: raise ValueError except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" + f"Server did not return a GraphQL result: {answer}" ) from e return answer_type, answer_id, execution_result def _parse_answer( - self, answer: str + self, answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server depending on the detected subprotocol. """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) + return self._parse_answer_graphqlws(answer) - return self._parse_answer_apollo(json_answer) + return self._parse_answer_apollo(answer) async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" @@ -340,9 +331,7 @@ async def _send_init_message_and_wait_ack(self) -> None: If the answer is not a connection_ack message, we will return an Exception. """ - init_message = json.dumps( - {"type": "connection_init", "payload": self.init_payload} - ) + init_message = {"type": "connection_init", "payload": self.init_payload} await self._send(init_message) @@ -382,7 +371,7 @@ async def send_ping(self, payload: Optional[Any] = None) -> None: if payload is not None: ping_message["payload"] = payload - await self._send(json.dumps(ping_message)) + await self._send(ping_message) async def _send_ping_coro(self) -> None: """Coroutine to periodically send a ping from the client to the backend. @@ -464,38 +453,31 @@ async def _send_query( if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: query_type = "subscribe" - query_str = json.dumps( - {"id": str(query_id), "type": query_type, "payload": payload} - ) + query = {"id": str(query_id), "type": query_type, "payload": payload} - await self._send(query_str) + await self._send(query) return query_id - async def _send(self, message: str) -> None: + async def _send(self, message: Dict[str, Any]) -> None: """Send the provided message to the websocket connection and log the message""" if self.websocket is None: raise TransportClosed("WebSocket connection is closed") try: - await self.websocket.send_str(message) + await self.websocket.send_json(message) log.info(">>> %s", message) except ConnectionResetError as e: await self._fail(e, clean_close=False) raise e - async def _receive(self) -> str: + async def _receive(self) -> Dict[str, Any]: if self.websocket is None: raise TransportClosed("WebSocket connection is closed") - data: WSMessage = await self.websocket.receive() - - if data.type != WSMsgType.TEXT: - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data.data + answer = await self.websocket.receive_json() log.info("<<< %s", answer) @@ -643,10 +625,13 @@ async def connect(self) -> None: timeout=self.timeout, verify_ssl=self.verify_ssl, ) - except Exception as e: - raise e + finally: + self._connecting = False - self._connecting = False + try: + self.response_headers = self.websocket._response.headers + except AttributeError: + self.response_headers = CIMultiDictProxy(CIMultiDict()) await self._after_connect() @@ -678,9 +663,12 @@ async def connect(self) -> None: # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + else: + raise TransportAlreadyConnected("Transport is already connected") + log.debug("connect: done") - async def _clean_close(self, e: Exception) -> None: + async def _clean_close(self) -> None: """Coroutine which will: - send stop messages for each active subscription to the server @@ -737,7 +725,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: if clean_close: log.debug("_close_coro: starting clean_close") try: - await self._clean_close(e) + await self._clean_close() except Exception as exc: # pragma: no cover log.warning("Ignoring exception in _clean_close: " + repr(exc)) From 34fc580ef71ad974487edc1577240b7a316c258d Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Thu, 30 May 2024 18:56:53 +0000 Subject: [PATCH 07/61] linting fixes Signed-off-by: Micah Pegman --- gql/transport/aiohttp_websockets.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 1df3cbca..220f6416 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,16 +1,9 @@ """Websockets Client for asyncio.""" -from ssl import SSLContext - -import aiohttp import asyncio import logging -from aiohttp.client_reqrep import Fingerprint -from aiohttp.helpers import BasicAuth, hdrs -from aiohttp.typedefs import LooseHeaders, StrOrURL from contextlib import suppress -from graphql import DocumentNode, ExecutionResult, print_ast -from multidict import CIMultiDict, CIMultiDictProxy +from ssl import SSLContext from typing import ( Any, AsyncGenerator, @@ -22,6 +15,13 @@ Union, ) +import aiohttp +from aiohttp.client_reqrep import Fingerprint +from aiohttp.helpers import BasicAuth, hdrs +from aiohttp.typedefs import LooseHeaders, StrOrURL +from graphql import DocumentNode, ExecutionResult, print_ast +from multidict import CIMultiDict, CIMultiDictProxy + from gql.transport.async_transport import AsyncTransport from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -73,9 +73,7 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, - **kwargs, ) -> None: - super().__init__(**kwargs) self.url: StrOrURL = url self.headers: Optional[LooseHeaders] = headers self.auth: Optional[BasicAuth] = auth From f5d208cd329dab6c9916bab3eebcf2f9c639bef8 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Thu, 30 May 2024 22:11:41 +0000 Subject: [PATCH 08/61] wip: initial tests --- docs/code_examples/console_async.py | 2 +- docs/code_examples/fastapi_async.py | 2 +- .../reconnecting_mutation_http.py | 3 +- .../code_examples/reconnecting_mutation_ws.py | 3 +- gql/cli.py | 6 +- gql/client.py | 146 +++++++----------- gql/dsl.py | 14 +- gql/graphql_request.py | 3 +- gql/transport/aiohttp.py | 14 +- gql/transport/aiohttp_websockets.py | 1 - gql/transport/appsync_websockets.py | 6 +- gql/transport/async_transport.py | 3 +- gql/transport/httpx.py | 5 +- gql/transport/local_schema.py | 3 +- gql/transport/phoenix_channel_websockets.py | 3 +- gql/transport/requests.py | 3 +- gql/transport/transport.py | 3 +- gql/transport/websockets.py | 6 +- gql/transport/websockets_base.py | 8 +- gql/utilities/get_introspection_query_ast.py | 3 +- gql/utilities/node_tree.py | 3 +- gql/utilities/parse_result.py | 3 +- gql/utilities/serialize_variable_values.py | 3 +- gql/utilities/update_schema_enum.py | 3 +- gql/utilities/update_schema_scalars.py | 3 +- tests/conftest.py | 26 +++- tests/custom_scalars/test_datetime.py | 5 +- tests/custom_scalars/test_enum_colors.py | 3 +- tests/custom_scalars/test_json.py | 3 +- tests/custom_scalars/test_money.py | 17 +- tests/nested_input/schema.py | 1 - tests/starwars/fixtures.py | 3 +- tests/starwars/schema.py | 1 - tests/starwars/test_subscription.py | 1 - tests/test_aiohttp.py | 40 ++++- tests/test_aiohttp_online.py | 3 +- tests/test_aiohttp_websocket_exceptions.py | 77 +++++++++ tests/test_aiohttp_websocket_online.py | 0 tests/test_aiohttp_websocket_query.py | 0 tests/test_aiohttp_websocket_subscription.py | 0 tests/test_appsync_auth.py | 7 +- tests/test_appsync_http.py | 4 +- tests/test_appsync_websockets.py | 18 ++- tests/test_async_client_validation.py | 3 +- tests/test_cli.py | 5 +- tests/test_client.py | 5 +- tests/test_graphql_request.py | 5 +- tests/test_graphqlws_exceptions.py | 4 +- tests/test_graphqlws_subscription.py | 8 +- tests/test_httpx.py | 20 ++- tests/test_httpx_async.py | 34 +++- tests/test_httpx_online.py | 3 +- tests/test_phoenix_channel_exceptions.py | 8 +- tests/test_phoenix_channel_subscription.py | 3 +- tests/test_requests.py | 22 ++- tests/test_requests_batch.py | 18 ++- tests/test_transport.py | 2 +- tests/test_transport_batch.py | 2 +- tests/test_websocket_exceptions.py | 7 +- tests/test_websocket_online.py | 3 +- tests/test_websocket_query.py | 11 +- tests/test_websocket_subscription.py | 5 +- 62 files changed, 389 insertions(+), 240 deletions(-) create mode 100644 tests/test_aiohttp_websocket_exceptions.py create mode 100644 tests/test_aiohttp_websocket_online.py create mode 100644 tests/test_aiohttp_websocket_query.py create mode 100644 tests/test_aiohttp_websocket_subscription.py diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 9a5e94e5..2ec4feec 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -1,7 +1,7 @@ import asyncio import logging - from aioconsole import ainput + from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 80920252..511b4abc 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -7,9 +7,9 @@ # uvicorn fastapi_async:app --reload import logging - from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse + from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/reconnecting_mutation_http.py b/docs/code_examples/reconnecting_mutation_http.py index f4329c8b..b379be91 100644 --- a/docs/code_examples/reconnecting_mutation_http.py +++ b/docs/code_examples/reconnecting_mutation_http.py @@ -1,7 +1,6 @@ import asyncio -import logging - import backoff +import logging from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/reconnecting_mutation_ws.py b/docs/code_examples/reconnecting_mutation_ws.py index 7d7c8f8a..b407ddaa 100644 --- a/docs/code_examples/reconnecting_mutation_ws.py +++ b/docs/code_examples/reconnecting_mutation_ws.py @@ -1,7 +1,6 @@ import asyncio -import logging - import backoff +import logging from gql import Client, gql from gql.transport.websockets import WebsocketsTransport diff --git a/gql/cli.py b/gql/cli.py index dd991546..234478de 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -5,9 +5,8 @@ import sys import textwrap from argparse import ArgumentParser, Namespace, RawTextHelpFormatter -from typing import Any, Dict, Optional - from graphql import GraphQLError, print_schema +from typing import Any, Dict, Optional from yarl import URL from gql import Client, __version__, gql @@ -358,9 +357,10 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt) else: - from gql.transport.appsync_auth import AppSyncIAMAuthentication from botocore.exceptions import NoRegionError + from gql.transport.appsync_auth import AppSyncIAMAuthentication + try: auth = AppSyncIAMAuthentication(host=url.host) except NoRegionError: diff --git a/gql/client.py b/gql/client.py index 0d9e36c7..17b8be49 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,9 +1,21 @@ import asyncio +import backoff import logging import sys import time import warnings +from anyio import fail_after from concurrent.futures import Future +from graphql import ( + DocumentNode, + ExecutionResult, + GraphQLSchema, + IntrospectionQuery, + build_ast_schema, + get_introspection_query, + parse, + validate, +) from queue import Queue from threading import Event, Thread from typing import ( @@ -21,19 +33,6 @@ overload, ) -import backoff -from anyio import fail_after -from graphql import ( - DocumentNode, - ExecutionResult, - GraphQLSchema, - IntrospectionQuery, - build_ast_schema, - get_introspection_query, - parse, - validate, -) - from .graphql_request import GraphQLRequest from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportClosed, TransportQueryError @@ -202,8 +201,7 @@ def execute_sync( *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute_sync( @@ -216,8 +214,7 @@ def execute_sync( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + ) -> ExecutionResult: ... # pragma: no cover @overload def execute_sync( @@ -230,8 +227,7 @@ def execute_sync( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute_sync( self, @@ -264,8 +260,7 @@ def execute_batch_sync( parse_result: Optional[bool] = None, get_execution_result: Literal[False], **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch_sync( @@ -276,8 +271,7 @@ def execute_batch_sync( parse_result: Optional[bool] = None, get_execution_result: Literal[True], **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch_sync( @@ -288,8 +282,7 @@ def execute_batch_sync( parse_result: Optional[bool] = None, get_execution_result: bool, **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch_sync( self, @@ -321,8 +314,7 @@ async def execute_async( *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute_async( @@ -335,8 +327,7 @@ async def execute_async( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute_async( @@ -349,8 +340,7 @@ async def execute_async( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute_async( self, @@ -385,8 +375,7 @@ def execute( *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( @@ -399,8 +388,7 @@ def execute( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( @@ -413,8 +401,7 @@ def execute( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, @@ -500,8 +487,7 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[False], **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -512,8 +498,7 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[True], **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -524,8 +509,7 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: bool, **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -581,8 +565,7 @@ def subscribe_async( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe_async( @@ -595,8 +578,7 @@ def subscribe_async( *, get_execution_result: Literal[True], **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe_async( @@ -611,8 +593,7 @@ def subscribe_async( **kwargs, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe_async( self, @@ -652,8 +633,7 @@ def subscribe( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> Generator[Dict[str, Any], None, None]: - ... # pragma: no cover + ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover @overload def subscribe( @@ -666,8 +646,7 @@ def subscribe( *, get_execution_result: Literal[True], **kwargs, - ) -> Generator[ExecutionResult, None, None]: - ... # pragma: no cover + ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover @overload def subscribe( @@ -682,8 +661,7 @@ def subscribe( **kwargs, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover def subscribe( self, @@ -953,8 +931,7 @@ def execute( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( @@ -967,8 +944,7 @@ def execute( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( @@ -981,8 +957,7 @@ def execute( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, @@ -1075,9 +1050,11 @@ def _execute_batch( serialize_variables is None and self.client.serialize_variables ): requests = [ - req.serialize_variable_values(self.client.schema) - if req.variable_values is not None - else req + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) for req in requests ] @@ -1105,8 +1082,7 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[False], **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -1117,8 +1093,7 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[True], **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -1129,8 +1104,7 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: bool, **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -1355,13 +1329,13 @@ async def _subscribe( ) # Subscribe to the transport - inner_generator: AsyncGenerator[ - ExecutionResult, None - ] = self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, + inner_generator: AsyncGenerator[ExecutionResult, None] = ( + self.transport.subscribe( + document, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) ) # Keep a reference to the inner generator to allow the user to call aclose() @@ -1397,8 +1371,7 @@ def subscribe( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe( @@ -1411,8 +1384,7 @@ def subscribe( *, get_execution_result: Literal[True], **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe( @@ -1427,8 +1399,7 @@ def subscribe( **kwargs, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe( self, @@ -1564,8 +1535,7 @@ async def execute( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute( @@ -1578,8 +1548,7 @@ async def execute( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute( @@ -1592,8 +1561,7 @@ async def execute( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute( self, diff --git a/gql/dsl.py b/gql/dsl.py index 536a8b8b..bc32d875 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -2,12 +2,10 @@ .. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa :alt: UML diagram """ + import logging import re from abc import ABC, abstractmethod -from math import isfinite -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast - from graphql import ( ArgumentNode, BooleanValueNode, @@ -62,6 +60,8 @@ print_ast, ) from graphql.pyutils import inspect +from math import isfinite +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast from .utils import to_camel_case @@ -595,9 +595,11 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: VariableDefinitionNode( type=var.ast_variable_type, variable=var.ast_variable_name, - default_value=None - if var.default_value is None - else ast_from_value(var.default_value, var.type), + default_value=( + None + if var.default_value is None + else ast_from_value(var.default_value, var.type) + ), directives=(), ) for var in self.variables.values() diff --git a/gql/graphql_request.py b/gql/graphql_request.py index b0c68f5c..41504dcd 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Any, Dict, Optional - from graphql import DocumentNode, GraphQLSchema +from typing import Any, Dict, Optional from .utilities import serialize_variable_values diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index be22ce9c..1269b097 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,18 +1,18 @@ +from ssl import SSLContext + +import aiohttp import asyncio import functools import io import json import logging -from ssl import SSLContext -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union - -import aiohttp from aiohttp.client_exceptions import ClientResponseError from aiohttp.client_reqrep import Fingerprint from aiohttp.helpers import BasicAuth from aiohttp.typedefs import LooseCookies, LooseHeaders from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union from ..utils import extract_files from .appsync_auth import AppSyncAuthentication @@ -101,9 +101,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": None - if isinstance(self.auth, AppSyncAuthentication) - else self.auth, + "auth": ( + None if isinstance(self.auth, AppSyncAuthentication) else self.auth + ), "json_serialize": self.json_serialize, } diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index d53c0c76..b0989ec0 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -17,7 +17,6 @@ AsyncGenerator, Collection, Dict, - Mapping, Optional, Tuple, Union, diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 66091747..655acb19 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -1,11 +1,11 @@ +from ssl import SSLContext + import json import logging -from ssl import SSLContext +from graphql import DocumentNode, ExecutionResult, print_ast from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse -from graphql import DocumentNode, ExecutionResult, print_ast - from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication from .exceptions import TransportProtocolError, TransportServerError from .websockets import WebsocketsTransport, WebsocketsTransportBase diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 4cecc9f9..2d180b65 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,7 +1,6 @@ import abc -from typing import Any, AsyncGenerator, Dict, Optional - from graphql import DocumentNode, ExecutionResult +from typing import Any, AsyncGenerator, Dict, Optional class AsyncTransport(abc.ABC): diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 811601b8..4f8d8334 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -1,6 +1,8 @@ +import httpx import io import json import logging +from graphql import DocumentNode, ExecutionResult, print_ast from typing import ( Any, AsyncGenerator, @@ -14,9 +16,6 @@ cast, ) -import httpx -from graphql import DocumentNode, ExecutionResult, print_ast - from ..utils import extract_files from . import AsyncTransport, Transport from .exceptions import ( diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 04ed4ff1..787e1491 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,9 +1,8 @@ import asyncio +from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe from inspect import isawaitable from typing import AsyncGenerator, Awaitable, cast -from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe - from gql.transport import AsyncTransport diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index b8226234..fae8cd3a 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,9 +1,8 @@ import asyncio import json import logging -from typing import Any, Dict, Optional, Tuple - from graphql import DocumentNode, ExecutionResult, print_ast +from typing import Any, Dict, Optional, Tuple from websockets.exceptions import ConnectionClosed from .exceptions import ( diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 0c6eb3fc..b6d19292 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,14 +1,13 @@ import io import json import logging -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union - import requests from graphql import DocumentNode, ExecutionResult, print_ast from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar from requests_toolbelt.multipart.encoder import MultipartEncoder +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union from gql.transport import Transport diff --git a/gql/transport/transport.py b/gql/transport/transport.py index a5bd7100..cb04d4d8 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,7 +1,6 @@ import abc -from typing import List - from graphql import DocumentNode, ExecutionResult +from typing import List from ..graphql_request import GraphQLRequest diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index c385d3d7..e127dc37 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,11 +1,11 @@ +from ssl import SSLContext + import asyncio import json import logging from contextlib import suppress -from ssl import SSLContext -from typing import Any, Dict, List, Optional, Tuple, Union, cast - from graphql import DocumentNode, ExecutionResult, print_ast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from websockets.datastructures import HeadersLike from websockets.typing import Subprotocol diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index 45c96d3e..a952611a 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -1,13 +1,13 @@ +from ssl import SSLContext + import asyncio import logging import warnings +import websockets from abc import abstractmethod from contextlib import suppress -from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast - -import websockets from graphql import DocumentNode, ExecutionResult +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast from websockets.client import WebSocketClientProtocol from websockets.datastructures import Headers, HeadersLike from websockets.exceptions import ConnectionClosed diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index d35a2a75..0abbec30 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -1,6 +1,5 @@ -from itertools import repeat - from graphql import DocumentNode, GraphQLSchema +from itertools import repeat from gql.dsl import DSLFragment, DSLMetaField, DSLQuery, DSLSchema, dsl_gql diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index c307d937..a8369b1a 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -1,6 +1,5 @@ -from typing import Any, Iterable, List, Optional, Sized - from graphql import Node +from typing import Any, Iterable, List, Optional, Sized def _node_tree_recursive( diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 02355425..c626f196 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -1,6 +1,4 @@ import logging -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast - from graphql import ( IDLE, REMOVE, @@ -29,6 +27,7 @@ ) from graphql.language.visitor import VisitorActionEnum from graphql.pyutils import inspect +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast log = logging.getLogger(__name__) diff --git a/gql/utilities/serialize_variable_values.py b/gql/utilities/serialize_variable_values.py index 38ad1995..cc8740c3 100644 --- a/gql/utilities/serialize_variable_values.py +++ b/gql/utilities/serialize_variable_values.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, Optional - from graphql import ( DocumentNode, GraphQLEnumType, @@ -15,6 +13,7 @@ type_from_ast, ) from graphql.pyutils import inspect +from typing import Any, Dict, Optional def _get_document_operation( diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py index 80c73862..2888ae08 100644 --- a/gql/utilities/update_schema_enum.py +++ b/gql/utilities/update_schema_enum.py @@ -1,7 +1,6 @@ from enum import Enum -from typing import Any, Dict, Mapping, Type, Union, cast - from graphql import GraphQLEnumType, GraphQLSchema +from typing import Any, Dict, Mapping, Type, Union, cast def update_schema_enum( diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index db3adb17..8ba366b3 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -1,6 +1,5 @@ -from typing import Iterable, List - from graphql import GraphQLScalarType, GraphQLSchema +from typing import Iterable, List def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): diff --git a/tests/conftest.py b/tests/conftest.py index 6a37a5d3..0732ac7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,20 @@ +from random import sample +import ssl + import asyncio import json import logging import os import pathlib +import pytest +import pytest_asyncio import re -import ssl import sys import tempfile import types from concurrent.futures import ThreadPoolExecutor from typing import Union -import pytest -import pytest_asyncio - from gql import Client all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] @@ -119,6 +120,7 @@ async def ssl_aiohttp_server(): for name in [ "websockets.legacy.server", "gql.transport.aiohttp", + "gql.transport.aiohttp_websockets", "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", @@ -482,6 +484,22 @@ async def client_and_graphqlws_server(graphqlws_server): # Yield both client session and server yield session, graphqlws_server +@pytest_asyncio.fixture +async def aiohttp_client_and_server(aiohttp_server): + """Helper fixture to start an aiohttp server and a client connected to its port.""" + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{aiohttp_server.hostname}:{aiohttp_server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, aiohttp_server + @pytest_asyncio.fixture async def run_sync_test(): diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index b3e717c5..61d3a9e3 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -1,7 +1,5 @@ -from datetime import datetime, timedelta -from typing import Any, Dict, Optional - import pytest +from datetime import datetime, timedelta from graphql.error import GraphQLError from graphql.language import ValueNode from graphql.pyutils import inspect @@ -17,6 +15,7 @@ GraphQLSchema, ) from graphql.utilities import value_from_ast_untyped +from typing import Any, Dict, Optional from gql import Client, gql diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 2f15a8ca..9ddc7df3 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -1,6 +1,5 @@ -from enum import Enum - import pytest +from enum import Enum from graphql import ( GraphQLArgument, GraphQLEnumType, diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index d3eae3b8..4c9505cc 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, Optional - import pytest from graphql import ( GraphQLArgument, @@ -14,6 +12,7 @@ ) from graphql.language import ValueNode from graphql.utilities import value_from_ast_untyped +from typing import Any, Dict, Optional from gql import Client, gql from gql.dsl import DSLSchema diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 374c70e6..234e6cb9 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -1,7 +1,4 @@ import asyncio -from math import isfinite -from typing import Any, Dict, NamedTuple, Optional - import pytest from graphql import ExecutionResult, graphql_sync from graphql.error import GraphQLError @@ -19,6 +16,8 @@ GraphQLSchema, ) from graphql.utilities import value_from_ast_untyped +from math import isfinite +from typing import Any, Dict, NamedTuple, Optional from gql import Client, GraphQLRequest, gql from gql.transport.exceptions import TransportQueryError @@ -441,9 +440,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: [ { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } for result in results ] @@ -453,9 +452,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: return web.json_response( { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } ) diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py index ccdebb4a..d8a2f929 100644 --- a/tests/nested_input/schema.py +++ b/tests/nested_input/schema.py @@ -1,5 +1,4 @@ import json - from graphql import ( GraphQLArgument, GraphQLField, diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 59d7ddfa..1d179f60 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -148,9 +148,10 @@ def create_review(episode, review): async def make_starwars_backend(aiohttp_server): from aiohttp import web - from .schema import StarWarsSchema from graphql import graphql_sync + from .schema import StarWarsSchema + async def handler(request): data = await request.json() source = data["query"] diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 4b672ad3..ef196213 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -1,5 +1,4 @@ import asyncio - from graphql import ( GraphQLArgument, GraphQLEnumType, diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 0f412acc..7c1be4cf 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -1,5 +1,4 @@ import asyncio - import pytest from graphql import ExecutionResult, GraphQLError, subscribe diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index b16964d0..1ed708cd 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,8 +1,7 @@ import io import json -from typing import Mapping - import pytest +from typing import Mapping from gql import Client, gql from gql.cli import get_parser, main @@ -43,6 +42,7 @@ @pytest.mark.asyncio async def test_aiohttp_query(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -82,6 +82,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_ignore_backend_content_type(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -111,6 +112,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cookies(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -144,6 +146,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_401(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -175,6 +178,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_429(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -222,6 +226,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_500(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -257,6 +262,7 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_aiohttp_error_code(event_loop, aiohttp_server, query_error): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -312,6 +318,7 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_aiohttp_invalid_protocol(event_loop, aiohttp_server, param): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport response = param["response"] @@ -340,6 +347,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_subscribe_not_supported(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -365,6 +373,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_connect_twice(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -387,6 +396,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_execute_if_not_connected(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -409,6 +419,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_extra_args(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -456,6 +467,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_variable_values(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -492,6 +504,7 @@ async def test_aiohttp_query_variable_values_fix_issue_292(event_loop, aiohttp_s See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -524,6 +537,7 @@ async def test_aiohttp_execute_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -552,6 +566,7 @@ async def test_aiohttp_subscribe_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -638,6 +653,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -703,6 +719,7 @@ async def single_upload_handler_with_content_type(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -743,6 +760,7 @@ async def test_aiohttp_file_upload_without_session( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -811,6 +829,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -845,7 +864,8 @@ async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.asyncio async def test_aiohttp_stream_reader_upload(event_loop, aiohttp_server): - from aiohttp import web, ClientSession + from aiohttp import ClientSession, web + from gql.transport.aiohttp import AIOHTTPTransport async def binary_data_handler(request): @@ -884,6 +904,7 @@ async def binary_data_handler(request): async def test_aiohttp_async_generator_upload(event_loop, aiohttp_server): import aiofiles from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -946,6 +967,7 @@ async def file_sender(file_name): @pytest.mark.asyncio async def test_aiohttp_file_upload_two_files(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1037,6 +1059,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_list_of_two_files(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1258,6 +1281,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1286,6 +1310,7 @@ async def handler(request): @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1320,6 +1345,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport error_answer = """ @@ -1363,6 +1389,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_reconnecting_session(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1403,6 +1430,7 @@ async def test_aiohttp_reconnecting_session_retries( event_loop, aiohttp_server, retries ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1436,6 +1464,7 @@ async def test_aiohttp_reconnecting_session_start_connecting_task_twice( event_loop, aiohttp_server, caplog ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1469,6 +1498,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_json_serializer(event_loop, aiohttp_server, caplog): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1527,6 +1557,7 @@ async def test_aiohttp_json_deserializer(event_loop, aiohttp_server): from aiohttp import web from decimal import Decimal from functools import partial + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1563,7 +1594,8 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): - from aiohttp import web, TCPConnector + from aiohttp import TCPConnector, web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 39b8a9d2..74e00aee 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -1,9 +1,8 @@ import asyncio +import pytest import sys from typing import Dict -import pytest - from gql import Client, gql from gql.transport.exceptions import TransportQueryError diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py new file mode 100644 index 00000000..707cf827 --- /dev/null +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -0,0 +1,77 @@ +import asyncio +import json +import pytest +import types +from typing import List + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.aiohttp_websockets + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"data","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_invalid_query(event_loop, aiohttp_server, query_str: str,): + + from aiohttp import web + + async def handler(request): + return web.Response( + text=invalid_query1_server_answer, + content_type="application/json", + headers={"dummy": "this should not be returned"}, + ) + + app = web.Application() + app.router.add_get("/ws", handler) + server = await aiohttp_server(app) + + url = server.make_url("/ws") + + transport = AIOHTTPWebsocketsTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" \ No newline at end of file diff --git a/tests/test_aiohttp_websocket_online.py b/tests/test_aiohttp_websocket_online.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index cb279ae5..89591426 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -23,6 +23,7 @@ def test_appsync_init_with_minimal_args(fake_session_factory): @pytest.mark.botocore def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): @@ -72,6 +73,7 @@ def test_appsync_init_with_apikey_auth(): @pytest.mark.botocore def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -108,9 +110,10 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have the AWS_DEFAULT_REGION environment variable set """ - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.exceptions import NoRegionError import logging + from botocore.exceptions import NoRegionError + + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport caplog.set_level(logging.WARNING) diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index ca3a3fcb..21d2c8ea 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -1,5 +1,4 @@ import json - import pytest from gql import Client, gql @@ -12,9 +11,10 @@ async def test_appsync_iam_mutation( event_loop, aiohttp_server, fake_credentials_factory ): from aiohttp import web + from urllib.parse import urlparse + from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication - from urllib.parse import urlparse async def handler(request): data = { diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 14c40e75..25cbe200 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -1,11 +1,10 @@ import asyncio import json +import pytest from base64 import b64decode from typing import List from urllib import parse -import pytest - from gql import Client, gql from .conftest import MS, WebSocketServerHelper @@ -424,9 +423,10 @@ async def test_appsync_subscription_api_key(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_with_token(event_loop, server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -451,9 +451,10 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_without_token(event_loop, server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -477,9 +478,10 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_execute_method_not_allowed(event_loop, server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -524,9 +526,10 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): @pytest.mark.botocore async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials dummy_credentials = Credentials( access_key=DUMMY_ACCESS_KEY_ID, @@ -577,10 +580,11 @@ async def test_appsync_subscription_api_key_unauthorized(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_not_allowed(event_loop, server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportQueryError - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index d39019e8..b2e7588d 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -1,7 +1,6 @@ import asyncio -import json - import graphql +import json import pytest from gql import Client, gql diff --git a/tests/test_cli.py b/tests/test_cli.py index f0534957..a6f0d0d8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,4 @@ import logging - import pytest from gql import __version__ @@ -270,8 +269,8 @@ async def test_cli_main_appsync_websockets_iam(parser, url): ) def test_cli_get_transport_appsync_websockets_api_key(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--api-key", "test-api-key"] @@ -291,8 +290,8 @@ def test_cli_get_transport_appsync_websockets_api_key(parser, url): ) def test_cli_get_transport_appsync_websockets_jwt(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--jwt", "test-jwt"] diff --git a/tests/test_client.py b/tests/test_client.py index ada129c6..955c2780 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,8 +1,7 @@ -import os -from contextlib import suppress - import mock +import os import pytest +from contextlib import suppress from graphql import build_ast_schema, parse from gql import Client, GraphQLRequest, gql diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index 4c9e7d76..00628e02 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -1,7 +1,4 @@ import asyncio -from math import isfinite -from typing import Any, Dict, NamedTuple, Optional - import pytest from graphql.error import GraphQLError from graphql.language import ValueNode @@ -17,6 +14,8 @@ GraphQLSchema, ) from graphql.utilities import value_from_ast_untyped +from math import isfinite +from typing import Any, Dict, NamedTuple, Optional from gql import GraphQLRequest, gql diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index ca689c47..37de6e2e 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -1,7 +1,6 @@ import asyncio -from typing import List - import pytest +from typing import List from gql import Client, gql from gql.transport.exceptions import ( @@ -234,6 +233,7 @@ async def server_closing_directly(ws, path): @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): import websockets + from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index cb705368..51eb2da9 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -1,11 +1,10 @@ import asyncio import json +import pytest import sys import warnings -from typing import List - -import pytest from parse import search +from typing import List from gql import Client, gql from gql.transport.exceptions import TransportServerError @@ -816,8 +815,9 @@ async def test_graphqlws_subscription_reconnecting_session( ): import websockets - from gql.transport.websockets import WebsocketsTransport + from gql.transport.exceptions import TransportClosed + from gql.transport.websockets import WebsocketsTransport path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" diff --git a/tests/test_httpx.py b/tests/test_httpx.py index af12f717..95b16a54 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -1,6 +1,5 @@ -from typing import Mapping - import pytest +from typing import Mapping from gql import Client, gql from gql.transport.exceptions import ( @@ -38,6 +37,7 @@ @pytest.mark.asyncio async def test_httpx_query(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -81,6 +81,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -118,6 +119,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_401(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -153,6 +155,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_429(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -202,6 +205,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_500(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -234,6 +238,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -274,6 +279,7 @@ async def test_httpx_invalid_protocol( event_loop, aiohttp_server, response, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -302,6 +308,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -330,6 +337,7 @@ async def test_httpx_cannot_execute_if_not_connected( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -367,6 +375,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_query_with_extensions(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -422,6 +431,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -484,6 +494,7 @@ async def test_httpx_file_upload_with_content_type( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -552,6 +563,7 @@ async def test_httpx_file_upload_additional_headers( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -614,6 +626,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_binary_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport # This is a sample binary file content containing all possible byte values @@ -687,6 +700,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_2 = """ @@ -787,6 +801,7 @@ async def test_httpx_file_upload_list_of_two_files( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_3 = """ @@ -876,6 +891,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport error_answer = """ diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 3665f5d8..2066d964 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1,8 +1,7 @@ import io import json -from typing import Mapping - import pytest +from typing import Mapping from gql import Client, gql from gql.cli import get_parser, main @@ -44,6 +43,7 @@ @pytest.mark.asyncio async def test_httpx_query(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -84,6 +84,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_ignore_backend_content_type(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -114,6 +115,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cookies(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -148,6 +150,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_401(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -180,6 +183,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_429(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -228,6 +232,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_500(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -264,6 +269,7 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_httpx_error_code(event_loop, aiohttp_server, query_error): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -320,6 +326,7 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_httpx_invalid_protocol(event_loop, aiohttp_server, param): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport response = param["response"] @@ -349,6 +356,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_subscribe_not_supported(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -375,6 +383,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -398,6 +407,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_execute_if_not_connected(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -420,9 +430,10 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_extra_args(event_loop, aiohttp_server): + import httpx from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport - import httpx async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -466,6 +477,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_variable_values(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -503,6 +515,7 @@ async def test_httpx_query_variable_values_fix_issue_292(event_loop, aiohttp_ser See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -536,6 +549,7 @@ async def test_httpx_execute_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -565,6 +579,7 @@ async def test_httpx_subscribe_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -652,6 +667,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_httpx_file_upload(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -690,6 +706,7 @@ async def test_httpx_file_upload_without_session( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -759,6 +776,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_httpx_binary_file_upload(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -817,6 +835,7 @@ async def test_httpx_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -909,6 +928,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1135,6 +1155,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_with_extensions(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1163,6 +1184,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_https(event_loop, ssl_aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1198,6 +1220,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport error_answer = """ @@ -1242,6 +1265,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_reconnecting_session(event_loop, aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1281,6 +1305,7 @@ async def handler(request): @pytest.mark.parametrize("retries", [False, lambda e: e]) async def test_httpx_reconnecting_session_retries(event_loop, aiohttp_server, retries): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1315,6 +1340,7 @@ async def test_httpx_reconnecting_session_start_connecting_task_twice( event_loop, aiohttp_server, caplog ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1349,6 +1375,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_json_serializer(event_loop, aiohttp_server, caplog): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1408,6 +1435,7 @@ async def test_httpx_json_deserializer(event_loop, aiohttp_server): from aiohttp import web from decimal import Decimal from functools import partial + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index 23d28dcc..dfa19fde 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -1,9 +1,8 @@ import asyncio +import pytest import sys from typing import Dict -import pytest - from gql import Client, gql from gql.transport.exceptions import TransportQueryError diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index e2bf0091..f59245e7 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -1,5 +1,4 @@ import asyncio - import pytest from gql import Client, gql @@ -19,9 +18,7 @@ def ensure_list(s): return ( s if s is None or isinstance(s, list) - else list(s) - if isinstance(s, tuple) - else [s] + else list(s) if isinstance(s, tuple) else [s] ) @@ -360,9 +357,10 @@ def subscription_server( data_answers=default_subscription_data_answer, unsubscribe_answers=default_subscription_unsubscribe_answer, ): - from .conftest import PhoenixChannelServerHelper import json + from .conftest import PhoenixChannelServerHelper + async def phoenix_server(ws, path): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 6367945d..127e3a20 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -1,8 +1,7 @@ import asyncio import json -import sys - import pytest +import sys from parse import search from gql import Client, gql diff --git a/tests/test_requests.py b/tests/test_requests.py index ba666243..7d5d237d 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,6 +1,5 @@ -from typing import Mapping - import pytest +from typing import Mapping from gql import Client, gql from gql.transport.exceptions import ( @@ -38,6 +37,7 @@ @pytest.mark.asyncio async def test_requests_query(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -81,6 +81,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -118,6 +119,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -153,6 +155,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -202,6 +205,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -234,6 +238,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -274,6 +279,7 @@ async def test_requests_invalid_protocol( event_loop, aiohttp_server, response, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -302,6 +308,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -330,6 +337,7 @@ async def test_requests_cannot_execute_if_not_connected( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -369,6 +377,7 @@ async def test_requests_query_with_extensions( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -424,6 +433,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -486,6 +496,7 @@ async def test_requests_file_upload_with_content_type( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -554,6 +565,7 @@ async def test_requests_file_upload_additional_headers( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -616,6 +628,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_binary_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport # This is a sample binary file content containing all possible byte values @@ -691,6 +704,7 @@ async def test_requests_file_upload_two_files( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_2 = """ @@ -791,6 +805,7 @@ async def test_requests_file_upload_list_of_two_files( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_3 = """ @@ -882,6 +897,7 @@ async def test_requests_error_fetching_schema( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport error_answer = """ @@ -932,6 +948,7 @@ async def test_requests_json_serializer( ): import json from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -994,6 +1011,7 @@ async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_t from aiohttp import web from decimal import Decimal from functools import partial + from gql.transport.requests import RequestsHTTPTransport async def handler(request): diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 4d8bf27e..7be46fd7 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -1,6 +1,5 @@ -from typing import Mapping - import pytest +from typing import Mapping from gql import Client, GraphQLRequest, gql from gql.transport.exceptions import ( @@ -50,6 +49,7 @@ @pytest.mark.asyncio async def test_requests_query(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -95,6 +95,7 @@ async def test_requests_query_auto_batch_enabled( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -143,9 +144,10 @@ async def test_requests_query_auto_batch_enabled_two_requests( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport from threading import Thread + from gql.transport.requests import RequestsHTTPTransport + async def handler(request): return web.Response( text=query1_server_answer_twice_list, @@ -201,6 +203,7 @@ def test_thread(): @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -240,6 +243,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -277,6 +281,7 @@ async def test_requests_error_code_401_auto_batch_enabled( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -315,6 +320,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -364,6 +370,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -396,6 +403,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -441,6 +449,7 @@ async def test_requests_invalid_protocol( event_loop, aiohttp_server, response, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -471,6 +480,7 @@ async def test_requests_cannot_execute_if_not_connected( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -512,6 +522,7 @@ async def test_requests_query_with_extensions( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -551,6 +562,7 @@ def test_code(): def test_requests_sync_batch_auto(): from threading import Thread + from gql.transport.requests import RequestsHTTPTransport client = Client( diff --git a/tests/test_transport.py b/tests/test_transport.py index d9a3eced..27730c07 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,5 +1,4 @@ import os - import pytest from gql import Client, gql @@ -28,6 +27,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index a9b21e6a..abd2152e 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -1,5 +1,4 @@ import os - import pytest from gql import Client, GraphQLRequest, gql @@ -28,6 +27,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 72db8a87..719a948c 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -1,10 +1,9 @@ import asyncio import json +import pytest import types from typing import List -import pytest - from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -273,6 +272,7 @@ async def server_closing_directly(ws, path): @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_websocket_server_closing_directly(event_loop, server): import websockets + from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -372,9 +372,10 @@ async def test_websocket_using_cli_invalid_query( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index b5fca837..45564d55 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -1,10 +1,9 @@ import asyncio import logging +import pytest import sys from typing import Dict -import pytest - from gql import Client, gql from gql.transport.exceptions import TransportError, TransportQueryError diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index e8b7a022..3fb76b58 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,11 +1,11 @@ +import ssl + import asyncio import json -import ssl +import pytest import sys from typing import Dict, Mapping -import pytest - from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -53,6 +53,7 @@ @pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_starting_client_in_context_manager(event_loop, server): import websockets + from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -93,6 +94,7 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): import websockets + from gql.transport.websockets import WebsocketsTransport server = ws_ssl_server @@ -547,10 +549,11 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io import json + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 4419783b..43129c4f 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -1,12 +1,11 @@ import asyncio import json +import pytest import sys import warnings -from typing import List - -import pytest from graphql import ExecutionResult from parse import search +from typing import List from gql import Client, gql from gql.transport.exceptions import TransportServerError From b44fd4501aea5f673eea6cac0471f8b3c070d05f Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 4 Jun 2024 20:45:29 +0000 Subject: [PATCH 09/61] wip: initial tests for aiohttp websockets --- tests/test_aiohttp_websocket_exceptions.py | 392 ++++++++++- tests/test_aiohttp_websocket_query.py | 606 +++++++++++++++++ tests/test_aiohttp_websocket_subscription.py | 648 +++++++++++++++++++ 3 files changed, 1618 insertions(+), 28 deletions(-) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 707cf827..70621ce0 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -1,9 +1,10 @@ import asyncio import json -import pytest import types from typing import List +import pytest + from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -12,12 +13,10 @@ TransportQueryError, ) -from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport - from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.aiohttp_websockets +pytestmark = pytest.mark.websockets invalid_query_str = """ query getContinents { @@ -38,40 +37,377 @@ invalid_query1_server = [invalid_query1_server_answer] + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_invalid_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_aiohttp_websocket_invalid_subscription( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +connection_error_server_answer = ( + '{"type":"connection_error","id":null,' + '"payload":{"message":"Unexpected token Q in JSON at position 0"}}' +) + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_server_does_not_send_ack( + event_loop, server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + +async def server_connection_error(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(connection_error_server_answer) + await ws.wait_closed() + + @pytest.mark.asyncio -@pytest.mark.parametrize("aiohttp_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("server", [server_connection_error], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_websocket_invalid_query(event_loop, aiohttp_server, query_str: str,): - - from aiohttp import web - - async def handler(request): - return web.Response( - text=invalid_query1_server_answer, - content_type="application/json", - headers={"dummy": "this should not be returned"}, +async def test_aiohttp_websocket_sending_invalid_data( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + invalid_data = "QSDF" + print(f">>> {invalid_data}") + await session.transport.websocket.send(invalid_data) + + await asyncio.sleep(2 * MS) + + +invalid_payload_server_answer = ( + '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +) + + +async def server_invalid_payload(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_payload_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_sending_invalid_payload( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + # Monkey patching the _send_query method to send an invalid payload + + async def monkey_patch_send_query( + self, + document, + variable_values=None, + operation_name=None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + {"id": str(query_id), "type": "start", "payload": "BLAHBLAH"} ) - - app = web.Application() - app.router.add_get("/ws", handler) - server = await aiohttp_server(app) - url = server.make_url("/ws") + await self._send(query_str) + return query_id + + session.transport._send_query = types.MethodType( + monkey_patch_send_query, session.transport + ) + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["message"] == "Must provide document" + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "data"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "data", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +payload_is_not_a_dict = ['{"type": "data", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "data", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_aiohttp_websocket_transport_protocol_errors( + event_loop, aiohttp_client_and_server +): + + session, server = aiohttp_client_and_server + + query = gql("query { hello }") + + with pytest.raises(TransportProtocolError): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_without_ack], indirect=True) +async def test_aiohttp_websocket_server_does_not_ack(event_loop, server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_directly], indirect=True) +async def test_aiohttp_websocket_server_closing_directly(event_loop, server): + import websockets + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with Client(transport=sample_transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) +async def test_aiohttp_websocket_server_closing_after_ack( + event_loop, aiohttp_client_and_server +): + + import websockets + + session, server = aiohttp_client_and_server + + query = gql("query { hello }") + + with pytest.raises(websockets.exceptions.ConnectionClosed): + await session.execute(query) + + await session.transport.wait_closed() + + with pytest.raises(TransportClosed): + await session.execute(query) + + +async def server_sending_invalid_query_errors(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + invalid_error = ( + '{"type":"error","id":"404","payload":' + '{"message":"error for no good reason on non existing query"}}' + ) + await ws.send(invalid_error) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_aiohttp_websocket_server_sending_invalid_query_errors( + event_loop, server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + # Invalid server message is ignored + async with Client(transport=sample_transport): + await asyncio.sleep(2 * MS) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_aiohttp_websocket_non_regression_bug_105(event_loop, server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # This test will check a fix to a race condition which happens if the user is trying + # to connect using the same client twice at the same time + # See bug #105 + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + # Create a coroutine which start the connection with the transport but does nothing + async def client_connect(client): + async with client: + await asyncio.sleep(2 * MS) + + # Create two tasks which will try to connect using the same client (not allowed) + connect_task1 = asyncio.ensure_future(client_connect(client)) + connect_task2 = asyncio.ensure_future(client_connect(client)) + + with pytest.raises(TransportAlreadyConnected): + await asyncio.gather(connect_task1, connect_task2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +async def test_aiohttp_websocket_using_cli_invalid_query( + event_loop, server, monkeypatch, capsys +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") - transport = AIOHTTPWebsocketsTransport(url=url, timeout=10) + import io - async with Client(transport=transport) as session: + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) - query = gql(query_str) + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(invalid_query_str)) - with pytest.raises(TransportQueryError) as exc_info: - await session.execute(query) + # Flush captured output + captured = capsys.readouterr() - exception = exc_info.value + await main(args) - assert isinstance(exception.errors, List) + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") - error = exception.errors[0] + expected_error = 'Cannot query field "bloh" on type "Continent"' - assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" \ No newline at end of file + assert expected_error in captured_err diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index e69de29b..489b8814 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -0,0 +1,606 @@ +import asyncio +import json +import ssl +import sys +from typing import Dict, Mapping + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportQueryError, + TransportServerError, +) + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.aiohttp_websockets + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_data = ( + '{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}' +) + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_starting_client_in_context_manager(event_loop, server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_using_ssl_connection(event_loop, ws_ssl_server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(ws_ssl_server.testcert) + + transport = AIOHTTPWebsocketsTransport(url=url, ssl=ssl_context) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_simple_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + +server1_two_answers_in_series = [ + query1_server_answer, + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_two_queries_in_series( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result1 = await session.execute(query) + + print("Query1 received:", result1) + + result2 = await session.execute(query) + + print("Query2 received:", result2) + + assert result1 == result2 + + +async def server1_two_queries_in_parallel(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await ws.send(query1_server_answer.format(query_id=2)) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 2) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_two_queries_in_parallel( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result1 = None + result2 = None + + async def task1_coro(): + nonlocal result1 + result1 = await session.execute(query) + + async def task2_coro(): + nonlocal result2 + result2 = await session.execute(query) + + task1 = asyncio.ensure_future(task1_coro()) + task2 = asyncio.ensure_future(task2_coro()) + + await asyncio.gather(task1, task2) + + print("Query1 received:", result1) + print("Query2 received:", result2) + + assert result1 == result2 + + +async def server_closing_while_we_are_doing_something_else(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await asyncio.sleep(1 * MS) + + # Closing server after first query + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_closing_while_we_are_doing_something_else], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_server_closing_after_first_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + # First query is working + await session.execute(query) + + # Then we do other things + await asyncio.sleep(1000 * MS) + + # Now the server is closed but we don't know it yet, we have to send a query + # to notice it and to receive the exception + with pytest.raises(TransportClosed): + await session.execute(query) + + +ignore_invalid_id_answers = [ + query1_server_answer, + '{"type":"complete","id": "55"}', + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [ignore_invalid_id_answers], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_ignore_invalid_id( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + # First query is working + await session.execute(query) + + # Second query gets no answer -> raises + with pytest.raises(TransportQueryError): + await session.execute(query) + + # Third query is working + await session.execute(query) + + +async def assert_client_is_working(session): + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_series(event_loop, server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + # Check client is disconnect here + assert transport.websocket is None + + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_parallel(event_loop, server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + async def task_coro(): + transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + task1 = asyncio.ensure_future(task_coro()) + task2 = asyncio.ensure_future(task_coro()) + + await asyncio.gather(task1, task2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transport( + event_loop, server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + with pytest.raises(TransportAlreadyConnected): + async with Client(transport=transport): + pass + + +async def server_with_authentication_in_connection_init_payload(ws, path): + # Wait the connection_init message + init_message_str = await ws.recv() + init_message = json.loads(init_message_str) + payload = init_message["payload"] + + if "Authorization" in payload: + if payload["Authorization"] == 12345: + await ws.send('{"type":"connection_ack"}') + + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + else: + await ws.send( + '{"type":"connection_error", "payload": "Invalid Authorization token"}' + ) + else: + await ws.send( + '{"type":"connection_error", "payload": "No Authorization token"}' + ) + + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_connect_success_with_authentication_in_connection_init( + event_loop, server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + init_payload = {"Authorization": 12345} + + transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) + + async with Client(transport=transport) as session: + + query1 = gql(query_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) +async def test_aiohttp_websocket_connect_failed_with_authentication_in_connection_init( + event_loop, server, query_str, init_payload +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) + + with pytest.raises(TransportServerError): + async with Client(transport=transport) as session: + query1 = gql(query_str) + + await session.execute(query1) + + +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +def test_aiohttp_websocket_execute_sync(server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + query1 = gql(query1_str) + + result = client.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Execute sync a second time + result = client.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions + transport = AIOHTTPWebsocketsTransport(url=url, connect_args={"max_size": 2**21}) + + query = gql(query1_str) + + async with Client(transport=transport) as session: + await session.execute(query) + + +async def server_sending_keep_alive_before_connection_ack(ws, path): + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_sending_keep_alive_before_connection_ack], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_non_regression_bug_108( + event_loop, aiohttp_client_and_server, query_str +): + + # This test will check that we now ignore keepalive message + # arriving before the connection_ack + # See bug #108 + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_using_cli(event_loop, server, monkeypatch, capsys): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + import io + import json + + from gql.cli import get_parser, main + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + # Flush captured output + captured = capsys.readouterr() + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + +query1_server_answer_with_extensions = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}},' + '"extensions": {{"key1": "val1"}}}}}}' +) + +server1_answers_with_extensions = [ + query1_server_answer_with_extensions, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_simple_query_with_extensions( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + execution_result = await session.execute(query, get_execution_result=True) + + assert execution_result.extensions["key1"] == "val1" diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index e69de29b..d493cfc8 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -0,0 +1,648 @@ +import asyncio +import json +import sys +import warnings +from typing import List + +import pytest +from graphql import ExecutionResult +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportServerError + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.aiohttp_websockets + +countdown_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +WITH_KEEPALIVE = False + + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +async def server_countdown(ws, path): + import websockets + + logged_messages.clear() + + global WITH_KEEPALIVE + try: + await WebSocketServerHelper.send_connection_ack(ws) + if WITH_KEEPALIVE: + await WebSocketServerHelper.send_keepalive(ws) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(2 * MS) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def stopping_coro(): + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + if json_result["type"] == "stop" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + try: + await WebSocketServerHelper.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break + + stopping_task = asyncio.ensure_future(stopping_coro()) + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + + stopping_task.cancel() + + try: + await stopping_task + except asyncio.CancelledError: + print("Now stopping task is cancelled") + + if WITH_KEEPALIVE: + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print("Now keepalive task is cancelled") + + await WebSocketServerHelper.send_complete(ws, query_id) + await WebSocketServerHelper.wait_connection_terminate(ws) + except websockets.exceptions.ConnectionClosedOK: + pass + finally: + await ws.wait_closed() + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_get_execution_result( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription, get_execution_result=True): + + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_break( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_task_cancel( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_close_transport( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(2 * MS) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_server_connection_closed( + event_loop, aiohttp_client_and_server, subscription_str +): + import websockets + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(websockets.exceptions.ConnectionClosedOK): + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_slow_consumer( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + await asyncio.sleep(10 * MS) + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_operation_name( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +WITH_KEEPALIVE = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync(server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_user_exception(server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(Exception) as exc_info: + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + raise Exception("This is an user exception") + + assert count == 5 + assert "This is an user exception" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_break(server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + break + + assert count == 5 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_graceful_shutdown( + server, subscription_str +): + """Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + interrupt_task = None + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + interrupt_task = asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Catch interrupt_task exception to remove warning + interrupt_task.exception() + + # Check that the server received a connection_terminate message last + assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_running_in_thread( + event_loop, server, subscription_str, run_sync_test +): + from gql.transport.websockets import WebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, server, test_code) From f43853112770c3742a0be489a52a429dc440f609 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 4 Jun 2024 20:48:51 +0000 Subject: [PATCH 10/61] fix some minor import errors --- gql/transport/aiohttp_websockets.py | 1 + tests/conftest.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 99487cc5..aa0000de 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -12,6 +12,7 @@ Optional, Tuple, Union, + Mapping, ) import aiohttp diff --git a/tests/conftest.py b/tests/conftest.py index 0732ac7a..eb022708 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -from random import sample import ssl import asyncio From 8f44fc64c1811ebf452e372a42699feb8a38fcff Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 4 Jun 2024 21:24:20 +0000 Subject: [PATCH 11/61] fix incorrect fixture --- tests/conftest.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index eb022708..75b5ab9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -457,6 +457,22 @@ async def client_and_server(server): url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = WebsocketsTransport(url=url) + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, server@pytest_asyncio.fixture + +@pytest_asyncio.fixture +async def aiohttp_client_and_server(server): + """Helper fixture to start a server and a client connected to its port with an aiohttp websockets transport.""" + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=sample_transport) as session: # Yield both client session and server @@ -483,23 +499,6 @@ async def client_and_graphqlws_server(graphqlws_server): # Yield both client session and server yield session, graphqlws_server -@pytest_asyncio.fixture -async def aiohttp_client_and_server(aiohttp_server): - """Helper fixture to start an aiohttp server and a client connected to its port.""" - - from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport - - # Generate transport to connect to the server fixture - path = "/graphql" - url = f"ws://{aiohttp_server.hostname}:{aiohttp_server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url) - - async with Client(transport=sample_transport) as session: - - # Yield both client session and server - yield session, aiohttp_server - - @pytest_asyncio.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): From ef84507e92958d4720f5f21b6f4f62d58999485f Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 4 Jun 2024 21:43:41 +0000 Subject: [PATCH 12/61] update query tests to match AIOHTTPWebsocketsTransport behavior --- gql/transport/aiohttp_websockets.py | 1 + tests/test_aiohttp_websocket_exceptions.py | 2 +- tests/test_aiohttp_websocket_query.py | 7 +++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index aa0000de..9ddfe212 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -73,6 +73,7 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, + ) -> None: self.url: StrOrURL = url self.headers: Optional[LooseHeaders] = headers diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 70621ce0..0cba0de1 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -16,7 +16,7 @@ from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.websockets +pytestmark = pytest.mark.aiohttp_websockets invalid_query_str = """ query getContinents { diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 489b8814..3062e5e7 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -245,7 +245,7 @@ async def test_aiohttp_websocket_server_closing_after_first_query( # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(ConnectionResetError): await session.execute(query) @@ -491,7 +491,10 @@ async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, ser url = f"ws://{server.hostname}:{server.port}/graphql" # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions - transport = AIOHTTPWebsocketsTransport(url=url, connect_args={"max_size": 2**21}) + transport = AIOHTTPWebsocketsTransport( + url=url, + max_msg_size=(2**21), + ) query = gql(query1_str) From 9870336fad3190bdc24e8862533229d4a0c7296d Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 4 Jun 2024 22:44:02 +0000 Subject: [PATCH 13/61] wip: updating tests to follow AIOHTTPWebsockets protocol --- gql/transport/aiohttp_websockets.py | 1 - tests/conftest.py | 2 +- tests/test_aiohttp_websocket_exceptions.py | 4 ++-- tests/test_aiohttp_websocket_online.py | 0 4 files changed, 3 insertions(+), 4 deletions(-) delete mode 100644 tests/test_aiohttp_websocket_online.py diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 9ddfe212..aa0000de 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -73,7 +73,6 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, - ) -> None: self.url: StrOrURL = url self.headers: Optional[LooseHeaders] = headers diff --git a/tests/conftest.py b/tests/conftest.py index 75b5ab9f..77802d00 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -460,7 +460,7 @@ async def client_and_server(server): async with Client(transport=sample_transport) as session: # Yield both client session and server - yield session, server@pytest_asyncio.fixture + yield session, server @pytest_asyncio.fixture async def aiohttp_client_and_server(server): diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 0cba0de1..e6423208 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -149,7 +149,7 @@ async def test_aiohttp_websocket_sending_invalid_data( invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send(invalid_data) + await session.transport.websocket.send_str(invalid_data) await asyncio.sleep(2 * MS) @@ -313,7 +313,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(ConnectionResetError): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_aiohttp_websocket_online.py b/tests/test_aiohttp_websocket_online.py deleted file mode 100644 index e69de29b..00000000 From 6584f549b1d35a4421ddd7c134b24601a7703329 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Wed, 5 Jun 2024 19:27:40 +0000 Subject: [PATCH 14/61] remove unnecessary websockets imports --- tests/test_aiohttp_websocket_exceptions.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index e6423208..b2e53188 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -282,7 +282,6 @@ async def server_closing_directly(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_aiohttp_websocket_server_closing_directly(event_loop, server): - import websockets from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -291,7 +290,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): sample_transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(ConnectionResetError): async with Client(transport=sample_transport): pass @@ -307,8 +306,6 @@ async def test_aiohttp_websocket_server_closing_after_ack( event_loop, aiohttp_client_and_server ): - import websockets - session, server = aiohttp_client_and_server query = gql("query { hello }") From fc73ba12dfdb97ef5560d88e42fc08227b35fdea Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Mon, 10 Jun 2024 18:45:37 +0000 Subject: [PATCH 15/61] fix some tests --- gql/transport/aiohttp_websockets.py | 37 ++++++++++++++++------ tests/test_aiohttp_websocket_exceptions.py | 2 +- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index aa0000de..5cf88bbb 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -16,8 +16,7 @@ ) import aiohttp -from aiohttp.client_reqrep import Fingerprint -from aiohttp.helpers import BasicAuth, hdrs +from aiohttp import hdrs, BasicAuth, Fingerprint, WSMsgType from aiohttp.typedefs import LooseHeaders, StrOrURL from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDict, CIMultiDictProxy @@ -32,6 +31,11 @@ ) from gql.transport.websockets_base import ListenerQueue +try: + from json.decoder import JSONDecodeError +except ImportError: + from simplejson import JSONDecodeError + log = logging.getLogger("gql.transport.aiohttp_websockets") @@ -149,7 +153,7 @@ def __init__( self.close_exception: Optional[Exception] = None def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] + self, answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the graphql-ws protocol. @@ -175,14 +179,14 @@ def _parse_answer_graphqlws( execution_result: Optional[ExecutionResult] = None try: - answer_type = str(json_answer.get("type")) + answer_type = str(answer.get("type")) if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) + answer_id = int(str(answer.get("id"))) if answer_type == "next" or answer_type == "error": - payload = json_answer.get("payload") + payload = answer.get("payload") if answer_type == "next": @@ -213,7 +217,7 @@ def _parse_answer_graphqlws( ) elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) + self.payloads[answer_type] = answer.get("payload", None) else: raise ValueError @@ -223,7 +227,7 @@ def _parse_answer_graphqlws( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" + f"Server did not return a GraphQL result: {answer}" ) from e return answer_type, answer_id, execution_result @@ -471,14 +475,27 @@ async def _send(self, message: Dict[str, Any]) -> None: raise e async def _receive(self) -> Dict[str, Any]: + log.debug("Entering _receive()") if self.websocket is None: raise TransportClosed("WebSocket connection is closed") - answer = await self.websocket.receive_json() + try: + answer = await self.websocket.receive_json() + except TypeError as e: + answer = await self.websocket.receive() + if answer.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + self._fail(e, clean_close=True) + raise ConnectionResetError + else: + self._fail(e, clean_close=False) + except JSONDecodeError as e: + self._fail(e) log.info("<<< %s", answer) + log.debug("Exiting _receive()") + return answer def _remove_listener(self, query_id) -> None: @@ -546,6 +563,8 @@ async def _handle_answer( async def _receive_data_loop(self) -> None: """Main asyncio task which will listen to the incoming messages and will call the parse_answer and handle_answer methods of the subclass.""" + log.debug("Entering _receive_data_loop()") + try: while True: diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index b2e53188..d50ac887 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -250,7 +250,7 @@ async def test_aiohttp_websocket_transport_protocol_errors( query = gql("query { hello }") - with pytest.raises(TransportProtocolError): + with pytest.raises((TransportProtocolError, TransportQueryError)): await session.execute(query) From e4dbce8b66fdf96fc4853a3b4f1d072e9926b4c7 Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Mon, 10 Jun 2024 18:59:10 +0000 Subject: [PATCH 16/61] more test fixes Signed-off-by: Micah Pegman --- tests/test_aiohttp_websocket_exceptions.py | 5 ----- tests/test_aiohttp_websocket_query.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index d50ac887..48312544 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -310,11 +310,6 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(ConnectionResetError): - await session.execute(query) - - await session.transport.wait_closed() - with pytest.raises(TransportClosed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 3062e5e7..2899038e 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -245,7 +245,7 @@ async def test_aiohttp_websocket_server_closing_after_first_query( # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportClosed): await session.execute(query) From e47ce3bf4479aca43bfa4453764ec36c24108a6b Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Mon, 10 Jun 2024 20:26:34 +0000 Subject: [PATCH 17/61] add ci test job for aiohttp_websockets --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 30e8289c..62fbe24a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false matrix: - dependency: ["aiohttp", "requests", "httpx", "websockets"] + dependency: ["aiohttp", "requests", "httpx", "websockets", "aiohttp_websockets"] steps: - uses: actions/checkout@v3 From cd71e91f6304d947dd0c677f90ae7509a8d85e91 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 14 Jun 2024 20:12:27 +0000 Subject: [PATCH 18/61] add aiohttp websocket tests for graphql websockets --- tests/conftest.py | 34 +- ...st_aiohttp_websocket_graphql_exceptions.py | 273 ++++++ ..._aiohttp_websocket_graphql_subscription.py | 879 ++++++++++++++++++ 3 files changed, 1181 insertions(+), 5 deletions(-) create mode 100644 tests/test_aiohttp_websocket_graphql_exceptions.py create mode 100644 tests/test_aiohttp_websocket_graphql_subscription.py diff --git a/tests/conftest.py b/tests/conftest.py index 77802d00..8cdf6c9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,19 @@ -import ssl - import asyncio import json import logging import os import pathlib -import pytest -import pytest_asyncio import re +import ssl import sys import tempfile import types from concurrent.futures import ThreadPoolExecutor from typing import Union +import pytest +import pytest_asyncio + from gql import Client all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] @@ -462,9 +462,13 @@ async def client_and_server(server): # Yield both client session and server yield session, server + @pytest_asyncio.fixture async def aiohttp_client_and_server(server): - """Helper fixture to start a server and a client connected to its port with an aiohttp websockets transport.""" + """ + Helper fixture to start a server and a client connected to its port + with an aiohttp websockets transport. + """ from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -499,6 +503,26 @@ async def client_and_graphqlws_server(graphqlws_server): # Yield both client session and server yield session, graphqlws_server +@pytest_asyncio.fixture +async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): + """Helper fixture to start a server with the graphql-ws prototocol + and a client connected to its port.""" + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport( + url=url, + protocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + ) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, graphqlws_server + @pytest_asyncio.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): diff --git a/tests/test_aiohttp_websocket_graphql_exceptions.py b/tests/test_aiohttp_websocket_graphql_exceptions.py new file mode 100644 index 00000000..497ab52d --- /dev/null +++ b/tests/test_aiohttp_websocket_graphql_exceptions.py @@ -0,0 +1,273 @@ +import asyncio +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.aiohttp_websockets + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"next","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_graphqlws_invalid_query( + event_loop, client_and_aiohttp_websocket_graphql_server, query_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_invalid_subscription], indirect=True +) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_aiohttp_graphqlws_invalid_subscription( + event_loop, client_and_aiohttp_websocket_graphql_server, query_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_graphqlws_server_does_not_send_ack( + event_loop, graphqlws_server, query_str +): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + +invalid_query_server_answer = ( + '{"id":"1","type":"error","payload":[{"message":"Cannot query field ' + '\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",' + '"locations":[{"line":2,"column":3}]}]}' +) + + +async def server_invalid_query(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_query_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) +async def test_aiohttp_graphqlws_sending_invalid_query(event_loop, client_and_aiohttp_websocket_graphql_server): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("{helo}") + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert ( + error["message"] + == 'Cannot query field "helo" on type "Query". Did you mean "hello"?' + ) + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "next"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "next", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}'] +payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "next", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + error_with_payload_not_a_list, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_aiohttp_graphqlws_transport_protocol_errors( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("query { hello }") + + with pytest.raises(TransportProtocolError): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) +async def test_aiohttp_graphqlws_server_does_not_ack(event_loop, graphqlws_server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) +async def test_aiohttp_graphqlws_server_closing_directly(event_loop, graphqlws_server): + import websockets + + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with Client(transport=sample_transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) +async def test_aiohttp_graphqlws_server_closing_after_ack( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + import websockets + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("query { hello }") + + with pytest.raises(websockets.exceptions.ConnectionClosed): + await session.execute(query) + + await session.transport.wait_closed() + + with pytest.raises(TransportClosed): + await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphql_subscription.py b/tests/test_aiohttp_websocket_graphql_subscription.py new file mode 100644 index 00000000..0f4106a1 --- /dev/null +++ b/tests/test_aiohttp_websocket_graphql_subscription.py @@ -0,0 +1,879 @@ +import asyncio +import json +import sys +import warnings +from typing import List + +import pytest +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportServerError + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.aiohttp_websockets + +countdown_server_answer = ( + '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +COUNTING_DELAY = 20 * MS +PING_SENDING_DELAY = 50 * MS +PONG_TIMEOUT = 100 * MS + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def server_countdown_factory( + keepalive=False, answer_pings=True, simulate_disconnect=False +): + async def server_countdown_template(ws, path): + import websockets + + logged_messages.clear() + + try: + await WebSocketServerHelper.send_connection_ack( + ws, payload="dummy_connection_ack_payload" + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f" Server: Countdown started from: {count}") + + if simulate_disconnect and count == 8: + await ws.close() + + pong_received: asyncio.Event = asyncio.Event() + + async def counting_coro(): + print(" Server: counting task started") + try: + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format( + query_id=query_id, number=number + ) + ) + await asyncio.sleep(COUNTING_DELAY) + finally: + print(" Server: counting task ended") + + print(" Server: starting counting task") + counting_task = asyncio.ensure_future(counting_coro()) + + async def keepalive_coro(): + print(" Server: keepalive task started") + try: + while True: + await asyncio.sleep(PING_SENDING_DELAY) + try: + # Send a ping + await WebSocketServerHelper.send_ping( + ws, payload="dummy_ping_payload" + ) + + # Wait for a pong + try: + await asyncio.wait_for( + pong_received.wait(), PONG_TIMEOUT + ) + except asyncio.TimeoutError: + print( + "\n Server: No pong received in time!\n" + ) + break + + pong_received.clear() + + except websockets.exceptions.ConnectionClosed: + break + finally: + print(" Server: keepalive task ended") + + if keepalive: + print(" Server: starting keepalive task") + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + print(" Server: receiving task started") + try: + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + answer_type = json_result["type"] + + if answer_type == "complete" and json_result["id"] == str( + query_id + ): + print("Cancelling counting task now") + counting_task.cancel() + if keepalive: + print("Cancelling keep alive task now") + keepalive_task.cancel() + + elif answer_type == "ping": + if answer_pings: + payload = json_result.get("payload", None) + await WebSocketServerHelper.send_pong( + ws, payload=payload + ) + + elif answer_type == "pong": + pong_received.set() + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + print(" Server: waiting for counting task to complete") + await counting_task + except asyncio.CancelledError: + print(" Server: Now counting task is cancelled") + + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) + + if keepalive: + print(" Server: cancelling keepalive task") + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") + + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\n Server: Assertion failed: {e!s}\n") + finally: + print(" Server: waiting for websocket connection to close") + await ws.wait_closed() + print(" Server: connection closed") + + return server_countdown_template + + +async def server_countdown(ws, path): + + server = server_countdown_factory() + await server(ws, path) + + +async def server_countdown_keepalive(ws, path): + + server = server_countdown_factory(keepalive=True) + await server(ws, path) + + +async def server_countdown_dont_answer_pings(ws, path): + + server = server_countdown_factory(answer_pings=False) + await server(ws, path) + + +async def server_countdown_disconnect(ws, path): + + server = server_countdown_factory(simulate_disconnect=True) + await server(ws, path) + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_break( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_task_cancel( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_close_transport( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(COUNTING_DELAY) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_server_connection_closed( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + import websockets + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(ConnectionResetError): + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_with_operation_name( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_with_keepalive( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + assert "ping" in session.transport.payloads + assert session.transport.payloads["ping"] == "dummy_ping_payload" + assert ( + session.transport.payloads["connection_ack"] == "dummy_connection_ack_payload" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(COUNTING_DELAY / 2)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_with_ping_interval_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, + ping_interval=(5 * COUNTING_DELAY), + pong_timeout=(4 * COUNTING_DELAY), + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_dont_answer_pings], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_with_ping_interval_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, ping_interval=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No pong received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_manual_pings_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + payload = {"count_received": count} + + await transport.send_ping(payload=payload) + + await asyncio.wait_for(transport.pong_received.wait(), 10000 * MS) + + transport.pong_received.clear() + + assert transport.payloads["pong"] == payload + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_manual_pong_answers_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, answer_pings=False) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + async def answer_ping_coro(): + while True: + await transport.ping_received.wait() + transport.ping_received.clear() + await transport.send_pong(payload={"some": "data"}) + + answer_ping_task = asyncio.ensure_future(answer_ping_coro()) + + try: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + finally: + answer_ping_task.cancel() + + assert count == -1 + + +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_graphqlws_subscription_sync(graphqlws_server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_graphqlws_subscription_sync_graceful_shutdown( + graphqlws_server, subscription_str +): + """Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Check that the server received a connection_terminate message last + # assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_graphqlws_subscription_running_in_thread( + event_loop, graphqlws_server, subscription_str, run_sync_test +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, graphqlws_server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_disconnect], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +@pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) +async def test_aiohttp_graphqlws_subscription_reconnecting_session( + event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe +): + + import websockets + + from gql.transport.exceptions import TransportClosed + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 8 + subscription_with_disconnect = gql(subscription_str.format(count=count)) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + session = await client.connect_async( + reconnecting=True, retry_connect=False, retry_execute=False + ) + + # First we make a subscription which will cause a disconnect in the backend + # (count=8) + try: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): + pass + except websockets.exceptions.ConnectionClosedOK: + pass + + await asyncio.sleep(50 * MS) + + # Then with the same session handle, we make a subscription or an execute + # which will detect that the transport is closed so that the client could + # try to reconnect + try: + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + await session.execute(subscription) + else: + print("\nSUBSCRIPTION_2\n") + async for result in session.subscribe(subscription): + pass + except TransportClosed: + pass + + await asyncio.sleep(50 * MS) + + # And finally with the same session handle, we make a subscription + # which works correctly + print("\nSUBSCRIPTION_3\n") + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await client.close_async() From bbddf25b85b7e1184f9486f1497399180cfbc31b Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 14 Jun 2024 20:14:02 +0000 Subject: [PATCH 19/61] update tests to AIOHTTPWebsocketsTransport --- tests/test_aiohttp_websocket_subscription.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index d493cfc8..e5c5aba4 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -311,14 +311,13 @@ async def server_countdown_close_connection_in_middle(ws, path): async def test_aiohttp_websocket_subscription_server_connection_closed( event_loop, aiohttp_client_and_server, subscription_str ): - import websockets session, server = aiohttp_client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(ConnectionResetError): async for result in session.subscribe(subscription): @@ -416,11 +415,11 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( event_loop, server, subscription_str ): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + sample_transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) client = Client(transport=sample_transport) @@ -446,11 +445,11 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( event_loop, server, subscription_str ): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + sample_transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) client = Client(transport=sample_transport) From 252969eb59a102ad939da5fb7aaf9b6a1e38d258 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 14 Jun 2024 20:15:18 +0000 Subject: [PATCH 20/61] add missing import --- gql/transport/aiohttp_websockets.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 5cf88bbb..c20baa1f 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -3,20 +3,21 @@ import asyncio import logging from contextlib import suppress +from json.decoder import JSONDecodeError from ssl import SSLContext from typing import ( Any, AsyncGenerator, Collection, Dict, + Mapping, Optional, Tuple, Union, - Mapping, ) import aiohttp -from aiohttp import hdrs, BasicAuth, Fingerprint, WSMsgType +from aiohttp import BasicAuth, Fingerprint, WSMsgType, hdrs from aiohttp.typedefs import LooseHeaders, StrOrURL from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDict, CIMultiDictProxy @@ -31,11 +32,6 @@ ) from gql.transport.websockets_base import ListenerQueue -try: - from json.decoder import JSONDecodeError -except ImportError: - from simplejson import JSONDecodeError - log = logging.getLogger("gql.transport.aiohttp_websockets") From 1b729c2febb9e80a34674b3ae8f684962f4511dc Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 14 Jun 2024 20:24:52 +0000 Subject: [PATCH 21/61] set exceptions tests to expect the correct errors --- tests/test_aiohttp_websocket_graphql_exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_aiohttp_websocket_graphql_exceptions.py b/tests/test_aiohttp_websocket_graphql_exceptions.py index 497ab52d..93ecb52d 100644 --- a/tests/test_aiohttp_websocket_graphql_exceptions.py +++ b/tests/test_aiohttp_websocket_graphql_exceptions.py @@ -201,7 +201,7 @@ async def test_aiohttp_graphqlws_transport_protocol_errors( query = gql("query { hello }") - with pytest.raises(TransportProtocolError): + with pytest.raises((TransportProtocolError, TransportQueryError)): await session.execute(query) @@ -264,7 +264,7 @@ async def test_aiohttp_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportClosed): await session.execute(query) await session.transport.wait_closed() From 3d9940230a82ffc1437c1d9650a504136ae36590 Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Mon, 17 Jun 2024 21:26:58 +0000 Subject: [PATCH 22/61] test updates Signed-off-by: Micah Pegman --- gql/transport/aiohttp_websockets.py | 49 ++++++++++++++++--- ...st_aiohttp_websocket_graphql_exceptions.py | 8 +-- ..._aiohttp_websocket_graphql_subscription.py | 22 ++++----- tests/test_aiohttp_websocket_subscription.py | 2 +- 4 files changed, 56 insertions(+), 25 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index c20baa1f..496ee4a1 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -343,7 +343,13 @@ async def _stop_listener(self, query_id: int): """Hook to stop to listen to a specific query. Will send a stop message in some subclasses. """ - pass # pragma: no cover + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) async def _after_connect(self): if self.websocket is None: @@ -351,13 +357,8 @@ async def _after_connect(self): # Find the backend subprotocol returned in the response headers # TODO: find the equivalent of response_headers in aiohttp websocket response - subprotocol = self.websocket.protocol - try: - self.subprotocol = subprotocol - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL + subprotocol = self.websocket.protocol or self.GRAPHQLWS_SUBPROTOCOL + self.subprotocol = subprotocol log.debug(f"backend subprotocol returned: {self.subprotocol!r}") @@ -371,6 +372,36 @@ async def send_ping(self, payload: Optional[Any] = None) -> None: await self._send(ping_message) + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(pong_message) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = {"id": str(query_id), "type": "stop"} + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = {"id": str(query_id), "type": "complete"} + + await self._send(complete_message) + async def _send_ping_coro(self) -> None: """Coroutine to periodically send a ping from the client to the backend. @@ -687,9 +718,11 @@ async def _clean_close(self) -> None: - send stop messages for each active subscription to the server - send the connection terminate message """ + log.debug(f"Listeners: {self.listeners}") # Send 'stop' message for all current queries for query_id, listener in self.listeners.items(): + print(f"Listener {query_id} send_stop: {listener.send_stop}") if listener.send_stop: await self._stop_listener(query_id) diff --git a/tests/test_aiohttp_websocket_graphql_exceptions.py b/tests/test_aiohttp_websocket_graphql_exceptions.py index 93ecb52d..d49ee7d0 100644 --- a/tests/test_aiohttp_websocket_graphql_exceptions.py +++ b/tests/test_aiohttp_websocket_graphql_exceptions.py @@ -141,7 +141,9 @@ async def server_invalid_query(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) -async def test_aiohttp_graphqlws_sending_invalid_query(event_loop, client_and_aiohttp_websocket_graphql_server): +async def test_aiohttp_graphqlws_sending_invalid_query( + event_loop, client_and_aiohttp_websocket_graphql_server +): session, server = client_and_aiohttp_websocket_graphql_server @@ -258,9 +260,7 @@ async def test_aiohttp_graphqlws_server_closing_after_ack( event_loop, client_and_aiohttp_websocket_graphql_server ): - import websockets - - session, server = client_and_aiohttp_websocket_graphql_server + session, _ = client_and_aiohttp_websocket_graphql_server query = gql("query { hello }") diff --git a/tests/test_aiohttp_websocket_graphql_subscription.py b/tests/test_aiohttp_websocket_graphql_subscription.py index 0f4106a1..1c3348bc 100644 --- a/tests/test_aiohttp_websocket_graphql_subscription.py +++ b/tests/test_aiohttp_websocket_graphql_subscription.py @@ -388,17 +388,13 @@ async def server_countdown_close_connection_in_middle(ws, path): async def test_aiohttp_graphqlws_subscription_server_connection_closed( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): - import websockets - - session, server = client_and_aiohttp_websocket_graphql_server + session, _ = client_and_aiohttp_websocket_graphql_server count = 10 subscription = gql(subscription_str.format(count=count)) with pytest.raises(ConnectionResetError): - async for result in session.subscribe(subscription): - number = result["number"] print(f"Number received: {number}") @@ -478,7 +474,9 @@ async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(5 * COUNTING_DELAY)) + transport = AIOHTTPWebsocketsTransport( + url=url, keep_alive_timeout=(5 * COUNTING_DELAY) + ) client = Client(transport=transport) @@ -510,7 +508,9 @@ async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_nok( path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(COUNTING_DELAY / 2)) + transport = AIOHTTPWebsocketsTransport( + url=url, keep_alive_timeout=(COUNTING_DELAY / 2) + ) client = Client(transport=transport) @@ -545,8 +545,8 @@ async def test_aiohttp_graphqlws_subscription_with_ping_interval_ok( url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" transport = AIOHTTPWebsocketsTransport( url=url, - ping_interval=(5 * COUNTING_DELAY), - pong_timeout=(4 * COUNTING_DELAY), + ping_interval=(10 * COUNTING_DELAY), + pong_timeout=(8 * COUNTING_DELAY), ) client = Client(transport=transport) @@ -815,8 +815,6 @@ async def test_aiohttp_graphqlws_subscription_reconnecting_session( event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe ): - import websockets - from gql.transport.exceptions import TransportClosed from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -842,7 +840,7 @@ async def test_aiohttp_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except websockets.exceptions.ConnectionClosedOK: + except ConnectionResetError: pass await asyncio.sleep(50 * MS) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index e5c5aba4..74f67619 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -247,7 +247,7 @@ async def test_aiohttp_websocket_subscription_close_transport( event_loop, aiohttp_client_and_server, subscription_str ): - session, server = aiohttp_client_and_server + session, _ = aiohttp_client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) From 10e21729f68ba4b70206a8741ce3a4222660bb37 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Tue, 18 Jun 2024 07:58:19 +0000 Subject: [PATCH 23/61] get response headers from aiohttp client response --- gql/transport/aiohttp_websockets.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 496ee4a1..8d9b66b6 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -357,8 +357,11 @@ async def _after_connect(self): # Find the backend subprotocol returned in the response headers # TODO: find the equivalent of response_headers in aiohttp websocket response - subprotocol = self.websocket.protocol or self.GRAPHQLWS_SUBPROTOCOL - self.subprotocol = subprotocol + response_headers = self.websocket._response.headers + try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + self.subprotocol = self.APOLLO_SUBPROTOCOL log.debug(f"backend subprotocol returned: {self.subprotocol!r}") From 3bd3126982c1e55629d8a800c2ca01b8dbc89e0d Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Tue, 18 Jun 2024 21:14:33 +0000 Subject: [PATCH 24/61] transport feature updates Signed-off-by: Micah Pegman --- gql/transport/aiohttp_websockets.py | 57 ++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 8d9b66b6..607b379c 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -2,6 +2,7 @@ import asyncio import logging +import warnings from contextlib import suppress from json.decoder import JSONDecodeError from ssl import SSLContext @@ -96,6 +97,18 @@ def __init__( self.verify_ssl: Optional[bool] = verify_ssl self.init_payload: Dict[str, Any] = init_payload + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout @@ -337,6 +350,9 @@ async def _send_init_message_and_wait_ack(self) -> None: await asyncio.wait_for(self._wait_ack(), self.ack_timeout) async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ await self._send_init_message_and_wait_ack() async def _stop_listener(self, query_id: int): @@ -352,12 +368,13 @@ async def _stop_listener(self, query_id: int): await self._send_stop_message(query_id) async def _after_connect(self): - if self.websocket is None: - raise TransportClosed("WebSocket connection is closed") - + """Hook to add custom code for subclasses after the connection + has been established. + """ # Find the backend subprotocol returned in the response headers # TODO: find the equivalent of response_headers in aiohttp websocket response response_headers = self.websocket._response.headers + log.debug(f"Response headers: {response_headers!r}") try: self.subprotocol = response_headers["Sec-WebSocket-Protocol"] except KeyError: @@ -439,6 +456,9 @@ async def _send_ping_coro(self) -> None: ) async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ # If requested, create a task to send periodic pings to the backend if ( @@ -450,13 +470,29 @@ async def _after_initialize(self): async def _close_hook(self): """Hook to add custom code for subclasses for the connection close""" - pass # pragma: no cover + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task + self.send_ping_task = None async def _connection_terminate(self): """Hook to add custom code for subclasses after the initialization has been done. """ - pass # pragma: no cover + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = {"type": "connection_terminate"} + + await self._send(connection_terminate_message) async def _send_query( self, @@ -590,6 +626,15 @@ async def _handle_answer( # Do nothing if no one is listening to this query_id. pass + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + async def _receive_data_loop(self) -> None: """Main asyncio task which will listen to the incoming messages and will call the parse_answer and handle_answer methods of the subclass.""" @@ -662,7 +707,7 @@ async def connect(self) -> None: max_msg_size=self.max_msg_size, origin=self.origin, params=self.params, - protocols=self.protocols, + protocols=self.supported_subprotocols, proxy=self.proxy, proxy_auth=self.proxy_auth, proxy_headers=self.proxy_headers, From 2ad42abd74a954353bb3c2c80bbc9b885739f0f0 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Mon, 8 Jul 2024 19:21:40 +0000 Subject: [PATCH 25/61] increase test coverage over edge cases --- tests/test_aiohttp_websocket_subscription.py | 166 ++++++++++++++++++- 1 file changed, 165 insertions(+), 1 deletion(-) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 74f67619..cf6111b5 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,9 +9,73 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportClosed, TransportServerError from .conftest import MS, WebSocketServerHelper +from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef + +starwars_expected_one = { + "stars": 3, + "commentary": "Was expecting more stuff", + "episode": "JEDI", +} + +starwars_expected_two = { + "stars": 5, + "commentary": "This is a great movie!", + "episode": "JEDI", +} + + +async def server_starwars(ws, path): + import websockets + + await WebSocketServerHelper.send_connection_ack(ws) + + try: + await ws.recv() + + reviews = [starwars_expected_one, starwars_expected_two] + + for review in reviews: + + data = ( + '{"type":"data","id":"1","payload":{"data":{"reviewAdded": ' + + json.dumps(review) + + "}}}" + ) + await ws.send(data) + await asyncio.sleep(2 * MS) + + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.wait_connection_terminate(ws) + + except websockets.exceptions.ConnectionClosedOK: + pass + + print("Server is now closed") + + +starwars_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + stars, + commentary, + episode + } + } +""" + +starwars_invalid_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + not_valid_field, + stars, + commentary, + episode + } + } +""" # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.aiohttp_websockets @@ -645,3 +709,103 @@ def test_code(): assert count == -1 await run_sync_test(event_loop, server, test_code) + +@pytest.mark.aiohttp_websockets +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_starwars], indirect=True) +@pytest.mark.parametrize("subscription_str", [starwars_subscription_str]) +@pytest.mark.parametrize( + "client_params", + [ + {"schema": StarWarsSchema}, + {"introspection": StarWarsIntrospection}, + {"schema": StarWarsTypeDef}, + ], +) +async def test_async_aiohttp_client_validation( + event_loop, server, subscription_str, client_params +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=sample_transport, **client_params) + + async with client as session: + + variable_values = {"ep": "JEDI"} + + subscription = gql(subscription_str) + + expected = [] + + async for result in session.subscribe( + subscription, variable_values=variable_values, parse_result=False + ): + + review = result["reviewAdded"] + expected.append(review) + + assert "stars" in review + assert "commentary" in review + assert "episode" in review + + assert expected[0] == starwars_expected_one + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_closing_transport( + event_loop, server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + session.transport.websocket._writer._closing = True + + with pytest.raises(ConnectionResetError) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "Cannot write to closing transport" + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_null_transport( + event_loop, server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url, receive_timeout=0.1) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + session.transport.websocket = None + + with pytest.raises(TransportClosed) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "WebSocket connection is closed" + From e5770c7ff5a80508e79ade704c575aa9cf1de247 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Mon, 8 Jul 2024 19:34:13 +0000 Subject: [PATCH 26/61] ran formatter --- docs/code_examples/console_async.py | 2 +- docs/code_examples/fastapi_async.py | 2 +- docs/code_examples/httpx_async_trio.py | 1 - .../reconnecting_mutation_http.py | 3 +- .../code_examples/reconnecting_mutation_ws.py | 3 +- gql/cli.py | 3 +- gql/client.py | 138 +++++++++++------- gql/dsl.py | 5 +- gql/graphql_request.py | 3 +- gql/transport/aiohttp.py | 8 +- gql/transport/aiohttp_websockets.py | 10 +- gql/transport/appsync_websockets.py | 6 +- gql/transport/async_transport.py | 3 +- gql/transport/httpx.py | 5 +- gql/transport/local_schema.py | 3 +- gql/transport/phoenix_channel_websockets.py | 3 +- gql/transport/requests.py | 3 +- gql/transport/transport.py | 3 +- gql/transport/websockets.py | 6 +- gql/transport/websockets_base.py | 8 +- gql/utilities/get_introspection_query_ast.py | 3 +- gql/utilities/node_tree.py | 3 +- gql/utilities/parse_result.py | 3 +- gql/utilities/serialize_variable_values.py | 3 +- gql/utilities/update_schema_enum.py | 3 +- gql/utilities/update_schema_scalars.py | 3 +- tests/conftest.py | 2 + tests/custom_scalars/test_datetime.py | 5 +- tests/custom_scalars/test_enum_colors.py | 3 +- tests/custom_scalars/test_json.py | 3 +- tests/custom_scalars/test_money.py | 5 +- tests/nested_input/schema.py | 1 + tests/starwars/schema.py | 1 + tests/starwars/test_subscription.py | 1 + tests/test_aiohttp.py | 3 +- tests/test_aiohttp_online.py | 3 +- tests/test_aiohttp_websocket_subscription.py | 17 +-- tests/test_appsync_http.py | 1 + tests/test_appsync_websockets.py | 3 +- tests/test_async_client_validation.py | 3 +- tests/test_cli.py | 1 + tests/test_client.py | 5 +- tests/test_graphql_request.py | 5 +- tests/test_graphqlws_exceptions.py | 3 +- tests/test_graphqlws_subscription.py | 5 +- tests/test_httpx.py | 3 +- tests/test_httpx_async.py | 3 +- tests/test_httpx_online.py | 3 +- tests/test_phoenix_channel_exceptions.py | 5 +- tests/test_phoenix_channel_subscription.py | 3 +- tests/test_requests.py | 3 +- tests/test_requests_batch.py | 3 +- tests/test_transport.py | 1 + tests/test_transport_batch.py | 1 + tests/test_websocket_exceptions.py | 3 +- tests/test_websocket_online.py | 3 +- tests/test_websocket_query.py | 6 +- tests/test_websocket_subscription.py | 5 +- 58 files changed, 213 insertions(+), 135 deletions(-) diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 2ec4feec..9a5e94e5 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -1,7 +1,7 @@ import asyncio import logging -from aioconsole import ainput +from aioconsole import ainput from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 511b4abc..80920252 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -7,9 +7,9 @@ # uvicorn fastapi_async:app --reload import logging + from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse - from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/httpx_async_trio.py b/docs/code_examples/httpx_async_trio.py index 058b952b..b76dab42 100644 --- a/docs/code_examples/httpx_async_trio.py +++ b/docs/code_examples/httpx_async_trio.py @@ -1,5 +1,4 @@ import trio - from gql import Client, gql from gql.transport.httpx import HTTPXAsyncTransport diff --git a/docs/code_examples/reconnecting_mutation_http.py b/docs/code_examples/reconnecting_mutation_http.py index b379be91..f4329c8b 100644 --- a/docs/code_examples/reconnecting_mutation_http.py +++ b/docs/code_examples/reconnecting_mutation_http.py @@ -1,7 +1,8 @@ import asyncio -import backoff import logging +import backoff + from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/reconnecting_mutation_ws.py b/docs/code_examples/reconnecting_mutation_ws.py index b407ddaa..7d7c8f8a 100644 --- a/docs/code_examples/reconnecting_mutation_ws.py +++ b/docs/code_examples/reconnecting_mutation_ws.py @@ -1,7 +1,8 @@ import asyncio -import backoff import logging +import backoff + from gql import Client, gql from gql.transport.websockets import WebsocketsTransport diff --git a/gql/cli.py b/gql/cli.py index 234478de..55e03ccb 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -5,8 +5,9 @@ import sys import textwrap from argparse import ArgumentParser, Namespace, RawTextHelpFormatter -from graphql import GraphQLError, print_schema from typing import Any, Dict, Optional + +from graphql import GraphQLError, print_schema from yarl import URL from gql import Client, __version__, gql diff --git a/gql/client.py b/gql/client.py index 17b8be49..dd9c2c5b 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,21 +1,9 @@ import asyncio -import backoff import logging import sys import time import warnings -from anyio import fail_after from concurrent.futures import Future -from graphql import ( - DocumentNode, - ExecutionResult, - GraphQLSchema, - IntrospectionQuery, - build_ast_schema, - get_introspection_query, - parse, - validate, -) from queue import Queue from threading import Event, Thread from typing import ( @@ -33,6 +21,19 @@ overload, ) +import backoff +from anyio import fail_after +from graphql import ( + DocumentNode, + ExecutionResult, + GraphQLSchema, + IntrospectionQuery, + build_ast_schema, + get_introspection_query, + parse, + validate, +) + from .graphql_request import GraphQLRequest from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportClosed, TransportQueryError @@ -201,7 +202,8 @@ def execute_sync( *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: ... # pragma: no cover + ) -> Dict[str, Any]: + ... # pragma: no cover @overload def execute_sync( @@ -214,7 +216,8 @@ def execute_sync( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: ... # pragma: no cover + ) -> ExecutionResult: + ... # pragma: no cover @overload def execute_sync( @@ -227,7 +230,8 @@ def execute_sync( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover def execute_sync( self, @@ -260,7 +264,8 @@ def execute_batch_sync( parse_result: Optional[bool] = None, get_execution_result: Literal[False], **kwargs, - ) -> List[Dict[str, Any]]: ... # pragma: no cover + ) -> List[Dict[str, Any]]: + ... # pragma: no cover @overload def execute_batch_sync( @@ -271,7 +276,8 @@ def execute_batch_sync( parse_result: Optional[bool] = None, get_execution_result: Literal[True], **kwargs, - ) -> List[ExecutionResult]: ... # pragma: no cover + ) -> List[ExecutionResult]: + ... # pragma: no cover @overload def execute_batch_sync( @@ -282,7 +288,8 @@ def execute_batch_sync( parse_result: Optional[bool] = None, get_execution_result: bool, **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover def execute_batch_sync( self, @@ -314,7 +321,8 @@ async def execute_async( *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: ... # pragma: no cover + ) -> Dict[str, Any]: + ... # pragma: no cover @overload async def execute_async( @@ -327,7 +335,8 @@ async def execute_async( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: ... # pragma: no cover + ) -> ExecutionResult: + ... # pragma: no cover @overload async def execute_async( @@ -340,7 +349,8 @@ async def execute_async( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover async def execute_async( self, @@ -375,7 +385,8 @@ def execute( *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: ... # pragma: no cover + ) -> Dict[str, Any]: + ... # pragma: no cover @overload def execute( @@ -388,7 +399,8 @@ def execute( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: ... # pragma: no cover + ) -> ExecutionResult: + ... # pragma: no cover @overload def execute( @@ -401,7 +413,8 @@ def execute( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover def execute( self, @@ -487,7 +500,8 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[False], **kwargs, - ) -> List[Dict[str, Any]]: ... # pragma: no cover + ) -> List[Dict[str, Any]]: + ... # pragma: no cover @overload def execute_batch( @@ -498,7 +512,8 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[True], **kwargs, - ) -> List[ExecutionResult]: ... # pragma: no cover + ) -> List[ExecutionResult]: + ... # pragma: no cover @overload def execute_batch( @@ -509,7 +524,8 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: bool, **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover def execute_batch( self, @@ -565,7 +581,8 @@ def subscribe_async( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover + ) -> AsyncGenerator[Dict[str, Any], None]: + ... # pragma: no cover @overload def subscribe_async( @@ -578,7 +595,8 @@ def subscribe_async( *, get_execution_result: Literal[True], **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover + ) -> AsyncGenerator[ExecutionResult, None]: + ... # pragma: no cover @overload def subscribe_async( @@ -593,7 +611,8 @@ def subscribe_async( **kwargs, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: ... # pragma: no cover + ]: + ... # pragma: no cover async def subscribe_async( self, @@ -633,7 +652,8 @@ def subscribe( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover + ) -> Generator[Dict[str, Any], None, None]: + ... # pragma: no cover @overload def subscribe( @@ -646,7 +666,8 @@ def subscribe( *, get_execution_result: Literal[True], **kwargs, - ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover + ) -> Generator[ExecutionResult, None, None]: + ... # pragma: no cover @overload def subscribe( @@ -661,7 +682,8 @@ def subscribe( **kwargs, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] - ]: ... # pragma: no cover + ]: + ... # pragma: no cover def subscribe( self, @@ -931,7 +953,8 @@ def execute( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: ... # pragma: no cover + ) -> Dict[str, Any]: + ... # pragma: no cover @overload def execute( @@ -944,7 +967,8 @@ def execute( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: ... # pragma: no cover + ) -> ExecutionResult: + ... # pragma: no cover @overload def execute( @@ -957,7 +981,8 @@ def execute( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover def execute( self, @@ -1082,7 +1107,8 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[False], **kwargs, - ) -> List[Dict[str, Any]]: ... # pragma: no cover + ) -> List[Dict[str, Any]]: + ... # pragma: no cover @overload def execute_batch( @@ -1093,7 +1119,8 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: Literal[True], **kwargs, - ) -> List[ExecutionResult]: ... # pragma: no cover + ) -> List[ExecutionResult]: + ... # pragma: no cover @overload def execute_batch( @@ -1104,7 +1131,8 @@ def execute_batch( parse_result: Optional[bool] = None, get_execution_result: bool, **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover def execute_batch( self, @@ -1329,13 +1357,13 @@ async def _subscribe( ) # Subscribe to the transport - inner_generator: AsyncGenerator[ExecutionResult, None] = ( - self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, - ) + inner_generator: AsyncGenerator[ + ExecutionResult, None + ] = self.transport.subscribe( + document, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, ) # Keep a reference to the inner generator to allow the user to call aclose() @@ -1371,7 +1399,8 @@ def subscribe( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover + ) -> AsyncGenerator[Dict[str, Any], None]: + ... # pragma: no cover @overload def subscribe( @@ -1384,7 +1413,8 @@ def subscribe( *, get_execution_result: Literal[True], **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover + ) -> AsyncGenerator[ExecutionResult, None]: + ... # pragma: no cover @overload def subscribe( @@ -1399,7 +1429,8 @@ def subscribe( **kwargs, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: ... # pragma: no cover + ]: + ... # pragma: no cover async def subscribe( self, @@ -1535,7 +1566,8 @@ async def execute( *, get_execution_result: Literal[False] = ..., **kwargs, - ) -> Dict[str, Any]: ... # pragma: no cover + ) -> Dict[str, Any]: + ... # pragma: no cover @overload async def execute( @@ -1548,7 +1580,8 @@ async def execute( *, get_execution_result: Literal[True], **kwargs, - ) -> ExecutionResult: ... # pragma: no cover + ) -> ExecutionResult: + ... # pragma: no cover @overload async def execute( @@ -1561,7 +1594,8 @@ async def execute( *, get_execution_result: bool, **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover async def execute( self, diff --git a/gql/dsl.py b/gql/dsl.py index bc32d875..8f0c412c 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -6,6 +6,9 @@ import logging import re from abc import ABC, abstractmethod +from math import isfinite +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast + from graphql import ( ArgumentNode, BooleanValueNode, @@ -60,8 +63,6 @@ print_ast, ) from graphql.pyutils import inspect -from math import isfinite -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast from .utils import to_camel_case diff --git a/gql/graphql_request.py b/gql/graphql_request.py index 41504dcd..b0c68f5c 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from graphql import DocumentNode, GraphQLSchema from typing import Any, Dict, Optional +from graphql import DocumentNode, GraphQLSchema + from .utilities import serialize_variable_values diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 1269b097..0258b091 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,18 +1,18 @@ -from ssl import SSLContext - -import aiohttp import asyncio import functools import io import json import logging +from ssl import SSLContext +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union + +import aiohttp from aiohttp.client_exceptions import ClientResponseError from aiohttp.client_reqrep import Fingerprint from aiohttp.helpers import BasicAuth from aiohttp.typedefs import LooseCookies, LooseHeaders from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union from ..utils import extract_files from .appsync_auth import AppSyncAuthentication diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 607b379c..2cb0a439 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -649,8 +649,9 @@ async def _receive_data_loop(self) -> None: except (ConnectionResetError, TransportProtocolError) as e: await self._fail(e, clean_close=False) break - except TransportClosed: - break + except TransportClosed as e: + await self._fail(e, clean_close=False) + raise e # Parse the answer try: @@ -720,10 +721,7 @@ async def connect(self) -> None: finally: self._connecting = False - try: - self.response_headers = self.websocket._response.headers - except AttributeError: - self.response_headers = CIMultiDictProxy(CIMultiDict()) + self.response_headers = self.websocket._response.headers await self._after_connect() diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 655acb19..66091747 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -1,11 +1,11 @@ -from ssl import SSLContext - import json import logging -from graphql import DocumentNode, ExecutionResult, print_ast +from ssl import SSLContext from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse +from graphql import DocumentNode, ExecutionResult, print_ast + from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication from .exceptions import TransportProtocolError, TransportServerError from .websockets import WebsocketsTransport, WebsocketsTransportBase diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 2d180b65..4cecc9f9 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,7 +1,8 @@ import abc -from graphql import DocumentNode, ExecutionResult from typing import Any, AsyncGenerator, Dict, Optional +from graphql import DocumentNode, ExecutionResult + class AsyncTransport(abc.ABC): @abc.abstractmethod diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 4f8d8334..811601b8 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -1,8 +1,6 @@ -import httpx import io import json import logging -from graphql import DocumentNode, ExecutionResult, print_ast from typing import ( Any, AsyncGenerator, @@ -16,6 +14,9 @@ cast, ) +import httpx +from graphql import DocumentNode, ExecutionResult, print_ast + from ..utils import extract_files from . import AsyncTransport, Transport from .exceptions import ( diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 787e1491..04ed4ff1 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,8 +1,9 @@ import asyncio -from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe from inspect import isawaitable from typing import AsyncGenerator, Awaitable, cast +from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe + from gql.transport import AsyncTransport diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index fae8cd3a..b8226234 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,8 +1,9 @@ import asyncio import json import logging -from graphql import DocumentNode, ExecutionResult, print_ast from typing import Any, Dict, Optional, Tuple + +from graphql import DocumentNode, ExecutionResult, print_ast from websockets.exceptions import ConnectionClosed from .exceptions import ( diff --git a/gql/transport/requests.py b/gql/transport/requests.py index b6d19292..0c6eb3fc 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,13 +1,14 @@ import io import json import logging +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union + import requests from graphql import DocumentNode, ExecutionResult, print_ast from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar from requests_toolbelt.multipart.encoder import MultipartEncoder -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union from gql.transport import Transport diff --git a/gql/transport/transport.py b/gql/transport/transport.py index cb04d4d8..a5bd7100 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,7 +1,8 @@ import abc -from graphql import DocumentNode, ExecutionResult from typing import List +from graphql import DocumentNode, ExecutionResult + from ..graphql_request import GraphQLRequest diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index e127dc37..c385d3d7 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,11 +1,11 @@ -from ssl import SSLContext - import asyncio import json import logging from contextlib import suppress -from graphql import DocumentNode, ExecutionResult, print_ast +from ssl import SSLContext from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike from websockets.typing import Subprotocol diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index a952611a..45c96d3e 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -1,13 +1,13 @@ -from ssl import SSLContext - import asyncio import logging import warnings -import websockets from abc import abstractmethod from contextlib import suppress -from graphql import DocumentNode, ExecutionResult +from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast + +import websockets +from graphql import DocumentNode, ExecutionResult from websockets.client import WebSocketClientProtocol from websockets.datastructures import Headers, HeadersLike from websockets.exceptions import ConnectionClosed diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index 0abbec30..d35a2a75 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -1,6 +1,7 @@ -from graphql import DocumentNode, GraphQLSchema from itertools import repeat +from graphql import DocumentNode, GraphQLSchema + from gql.dsl import DSLFragment, DSLMetaField, DSLQuery, DSLSchema, dsl_gql diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index a8369b1a..c307d937 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -1,6 +1,7 @@ -from graphql import Node from typing import Any, Iterable, List, Optional, Sized +from graphql import Node + def _node_tree_recursive( obj: Any, diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index c626f196..02355425 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -1,4 +1,6 @@ import logging +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast + from graphql import ( IDLE, REMOVE, @@ -27,7 +29,6 @@ ) from graphql.language.visitor import VisitorActionEnum from graphql.pyutils import inspect -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast log = logging.getLogger(__name__) diff --git a/gql/utilities/serialize_variable_values.py b/gql/utilities/serialize_variable_values.py index cc8740c3..38ad1995 100644 --- a/gql/utilities/serialize_variable_values.py +++ b/gql/utilities/serialize_variable_values.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Optional + from graphql import ( DocumentNode, GraphQLEnumType, @@ -13,7 +15,6 @@ type_from_ast, ) from graphql.pyutils import inspect -from typing import Any, Dict, Optional def _get_document_operation( diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py index 2888ae08..80c73862 100644 --- a/gql/utilities/update_schema_enum.py +++ b/gql/utilities/update_schema_enum.py @@ -1,7 +1,8 @@ from enum import Enum -from graphql import GraphQLEnumType, GraphQLSchema from typing import Any, Dict, Mapping, Type, Union, cast +from graphql import GraphQLEnumType, GraphQLSchema + def update_schema_enum( schema: GraphQLSchema, diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index 8ba366b3..db3adb17 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -1,6 +1,7 @@ -from graphql import GraphQLScalarType, GraphQLSchema from typing import Iterable, List +from graphql import GraphQLScalarType, GraphQLSchema + def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): """Update the scalar in a schema with the scalar provided. diff --git a/tests/conftest.py b/tests/conftest.py index 8cdf6c9f..bd68982b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -503,6 +503,7 @@ async def client_and_graphqlws_server(graphqlws_server): # Yield both client session and server yield session, graphqlws_server + @pytest_asyncio.fixture async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): """Helper fixture to start a server with the graphql-ws prototocol @@ -523,6 +524,7 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): # Yield both client session and server yield session, graphqlws_server + @pytest_asyncio.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index 61d3a9e3..b3e717c5 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -1,5 +1,7 @@ -import pytest from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +import pytest from graphql.error import GraphQLError from graphql.language import ValueNode from graphql.pyutils import inspect @@ -15,7 +17,6 @@ GraphQLSchema, ) from graphql.utilities import value_from_ast_untyped -from typing import Any, Dict, Optional from gql import Client, gql diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 9ddc7df3..2f15a8ca 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -1,5 +1,6 @@ -import pytest from enum import Enum + +import pytest from graphql import ( GraphQLArgument, GraphQLEnumType, diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index 4c9505cc..d3eae3b8 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Optional + import pytest from graphql import ( GraphQLArgument, @@ -12,7 +14,6 @@ ) from graphql.language import ValueNode from graphql.utilities import value_from_ast_untyped -from typing import Any, Dict, Optional from gql import Client, gql from gql.dsl import DSLSchema diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 234e6cb9..c2d0e3d4 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -1,4 +1,7 @@ import asyncio +from math import isfinite +from typing import Any, Dict, NamedTuple, Optional + import pytest from graphql import ExecutionResult, graphql_sync from graphql.error import GraphQLError @@ -16,8 +19,6 @@ GraphQLSchema, ) from graphql.utilities import value_from_ast_untyped -from math import isfinite -from typing import Any, Dict, NamedTuple, Optional from gql import Client, GraphQLRequest, gql from gql.transport.exceptions import TransportQueryError diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py index d8a2f929..ccdebb4a 100644 --- a/tests/nested_input/schema.py +++ b/tests/nested_input/schema.py @@ -1,4 +1,5 @@ import json + from graphql import ( GraphQLArgument, GraphQLField, diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index ef196213..4b672ad3 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -1,4 +1,5 @@ import asyncio + from graphql import ( GraphQLArgument, GraphQLEnumType, diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 7c1be4cf..0f412acc 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -1,4 +1,5 @@ import asyncio + import pytest from graphql import ExecutionResult, GraphQLError, subscribe diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 1ed708cd..fd1f449c 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,8 +1,9 @@ import io import json -import pytest from typing import Mapping +import pytest + from gql import Client, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 74e00aee..39b8a9d2 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -1,8 +1,9 @@ import asyncio -import pytest import sys from typing import Dict +import pytest + from gql import Client, gql from gql.transport.exceptions import TransportQueryError diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index cf6111b5..645ad4c8 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -710,6 +710,7 @@ def test_code(): await run_sync_test(event_loop, server, test_code) + @pytest.mark.aiohttp_websockets @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_starwars], indirect=True) @@ -759,9 +760,7 @@ async def test_async_aiohttp_client_validation( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_closing_transport( - event_loop, server, subscription_str -): +async def test_subscribe_on_closing_transport(event_loop, server, subscription_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -775,19 +774,18 @@ async def test_subscribe_on_closing_transport( async with client as session: session.transport.websocket._writer._closing = True - + with pytest.raises(ConnectionResetError) as e: async for _ in session.subscribe(subscription): pass - + assert e.value.args[0] == "Cannot write to closing transport" + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_null_transport( - event_loop, server, subscription_str -): +async def test_subscribe_on_null_transport(event_loop, server, subscription_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -806,6 +804,5 @@ async def test_subscribe_on_null_transport( with pytest.raises(TransportClosed) as e: async for _ in session.subscribe(subscription): pass - - assert e.value.args[0] == "WebSocket connection is closed" + assert e.value.args[0] == "WebSocket connection is closed" diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 21d2c8ea..3fa9bf93 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -1,4 +1,5 @@ import json + import pytest from gql import Client, gql diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 25cbe200..e05bb6a9 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -1,10 +1,11 @@ import asyncio import json -import pytest from base64 import b64decode from typing import List from urllib import parse +import pytest + from gql import Client, gql from .conftest import MS, WebSocketServerHelper diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index b2e7588d..d39019e8 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -1,6 +1,7 @@ import asyncio -import graphql import json + +import graphql import pytest from gql import Client, gql diff --git a/tests/test_cli.py b/tests/test_cli.py index a6f0d0d8..75fe8757 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import logging + import pytest from gql import __version__ diff --git a/tests/test_client.py b/tests/test_client.py index 955c2780..ada129c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,8 @@ -import mock import os -import pytest from contextlib import suppress + +import mock +import pytest from graphql import build_ast_schema, parse from gql import Client, GraphQLRequest, gql diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index 00628e02..4c9e7d76 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -1,4 +1,7 @@ import asyncio +from math import isfinite +from typing import Any, Dict, NamedTuple, Optional + import pytest from graphql.error import GraphQLError from graphql.language import ValueNode @@ -14,8 +17,6 @@ GraphQLSchema, ) from graphql.utilities import value_from_ast_untyped -from math import isfinite -from typing import Any, Dict, NamedTuple, Optional from gql import GraphQLRequest, gql diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index 37de6e2e..b0bf37e1 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -1,7 +1,8 @@ import asyncio -import pytest from typing import List +import pytest + from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 51eb2da9..e818be35 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -1,11 +1,12 @@ import asyncio import json -import pytest import sys import warnings -from parse import search from typing import List +import pytest +from parse import search + from gql import Client, gql from gql.transport.exceptions import TransportServerError diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 95b16a54..f066a5dc 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -1,6 +1,7 @@ -import pytest from typing import Mapping +import pytest + from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 2066d964..888c025a 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1,8 +1,9 @@ import io import json -import pytest from typing import Mapping +import pytest + from gql import Client, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index dfa19fde..23d28dcc 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -1,8 +1,9 @@ import asyncio -import pytest import sys from typing import Dict +import pytest + from gql import Client, gql from gql.transport.exceptions import TransportQueryError diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index f59245e7..ce2ce996 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -1,4 +1,5 @@ import asyncio + import pytest from gql import Client, gql @@ -18,7 +19,9 @@ def ensure_list(s): return ( s if s is None or isinstance(s, list) - else list(s) if isinstance(s, tuple) else [s] + else list(s) + if isinstance(s, tuple) + else [s] ) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 127e3a20..6367945d 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -1,7 +1,8 @@ import asyncio import json -import pytest import sys + +import pytest from parse import search from gql import Client, gql diff --git a/tests/test_requests.py b/tests/test_requests.py index 7d5d237d..cfbce0d1 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,6 +1,7 @@ -import pytest from typing import Mapping +import pytest + from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 7be46fd7..bc69d37d 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -1,6 +1,7 @@ -import pytest from typing import Mapping +import pytest + from gql import Client, GraphQLRequest, gql from gql.transport.exceptions import ( TransportClosed, diff --git a/tests/test_transport.py b/tests/test_transport.py index 27730c07..e554955a 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,4 +1,5 @@ import os + import pytest from gql import Client, gql diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index abd2152e..7c108ec3 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -1,4 +1,5 @@ import os + import pytest from gql import Client, GraphQLRequest, gql diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 719a948c..94090ae0 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -1,9 +1,10 @@ import asyncio import json -import pytest import types from typing import List +import pytest + from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index 45564d55..b5fca837 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -1,9 +1,10 @@ import asyncio import logging -import pytest import sys from typing import Dict +import pytest + from gql import Client, gql from gql.transport.exceptions import TransportError, TransportQueryError diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 3fb76b58..b0b88eb6 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,11 +1,11 @@ -import ssl - import asyncio import json -import pytest +import ssl import sys from typing import Dict, Mapping +import pytest + from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 43129c4f..4419783b 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -1,11 +1,12 @@ import asyncio import json -import pytest import sys import warnings +from typing import List + +import pytest from graphql import ExecutionResult from parse import search -from typing import List from gql import Client, gql from gql.transport.exceptions import TransportServerError From f28c4fefcf70c3a85e8a58c11509f06696aa61ff Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Mon, 8 Jul 2024 20:54:55 +0000 Subject: [PATCH 27/61] remove unused import --- gql/transport/aiohttp_websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 2cb0a439..c740d500 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -21,7 +21,7 @@ from aiohttp import BasicAuth, Fingerprint, WSMsgType, hdrs from aiohttp.typedefs import LooseHeaders, StrOrURL from graphql import DocumentNode, ExecutionResult, print_ast -from multidict import CIMultiDict, CIMultiDictProxy +from multidict import CIMultiDictProxy from gql.transport.async_transport import AsyncTransport from gql.transport.exceptions import ( From 2b25cb8eea284e4f1fe04e752a1036505ba38949 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 12 Jul 2024 19:10:56 +0000 Subject: [PATCH 28/61] remove unnecessary timeout --- tests/test_aiohttp_websocket_subscription.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 645ad4c8..1884d37b 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -783,6 +783,7 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s @pytest.mark.asyncio +@pytest.mark.skip @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_subscribe_on_null_transport(event_loop, server, subscription_str): @@ -791,7 +792,7 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) url = f"ws://{server.hostname}:{server.port}/graphql" - transport = AIOHTTPWebsocketsTransport(url=url, receive_timeout=0.1) + transport = AIOHTTPWebsocketsTransport(url=url) client = Client(transport=transport) count = 1 @@ -800,7 +801,6 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: session.transport.websocket = None - with pytest.raises(TransportClosed) as e: async for _ in session.subscribe(subscription): pass From b158614aa6d60677aae6e2ce4e2b6225b06f64c8 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 12 Jul 2024 21:03:06 +0000 Subject: [PATCH 29/61] remove hanging test --- tests/test_aiohttp_websocket_subscription.py | 26 -------------------- 1 file changed, 26 deletions(-) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 1884d37b..9f8159e3 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -780,29 +780,3 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s pass assert e.value.args[0] == "Cannot write to closing transport" - - -@pytest.mark.asyncio -@pytest.mark.skip -@pytest.mark.parametrize("server", [server_countdown], indirect=True) -@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_null_transport(event_loop, server, subscription_str): - - from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport - - url = f"ws://{server.hostname}:{server.port}/graphql" - - transport = AIOHTTPWebsocketsTransport(url=url) - - client = Client(transport=transport) - count = 1 - subscription = gql(subscription_str.format(count=count)) - - async with client as session: - - session.transport.websocket = None - with pytest.raises(TransportClosed) as e: - async for _ in session.subscribe(subscription): - pass - - assert e.value.args[0] == "WebSocket connection is closed" From bae308f07657ae32c077d86698e3da526f7e3264 Mon Sep 17 00:00:00 2001 From: Taylor Lowery Date: Fri, 12 Jul 2024 21:06:55 +0000 Subject: [PATCH 30/61] re-add hanging test with skip marker and explanation --- tests/test_aiohttp_websocket_subscription.py | 26 ++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 9f8159e3..45d71a5a 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -780,3 +780,29 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s pass assert e.value.args[0] == "Cannot write to closing transport" + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="This test hangs with WebsocketsTransport or AIOHTTPWebsocketsTransport") +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_null_transport(event_loop, server, subscription_str): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + session.transport.websocket = None + with pytest.raises(TransportClosed) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "WebSocket connection is closed" From 5acb518460d41396048e7b614a63a87e563f7661 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 16:15:18 +0200 Subject: [PATCH 31/61] Modify _receive method to be more like the websockets one --- gql/transport/aiohttp_websockets.py | 47 +++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index c740d500..d56f239e 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,10 +1,10 @@ """Websockets Client for asyncio.""" import asyncio +import json import logging import warnings from contextlib import suppress -from json.decoder import JSONDecodeError from ssl import SSLContext from typing import ( Any, @@ -310,15 +310,22 @@ def _parse_answer_apollo( return answer_type, answer_id, execution_result def _parse_answer( - self, answer: Dict[str, Any] + self, answer: str ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server depending on the detected subprotocol. """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(answer) + return self._parse_answer_graphqlws(json_answer) - return self._parse_answer_apollo(answer) + return self._parse_answer_apollo(json_answer) async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" @@ -540,27 +547,27 @@ async def _send(self, message: Dict[str, Any]) -> None: await self._fail(e, clean_close=False) raise e - async def _receive(self) -> Dict[str, Any]: - log.debug("Entering _receive()") + async def _receive(self) -> str: + """Wait the next message from the websocket connection and log the answer""" + # It is possible that the websocket has been already closed in another task if self.websocket is None: - raise TransportClosed("WebSocket connection is closed") + raise TransportClosed("Transport is already closed") - try: - answer = await self.websocket.receive_json() - except TypeError as e: - answer = await self.websocket.receive() - if answer.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): - self._fail(e, clean_close=True) - raise ConnectionResetError - else: - self._fail(e, clean_close=False) - except JSONDecodeError as e: - self._fail(e) + ws_message = await self.websocket.receive() - log.info("<<< %s", answer) + if ws_message.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + raise ConnectionResetError + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") - log.debug("Exiting _receive()") + # Note: ws_message could also be a low level PING or PONG type here + # but we don't enable those + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + log.info("<<< %s", answer) return answer From 4ec25161d14d6244010cbad1d25d98eef1e5ec8b Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 16:34:42 +0200 Subject: [PATCH 32/61] Running make check --- docs/code_examples/fastapi_async.py | 1 + tests/test_aiohttp_websocket_subscription.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 80920252..3bedd187 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,6 +10,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse + from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 45d71a5a..e41a320b 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -783,7 +783,9 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s @pytest.mark.asyncio -@pytest.mark.skip(reason="This test hangs with WebsocketsTransport or AIOHTTPWebsocketsTransport") +@pytest.mark.skip( + reason="This test hangs with WebsocketsTransport or AIOHTTPWebsocketsTransport" +) @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_subscribe_on_null_transport(event_loop, server, subscription_str): From 6336eaa438b873a2f383438cff147e2c4c0a9649 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 17:32:56 +0200 Subject: [PATCH 33/61] Revert all code changes outside of the scope of this PR --- docs/code_examples/httpx_async_trio.py | 1 + gql/cli.py | 3 +- gql/client.py | 8 ++--- gql/dsl.py | 9 ++---- gql/transport/aiohttp.py | 6 ++-- tests/custom_scalars/test_money.py | 12 ++++---- tests/starwars/fixtures.py | 3 +- tests/test_aiohttp.py | 37 ++---------------------- tests/test_appsync_auth.py | 7 ++--- tests/test_appsync_http.py | 3 +- tests/test_appsync_websockets.py | 15 ++++------ tests/test_cli.py | 4 +-- tests/test_graphqlws_exceptions.py | 1 - tests/test_graphqlws_subscription.py | 3 +- tests/test_httpx.py | 17 ----------- tests/test_phoenix_channel_exceptions.py | 3 +- tests/test_requests.py | 19 ------------ tests/test_requests_batch.py | 15 +--------- tests/test_transport.py | 1 - tests/test_transport_batch.py | 1 - tests/test_websocket_exceptions.py | 4 +-- tests/test_websocket_query.py | 5 +--- 22 files changed, 35 insertions(+), 142 deletions(-) diff --git a/docs/code_examples/httpx_async_trio.py b/docs/code_examples/httpx_async_trio.py index b76dab42..058b952b 100644 --- a/docs/code_examples/httpx_async_trio.py +++ b/docs/code_examples/httpx_async_trio.py @@ -1,4 +1,5 @@ import trio + from gql import Client, gql from gql.transport.httpx import HTTPXAsyncTransport diff --git a/gql/cli.py b/gql/cli.py index 55e03ccb..dd991546 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -358,9 +358,8 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt) else: - from botocore.exceptions import NoRegionError - from gql.transport.appsync_auth import AppSyncIAMAuthentication + from botocore.exceptions import NoRegionError try: auth = AppSyncIAMAuthentication(host=url.host) diff --git a/gql/client.py b/gql/client.py index dd9c2c5b..0d9e36c7 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1075,11 +1075,9 @@ def _execute_batch( serialize_variables is None and self.client.serialize_variables ): requests = [ - ( - req.serialize_variable_values(self.client.schema) - if req.variable_values is not None - else req - ) + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req for req in requests ] diff --git a/gql/dsl.py b/gql/dsl.py index 8f0c412c..536a8b8b 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -2,7 +2,6 @@ .. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa :alt: UML diagram """ - import logging import re from abc import ABC, abstractmethod @@ -596,11 +595,9 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: VariableDefinitionNode( type=var.ast_variable_type, variable=var.ast_variable_name, - default_value=( - None - if var.default_value is None - else ast_from_value(var.default_value, var.type) - ), + default_value=None + if var.default_value is None + else ast_from_value(var.default_value, var.type), directives=(), ) for var in self.variables.values() diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 0258b091..be22ce9c 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -101,9 +101,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": ( - None if isinstance(self.auth, AppSyncAuthentication) else self.auth - ), + "auth": None + if isinstance(self.auth, AppSyncAuthentication) + else self.auth, "json_serialize": self.json_serialize, } diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index c2d0e3d4..374c70e6 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -441,9 +441,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: [ { "data": result.data, - "errors": ( - [str(e) for e in result.errors] if result.errors else None - ), + "errors": [str(e) for e in result.errors] + if result.errors + else None, } for result in results ] @@ -453,9 +453,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: return web.json_response( { "data": result.data, - "errors": ( - [str(e) for e in result.errors] if result.errors else None - ), + "errors": [str(e) for e in result.errors] + if result.errors + else None, } ) diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 1d179f60..59d7ddfa 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -148,9 +148,8 @@ def create_review(episode, review): async def make_starwars_backend(aiohttp_server): from aiohttp import web - from graphql import graphql_sync - from .schema import StarWarsSchema + from graphql import graphql_sync async def handler(request): data = await request.json() diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index fd1f449c..b16964d0 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -43,7 +43,6 @@ @pytest.mark.asyncio async def test_aiohttp_query(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -83,7 +82,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_ignore_backend_content_type(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -113,7 +111,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cookies(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -147,7 +144,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_401(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -179,7 +175,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_429(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -227,7 +222,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_500(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -263,7 +257,6 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_aiohttp_error_code(event_loop, aiohttp_server, query_error): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -319,7 +312,6 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_aiohttp_invalid_protocol(event_loop, aiohttp_server, param): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport response = param["response"] @@ -348,7 +340,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_subscribe_not_supported(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -374,7 +365,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_connect_twice(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -397,7 +387,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_execute_if_not_connected(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -420,7 +409,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_extra_args(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -468,7 +456,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_variable_values(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -505,7 +492,6 @@ async def test_aiohttp_query_variable_values_fix_issue_292(event_loop, aiohttp_s See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -538,7 +524,6 @@ async def test_aiohttp_execute_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -567,7 +552,6 @@ async def test_aiohttp_subscribe_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -654,7 +638,6 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -720,7 +703,6 @@ async def single_upload_handler_with_content_type(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -761,7 +743,6 @@ async def test_aiohttp_file_upload_without_session( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -830,7 +811,6 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -865,8 +845,7 @@ async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.asyncio async def test_aiohttp_stream_reader_upload(event_loop, aiohttp_server): - from aiohttp import ClientSession, web - + from aiohttp import web, ClientSession from gql.transport.aiohttp import AIOHTTPTransport async def binary_data_handler(request): @@ -905,7 +884,6 @@ async def binary_data_handler(request): async def test_aiohttp_async_generator_upload(event_loop, aiohttp_server): import aiofiles from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -968,7 +946,6 @@ async def file_sender(file_name): @pytest.mark.asyncio async def test_aiohttp_file_upload_two_files(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1060,7 +1037,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_list_of_two_files(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1282,7 +1258,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1311,7 +1286,6 @@ async def handler(request): @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1346,7 +1320,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport error_answer = """ @@ -1390,7 +1363,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_reconnecting_session(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1431,7 +1403,6 @@ async def test_aiohttp_reconnecting_session_retries( event_loop, aiohttp_server, retries ): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1465,7 +1436,6 @@ async def test_aiohttp_reconnecting_session_start_connecting_task_twice( event_loop, aiohttp_server, caplog ): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1499,7 +1469,6 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_json_serializer(event_loop, aiohttp_server, caplog): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1558,7 +1527,6 @@ async def test_aiohttp_json_deserializer(event_loop, aiohttp_server): from aiohttp import web from decimal import Decimal from functools import partial - from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1595,8 +1563,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): - from aiohttp import TCPConnector, web - + from aiohttp import web, TCPConnector from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 89591426..cb279ae5 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -23,7 +23,6 @@ def test_appsync_init_with_minimal_args(fake_session_factory): @pytest.mark.botocore def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): @@ -73,7 +72,6 @@ def test_appsync_init_with_apikey_auth(): @pytest.mark.botocore def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions - from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -110,10 +108,9 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have the AWS_DEFAULT_REGION environment variable set """ - import logging - from botocore.exceptions import NoRegionError - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.exceptions import NoRegionError + import logging caplog.set_level(logging.WARNING) diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 3fa9bf93..ca3a3fcb 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -12,10 +12,9 @@ async def test_appsync_iam_mutation( event_loop, aiohttp_server, fake_credentials_factory ): from aiohttp import web - from urllib.parse import urlparse - from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication + from urllib.parse import urlparse async def handler(request): data = { diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index e05bb6a9..14c40e75 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -424,10 +424,9 @@ async def test_appsync_subscription_api_key(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_with_token(event_loop, server): - from botocore.credentials import Credentials - from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -452,10 +451,9 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_without_token(event_loop, server): - from botocore.credentials import Credentials - from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -479,10 +477,9 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_execute_method_not_allowed(event_loop, server): - from botocore.credentials import Credentials - from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -527,10 +524,9 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): @pytest.mark.botocore async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): - from botocore.credentials import Credentials - from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials dummy_credentials = Credentials( access_key=DUMMY_ACCESS_KEY_ID, @@ -581,11 +577,10 @@ async def test_appsync_subscription_api_key_unauthorized(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_not_allowed(event_loop, server): - from botocore.credentials import Credentials - from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportQueryError + from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" diff --git a/tests/test_cli.py b/tests/test_cli.py index 75fe8757..f0534957 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -270,8 +270,8 @@ async def test_cli_main_appsync_websockets_iam(parser, url): ) def test_cli_get_transport_appsync_websockets_api_key(parser, url): - from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication args = parser.parse_args( [url, "--transport", "appsync_websockets", "--api-key", "test-api-key"] @@ -291,8 +291,8 @@ def test_cli_get_transport_appsync_websockets_api_key(parser, url): ) def test_cli_get_transport_appsync_websockets_jwt(parser, url): - from gql.transport.appsync_auth import AppSyncJWTAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.appsync_auth import AppSyncJWTAuthentication args = parser.parse_args( [url, "--transport", "appsync_websockets", "--jwt", "test-jwt"] diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index b0bf37e1..ca689c47 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -234,7 +234,6 @@ async def server_closing_directly(ws, path): @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): import websockets - from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index e818be35..cb705368 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -816,9 +816,8 @@ async def test_graphqlws_subscription_reconnecting_session( ): import websockets - - from gql.transport.exceptions import TransportClosed from gql.transport.websockets import WebsocketsTransport + from gql.transport.exceptions import TransportClosed path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" diff --git a/tests/test_httpx.py b/tests/test_httpx.py index f066a5dc..af12f717 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -38,7 +38,6 @@ @pytest.mark.asyncio async def test_httpx_query(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -82,7 +81,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -120,7 +118,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_401(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -156,7 +153,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_429(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -206,7 +202,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_500(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -239,7 +234,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -280,7 +274,6 @@ async def test_httpx_invalid_protocol( event_loop, aiohttp_server, response, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -309,7 +302,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -338,7 +330,6 @@ async def test_httpx_cannot_execute_if_not_connected( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -376,7 +367,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_query_with_extensions(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -432,7 +422,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -495,7 +484,6 @@ async def test_httpx_file_upload_with_content_type( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -564,7 +552,6 @@ async def test_httpx_file_upload_additional_headers( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -627,7 +614,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_binary_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport # This is a sample binary file content containing all possible byte values @@ -701,7 +687,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport file_upload_mutation_2 = """ @@ -802,7 +787,6 @@ async def test_httpx_file_upload_list_of_two_files( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXTransport file_upload_mutation_3 = """ @@ -892,7 +876,6 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport error_answer = """ diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index ce2ce996..e2bf0091 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -360,9 +360,8 @@ def subscription_server( data_answers=default_subscription_data_answer, unsubscribe_answers=default_subscription_unsubscribe_answer, ): - import json - from .conftest import PhoenixChannelServerHelper + import json async def phoenix_server(ws, path): await PhoenixChannelServerHelper.send_connection_ack(ws) diff --git a/tests/test_requests.py b/tests/test_requests.py index cfbce0d1..ba666243 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -38,7 +38,6 @@ @pytest.mark.asyncio async def test_requests_query(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -82,7 +81,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -120,7 +118,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -156,7 +153,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -206,7 +202,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -239,7 +234,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -280,7 +274,6 @@ async def test_requests_invalid_protocol( event_loop, aiohttp_server, response, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -309,7 +302,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -338,7 +330,6 @@ async def test_requests_cannot_execute_if_not_connected( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -378,7 +369,6 @@ async def test_requests_query_with_extensions( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -434,7 +424,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -497,7 +486,6 @@ async def test_requests_file_upload_with_content_type( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -566,7 +554,6 @@ async def test_requests_file_upload_additional_headers( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -629,7 +616,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_binary_file_upload(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport # This is a sample binary file content containing all possible byte values @@ -705,7 +691,6 @@ async def test_requests_file_upload_two_files( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_2 = """ @@ -806,7 +791,6 @@ async def test_requests_file_upload_list_of_two_files( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_3 = """ @@ -898,7 +882,6 @@ async def test_requests_error_fetching_schema( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport error_answer = """ @@ -949,7 +932,6 @@ async def test_requests_json_serializer( ): import json from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -1012,7 +994,6 @@ async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_t from aiohttp import web from decimal import Decimal from functools import partial - from gql.transport.requests import RequestsHTTPTransport async def handler(request): diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index bc69d37d..4d8bf27e 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -50,7 +50,6 @@ @pytest.mark.asyncio async def test_requests_query(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -96,7 +95,6 @@ async def test_requests_query_auto_batch_enabled( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -145,9 +143,8 @@ async def test_requests_query_auto_batch_enabled_two_requests( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from threading import Thread - from gql.transport.requests import RequestsHTTPTransport + from threading import Thread async def handler(request): return web.Response( @@ -204,7 +201,6 @@ def test_thread(): @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -244,7 +240,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -282,7 +277,6 @@ async def test_requests_error_code_401_auto_batch_enabled( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -321,7 +315,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -371,7 +364,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -404,7 +396,6 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -450,7 +441,6 @@ async def test_requests_invalid_protocol( event_loop, aiohttp_server, response, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -481,7 +471,6 @@ async def test_requests_cannot_execute_if_not_connected( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -523,7 +512,6 @@ async def test_requests_query_with_extensions( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -563,7 +551,6 @@ def test_code(): def test_requests_sync_batch_auto(): from threading import Thread - from gql.transport.requests import RequestsHTTPTransport client = Client( diff --git a/tests/test_transport.py b/tests/test_transport.py index e554955a..d9a3eced 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -28,7 +28,6 @@ def use_cassette(name): @pytest.fixture def client(): import requests - from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index 7c108ec3..a9b21e6a 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -28,7 +28,6 @@ def use_cassette(name): @pytest.fixture def client(): import requests - from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 94090ae0..72db8a87 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -273,7 +273,6 @@ async def server_closing_directly(ws, path): @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_websocket_server_closing_directly(event_loop, server): import websockets - from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -373,10 +372,9 @@ async def test_websocket_using_cli_invalid_query( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") + from gql.cli import main, get_parser import io - from gql.cli import get_parser, main - parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index b0b88eb6..e8b7a022 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -53,7 +53,6 @@ @pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_starting_client_in_context_manager(event_loop, server): import websockets - from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -94,7 +93,6 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): import websockets - from gql.transport.websockets import WebsocketsTransport server = ws_ssl_server @@ -549,11 +547,10 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") + from gql.cli import main, get_parser import io import json - from gql.cli import get_parser, main - parser = get_parser(with_examples=True) args = parser.parse_args([url]) From 0736dc04777e82ff99a63429c7a352b14bfa6a07 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 18:02:56 +0200 Subject: [PATCH 34/61] Fix aiohttp_websocket ssl attribute typing --- gql/transport/aiohttp_websockets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index d56f239e..89bb10ee 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -11,6 +11,7 @@ AsyncGenerator, Collection, Dict, + Literal, Mapping, Optional, Tuple, @@ -60,7 +61,7 @@ def __init__( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, Fingerprint] = True, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, ssl_context: Optional[SSLContext] = None, verify_ssl: Optional[bool] = True, proxy_headers: Optional[LooseHeaders] = None, @@ -91,7 +92,7 @@ def __init__( self.proxy_auth: Optional[BasicAuth] = proxy_auth self.proxy_headers: Optional[LooseHeaders] = proxy_headers self.receive_timeout: Optional[float] = receive_timeout - self.ssl: Union[SSLContext, bool, Fingerprint] = ssl + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl self.ssl_context: Optional[SSLContext] = ssl_context self.timeout: float = timeout self.verify_ssl: Optional[bool] = verify_ssl From 172d65d22084186f1cdaa24817abe5548a9ca7ee Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 18:05:50 +0200 Subject: [PATCH 35/61] Revert all code changes outside of the scope of this PR - 2 --- docs/code_examples/fastapi_async.py | 1 - tests/test_httpx_async.py | 31 +---------------------------- 2 files changed, 1 insertion(+), 31 deletions(-) diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 3bedd187..80920252 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,7 +10,6 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse - from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 888c025a..3665f5d8 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -44,7 +44,6 @@ @pytest.mark.asyncio async def test_httpx_query(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -85,7 +84,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_ignore_backend_content_type(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -116,7 +114,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cookies(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -151,7 +148,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_401(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -184,7 +180,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_429(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -233,7 +228,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_500(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -270,7 +264,6 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_httpx_error_code(event_loop, aiohttp_server, query_error): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -327,7 +320,6 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_httpx_invalid_protocol(event_loop, aiohttp_server, param): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport response = param["response"] @@ -357,7 +349,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_subscribe_not_supported(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -384,7 +375,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -408,7 +398,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_execute_if_not_connected(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -431,10 +420,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_extra_args(event_loop, aiohttp_server): - import httpx from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport + import httpx async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -478,7 +466,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_variable_values(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -516,7 +503,6 @@ async def test_httpx_query_variable_values_fix_issue_292(event_loop, aiohttp_ser See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -550,7 +536,6 @@ async def test_httpx_execute_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -580,7 +565,6 @@ async def test_httpx_subscribe_running_in_thread( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -668,7 +652,6 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_httpx_file_upload(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -707,7 +690,6 @@ async def test_httpx_file_upload_without_session( event_loop, aiohttp_server, run_sync_test ): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -777,7 +759,6 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_httpx_binary_file_upload(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -836,7 +817,6 @@ async def test_httpx_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -929,7 +909,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1156,7 +1135,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_with_extensions(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1185,7 +1163,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_https(event_loop, ssl_aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1221,7 +1198,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport error_answer = """ @@ -1266,7 +1242,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_reconnecting_session(event_loop, aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1306,7 +1281,6 @@ async def handler(request): @pytest.mark.parametrize("retries", [False, lambda e: e]) async def test_httpx_reconnecting_session_retries(event_loop, aiohttp_server, retries): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1341,7 +1315,6 @@ async def test_httpx_reconnecting_session_start_connecting_task_twice( event_loop, aiohttp_server, caplog ): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1376,7 +1349,6 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_json_serializer(event_loop, aiohttp_server, caplog): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1436,7 +1408,6 @@ async def test_httpx_json_deserializer(event_loop, aiohttp_server): from aiohttp import web from decimal import Decimal from functools import partial - from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): From 5653d56797243367aa04d445d819756cf529f89e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 18:09:20 +0200 Subject: [PATCH 36/61] Remove invalid aiohttp_websockets dependency from github workflow --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 62fbe24a..30e8289c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false matrix: - dependency: ["aiohttp", "requests", "httpx", "websockets", "aiohttp_websockets"] + dependency: ["aiohttp", "requests", "httpx", "websockets"] steps: - uses: actions/checkout@v3 From f1fb25d59e26a3d83e7aa35f0ae5660a94a79976 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 18:44:52 +0200 Subject: [PATCH 37/61] Add the aiohttp AND websockets mark to the aiohttp_websocket tests --- tests/test_aiohttp_websocket_exceptions.py | 4 ++-- tests/test_aiohttp_websocket_graphql_exceptions.py | 4 ++-- tests/test_aiohttp_websocket_graphql_subscription.py | 4 ++-- tests/test_aiohttp_websocket_query.py | 4 ++-- tests/test_aiohttp_websocket_subscription.py | 7 +++---- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 48312544..7c1b35ed 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -15,8 +15,8 @@ from .conftest import MS, WebSocketServerHelper -# Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.aiohttp_websockets +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] invalid_query_str = """ query getContinents { diff --git a/tests/test_aiohttp_websocket_graphql_exceptions.py b/tests/test_aiohttp_websocket_graphql_exceptions.py index d49ee7d0..577ddc6b 100644 --- a/tests/test_aiohttp_websocket_graphql_exceptions.py +++ b/tests/test_aiohttp_websocket_graphql_exceptions.py @@ -12,8 +12,8 @@ from .conftest import WebSocketServerHelper -# Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.aiohttp_websockets +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] invalid_query_str = """ query getContinents { diff --git a/tests/test_aiohttp_websocket_graphql_subscription.py b/tests/test_aiohttp_websocket_graphql_subscription.py index 1c3348bc..bb5529a1 100644 --- a/tests/test_aiohttp_websocket_graphql_subscription.py +++ b/tests/test_aiohttp_websocket_graphql_subscription.py @@ -12,8 +12,8 @@ from .conftest import MS, WebSocketServerHelper -# Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.aiohttp_websockets +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] countdown_server_answer = ( '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 2899038e..e119af5a 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -16,8 +16,8 @@ from .conftest import MS, WebSocketServerHelper -# Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.aiohttp_websockets +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] query1_str = """ query getContinents { diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index e41a320b..1eaf7c98 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -14,6 +14,9 @@ from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + starwars_expected_one = { "stars": 3, "commentary": "Was expecting more stuff", @@ -77,9 +80,6 @@ async def server_starwars(ws, path): } """ -# Marking all tests in this file with the websockets marker -pytestmark = pytest.mark.aiohttp_websockets - countdown_server_answer = ( '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' ) @@ -711,7 +711,6 @@ def test_code(): await run_sync_test(event_loop, server, test_code) -@pytest.mark.aiohttp_websockets @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_starwars], indirect=True) @pytest.mark.parametrize("subscription_str", [starwars_subscription_str]) From 3a762aec814663104d9487390f4f48fc93810e0b Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 18:46:26 +0200 Subject: [PATCH 38/61] Copy ListenerQueue in aiohttp_websockets.py It is necessary to avoid needing the websockets dependency when using the AIOHTTPWebsockets transport --- gql/transport/aiohttp_websockets.py | 55 ++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 89bb10ee..f4f29772 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -32,10 +32,63 @@ TransportQueryError, TransportServerError, ) -from gql.transport.websockets_base import ListenerQueue log = logging.getLogger("gql.transport.aiohttp_websockets") +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + class AIOHTTPWebsocketsTransport(AsyncTransport): From fcd59dca1d1c74c263ad22b6989a3e21b46dadbe Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 18:56:31 +0200 Subject: [PATCH 39/61] Fix importing Literal for Python 3.7 --- gql/transport/aiohttp_websockets.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index f4f29772..1224c873 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -3,6 +3,7 @@ import asyncio import json import logging +import sys import warnings from contextlib import suppress from ssl import SSLContext @@ -11,7 +12,6 @@ AsyncGenerator, Collection, Dict, - Literal, Mapping, Optional, Tuple, @@ -33,6 +33,16 @@ TransportServerError, ) +""" +Load the appropriate instance of the Literal type +Note: we cannot use try: except ImportError because of the following mypy issue: +https://github.com/python/mypy/issues/8520 +""" +if sys.version_info[:2] >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # pragma: no cover + log = logging.getLogger("gql.transport.aiohttp_websockets") ParsedAnswer = Tuple[str, Optional[ExecutionResult]] From 64d6ffd294bf2069d470c2a8aec010d73f1a4345 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 23:45:51 +0200 Subject: [PATCH 40/61] assertion instead of TransportClosed exception if websocket is None --- gql/transport/aiohttp_websockets.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 1224c873..1c9b7d42 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -614,9 +614,7 @@ async def _send(self, message: Dict[str, Any]) -> None: async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer""" - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") + assert self.websocket is not None ws_message = await self.websocket.receive() @@ -720,9 +718,6 @@ async def _receive_data_loop(self) -> None: except (ConnectionResetError, TransportProtocolError) as e: await self._fail(e, clean_close=False) break - except TransportClosed as e: - await self._fail(e, clean_close=False) - raise e # Parse the answer try: From 87b55736ef71134819609c28142e54f8188807ff Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 13 Jul 2024 23:50:37 +0200 Subject: [PATCH 41/61] small cleaning --- gql/transport/aiohttp_websockets.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 1c9b7d42..9ea7c0cd 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,5 +1,3 @@ -"""Websockets Client for asyncio.""" - import asyncio import json import logging @@ -443,7 +441,6 @@ async def _after_connect(self): has been established. """ # Find the backend subprotocol returned in the response headers - # TODO: find the equivalent of response_headers in aiohttp websocket response response_headers = self.websocket._response.headers log.debug(f"Response headers: {response_headers!r}") try: From 270d8ecc896e7e355cf6fa0cea3b12bff9b7936f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 14 Jul 2024 00:39:32 +0200 Subject: [PATCH 42/61] Close aiohttp session properly, inspired from aiohttp transport --- gql/transport/aiohttp_websockets.py | 39 +++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 9ea7c0cd..dc4ee757 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -22,6 +22,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy +from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.async_transport import AsyncTransport from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -113,6 +114,7 @@ def __init__( protocols: Collection[str] = (), timeout: float = 10.0, receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, autoclose: bool = True, autoping: bool = True, heartbeat: Optional[float] = None, @@ -136,6 +138,7 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, + client_session_args: Optional[Dict[str, Any]] = None, ) -> None: self.url: StrOrURL = url self.headers: Optional[LooseHeaders] = headers @@ -153,6 +156,7 @@ def __init__( self.proxy_auth: Optional[BasicAuth] = proxy_auth self.proxy_headers: Optional[LooseHeaders] = proxy_headers self.receive_timeout: Optional[float] = receive_timeout + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl self.ssl_context: Optional[SSLContext] = ssl_context self.timeout: float = timeout @@ -222,6 +226,7 @@ def __init__( self.GRAPHQLWS_SUBPROTOCOL, ) self.close_exception: Optional[Exception] = None + self.client_session_args = client_session_args def _parse_answer_graphqlws( self, answer: Dict[str, Any] @@ -753,7 +758,13 @@ async def connect(self) -> None: log.debug("connect: starting") if self.session is None: - self.session = aiohttp.ClientSession() + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.client_session_args: + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) if self.websocket is None and not self._connecting: self._connecting = True @@ -893,8 +904,32 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: close websocket connection") await self.websocket.close() + self.websocket = None + + log.debug("_close_coro: close aiohttp session") + + if ( + self.client_session_args + and self.client_session_args.get("connector_owner") is False + ): + + log.debug("connector_owner is False -> not closing connector") + + else: + assert self.session is not None + + closed_event = AIOHTTPTransport.create_aiohttp_closed_event( + self.session + ) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + + self.session = None - log.debug("_close_coro: websocket connection closed") + log.debug("_close_coro: aiohttp session closed") except Exception as exc: # pragma: no cover log.warning("Exception catched in _close_coro: " + repr(exc)) From ddcc77f11e17d9f256ef0b1de335cc6c7c458b8e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 14 Jul 2024 16:44:29 +0200 Subject: [PATCH 43/61] Add test for ssl_close_timeout parameter --- tests/test_aiohttp_websocket_query.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index e119af5a..a69bca32 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -87,7 +87,10 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(event_loop, @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_using_ssl_connection(event_loop, ws_ssl_server): +@pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +async def test_aiohttp_websocket_using_ssl_connection( + event_loop, ws_ssl_server, ssl_close_timeout +): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -99,7 +102,9 @@ async def test_aiohttp_websocket_using_ssl_connection(event_loop, ws_ssl_server) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(ws_ssl_server.testcert) - transport = AIOHTTPWebsocketsTransport(url=url, ssl=ssl_context) + transport = AIOHTTPWebsocketsTransport( + url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout + ) async with Client(transport=transport) as session: From 05e61b860106828aa821c385c64a8338b5c868a2 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 14 Jul 2024 17:00:01 +0200 Subject: [PATCH 44/61] Add test for client_session_args and connector_owner_false --- tests/test_aiohttp_websocket_query.py | 41 +++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index a69bca32..9d72b6c6 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -612,3 +612,44 @@ async def test_aiohttp_websocket_simple_query_with_extensions( execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_connector_owner_false(event_loop, server): + from aiohttp import TCPConnector + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + connector = TCPConnector() + transport = AIOHTTPWebsocketsTransport( + url=url, + timeout=10, + client_session_args={ + "connector": connector, + "connector_owner": False, + }, + ) + + for _ in range(2): + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + await connector.close() From 2af7d40e6dda0f925710fe719b2d7db8afe8338f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 14 Jul 2024 18:59:10 +0200 Subject: [PATCH 45/61] Fix handling of connection_error during init --- gql/transport/aiohttp_websockets.py | 6 +++++- tests/test_aiohttp_websocket_query.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index dc4ee757..ab250937 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -810,7 +810,11 @@ async def connect(self) -> None: await self._initialize() except ConnectionResetError as e: raise e - except (TransportProtocolError, asyncio.TimeoutError) as e: + except ( + TransportProtocolError, + TransportServerError, + asyncio.TimeoutError, + ) as e: await self._fail(e, clean_close=False) raise e diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 9d72b6c6..7bc603b5 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -438,11 +438,15 @@ async def test_aiohttp_websocket_connect_failed_with_authentication_in_connectio transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) - with pytest.raises(TransportServerError): - async with Client(transport=transport) as session: - query1 = gql(query_str) + for _ in range(2): + with pytest.raises(TransportServerError): + async with Client(transport=transport) as session: + query1 = gql(query_str) + + await session.execute(query1) - await session.execute(query1) + assert transport.session is None + assert transport.websocket is None @pytest.mark.parametrize("server", [server1_answers], indirect=True) From b529ea2217ed8ebe399bf3c9fc4829612afe8e4f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 12:12:35 +0200 Subject: [PATCH 46/61] Fix RuntimeError event loop is closed exception shown during tests --- tests/test_aiohttp_websocket_exceptions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 7c1b35ed..ea48824f 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,6 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportQueryError, @@ -366,8 +365,10 @@ async def client_connect(client): connect_task1 = asyncio.ensure_future(client_connect(client)) connect_task2 = asyncio.ensure_future(client_connect(client)) - with pytest.raises(TransportAlreadyConnected): - await asyncio.gather(connect_task1, connect_task2) + result = await asyncio.gather(connect_task1, connect_task2, return_exceptions=True) + + assert result[0] is None + assert type(result[1]).__name__ == "TransportAlreadyConnected" @pytest.mark.asyncio From 170846f4c830f1ce0da93e4aa47250c043a2459c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 12:36:10 +0200 Subject: [PATCH 47/61] Revert "assertion instead of TransportClosed exception if websocket is None" This reverts commit 64d6ffd294bf2069d470c2a8aec010d73f1a4345. --- gql/transport/aiohttp_websockets.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index ab250937..da34033d 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -616,7 +616,9 @@ async def _send(self, message: Dict[str, Any]) -> None: async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer""" - assert self.websocket is not None + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") ws_message = await self.websocket.receive() @@ -720,6 +722,9 @@ async def _receive_data_loop(self) -> None: except (ConnectionResetError, TransportProtocolError) as e: await self._fail(e, clean_close=False) break + except TransportClosed as e: + await self._fail(e, clean_close=False) + raise e # Parse the answer try: From 9a10f0d1c467440eff2130a8c4990d8e9cb6f0ec Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 12:42:28 +0200 Subject: [PATCH 48/61] Using _wait_closed.is_set() to determine if we need to wait or close the transport --- gql/transport/aiohttp_websockets.py | 7 ++++--- tests/test_aiohttp_websocket_subscription.py | 3 --- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index da34033d..6502fe4c 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -959,8 +959,8 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: if self.close_task is None: - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") + if self._wait_closed.is_set(): + log.debug("_fail started but transport is already closed") else: self.close_task = asyncio.shield( asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) @@ -984,7 +984,8 @@ async def close(self) -> None: async def wait_closed(self) -> None: log.debug("wait_close: starting") - await self._wait_closed.wait() + if not self._wait_closed.is_set(): + await self._wait_closed.wait() log.debug("wait_close: done") diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 1eaf7c98..c5b6d504 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -782,9 +782,6 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s @pytest.mark.asyncio -@pytest.mark.skip( - reason="This test hangs with WebsocketsTransport or AIOHTTPWebsocketsTransport" -) @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_subscribe_on_null_transport(event_loop, server, subscription_str): From 4aaf86a7e2f87b90afcad20ddfc4ed0d45c2d3f5 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 15:30:35 +0200 Subject: [PATCH 49/61] Ignore low-level ping pong messages --- gql/transport/aiohttp_websockets.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 6502fe4c..2dd5ec20 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -620,15 +620,23 @@ async def _receive(self) -> str: if self.websocket is None: raise TransportClosed("Transport is already closed") - ws_message = await self.websocket.receive() + while True: + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break - if ws_message.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): raise ConnectionResetError elif ws_message.type is WSMsgType.BINARY: raise TransportProtocolError("Binary data received in the websocket") - # Note: ws_message could also be a low level PING or PONG type here - # but we don't enable those assert ws_message.type is WSMsgType.TEXT answer: str = ws_message.data From 277fd5d7d1a789f2f6b1d02df1b082c31f50be96 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 16:17:18 +0200 Subject: [PATCH 50/61] Rename timeout to websocket_close_timeout, reorder init params --- gql/transport/aiohttp_websockets.py | 20 ++++++++++---------- tests/test_aiohttp_websocket_query.py | 3 +-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 2dd5ec20..9c28f233 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -112,9 +112,6 @@ def __init__( *, method: str = hdrs.METH_GET, protocols: Collection[str] = (), - timeout: float = 10.0, - receive_timeout: Optional[float] = None, - ssl_close_timeout: Optional[Union[int, float]] = 10, autoclose: bool = True, autoping: bool = True, heartbeat: Optional[float] = None, @@ -130,6 +127,9 @@ def __init__( proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, @@ -155,11 +155,15 @@ def __init__( self.proxy: Optional[StrOrURL] = proxy self.proxy_auth: Optional[BasicAuth] = proxy_auth self.proxy_headers: Optional[LooseHeaders] = proxy_headers - self.receive_timeout: Optional[float] = receive_timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl self.ssl_context: Optional[SSLContext] = ssl_context - self.timeout: float = timeout + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout self.verify_ssl: Optional[bool] = verify_ssl self.init_payload: Dict[str, Any] = init_payload @@ -175,10 +179,6 @@ def __init__( self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() @@ -802,7 +802,7 @@ async def connect(self) -> None: receive_timeout=self.receive_timeout, ssl=self.ssl, ssl_context=None, - timeout=self.timeout, + timeout=self.websocket_close_timeout, verify_ssl=self.verify_ssl, ) finally: diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 7bc603b5..8d6fbab9 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -58,7 +58,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(event_loop, url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = AIOHTTPWebsocketsTransport(url=url, timeout=10) + transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) async with Client(transport=transport) as session: @@ -630,7 +630,6 @@ async def test_aiohttp_websocket_connector_owner_false(event_loop, server): connector = TCPConnector() transport = AIOHTTPWebsocketsTransport( url=url, - timeout=10, client_session_args={ "connector": connector, "connector_owner": False, From a8b276dd2714277d032287563b41a111b8fbcb50 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 17:04:37 +0200 Subject: [PATCH 51/61] Modify transport init parameters Adding connect_args parameter to be able to provide any argument to the ws_connect method Removing the following parameters (they can now be provided in the connect_args dict): - autoclose - autoping - compress - max_msg_size - verify_ssl - method Renaming protocols to subprotocols to be more similar to the websockets transport --- gql/transport/aiohttp_websockets.py | 51 +++++++++++++-------------- tests/conftest.py | 2 +- tests/test_aiohttp_websocket_query.py | 6 ++-- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 9c28f233..6186610f 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -17,7 +17,7 @@ ) import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType, hdrs +from aiohttp import BasicAuth, Fingerprint, WSMsgType from aiohttp.typedefs import LooseHeaders, StrOrURL from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy @@ -110,10 +110,7 @@ def __init__( self, url: StrOrURL, *, - method: str = hdrs.METH_GET, - protocols: Collection[str] = (), - autoclose: bool = True, - autoping: bool = True, + subprotocols: Optional[Collection[str]] = None, heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, @@ -121,12 +118,9 @@ def __init__( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, ssl_context: Optional[SSLContext] = None, - verify_ssl: Optional[bool] = True, - proxy_headers: Optional[LooseHeaders] = None, - compress: int = 0, - max_msg_size: int = 4 * 1024 * 1024, websocket_close_timeout: float = 10.0, receive_timeout: Optional[float] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, @@ -139,32 +133,31 @@ def __init__( pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Dict[str, Any] = {}, ) -> None: self.url: StrOrURL = url - self.headers: Optional[LooseHeaders] = headers - self.auth: Optional[BasicAuth] = auth - self.autoclose: bool = autoclose - self.autoping: bool = autoping - self.compress: int = compress self.heartbeat: Optional[float] = heartbeat - self.max_msg_size: int = max_msg_size - self.method: str = method + self.auth: Optional[BasicAuth] = auth self.origin: Optional[str] = origin self.params: Optional[Mapping[str, str]] = params - self.protocols: Collection[str] = protocols + self.headers: Optional[LooseHeaders] = headers + self.proxy: Optional[StrOrURL] = proxy self.proxy_auth: Optional[BasicAuth] = proxy_auth self.proxy_headers: Optional[LooseHeaders] = proxy_headers - self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl self.ssl_context: Optional[SSLContext] = ssl_context + self.websocket_close_timeout: float = websocket_close_timeout self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - self.verify_ssl: Optional[bool] = verify_ssl + self.init_payload: Dict[str, Any] = init_payload # We need to set an event loop here if there is none @@ -221,12 +214,15 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - self.supported_subprotocols: Collection[str] = protocols or ( + self.supported_subprotocols: Collection[str] = subprotocols or ( self.APOLLO_SUBPROTOCOL, self.GRAPHQLWS_SUBPROTOCOL, ) + self.close_exception: Optional[Exception] = None + self.client_session_args = client_session_args + self.connect_args = connect_args def _parse_answer_graphqlws( self, answer: Dict[str, Any] @@ -782,28 +778,29 @@ async def connect(self) -> None: if self.websocket is None and not self._connecting: self._connecting = True + connect_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.connect_args: + connect_args.update(self.connect_args) + try: self.websocket = await self.session.ws_connect( - method=self.method, url=self.url, headers=self.headers, auth=self.auth, - autoclose=self.autoclose, - autoping=self.autoping, - compress=self.compress, heartbeat=self.heartbeat, - max_msg_size=self.max_msg_size, origin=self.origin, params=self.params, protocols=self.supported_subprotocols, proxy=self.proxy, proxy_auth=self.proxy_auth, proxy_headers=self.proxy_headers, + timeout=self.websocket_close_timeout, receive_timeout=self.receive_timeout, ssl=self.ssl, ssl_context=None, - timeout=self.websocket_close_timeout, - verify_ssl=self.verify_ssl, + **connect_args, ) finally: self._connecting = False diff --git a/tests/conftest.py b/tests/conftest.py index bd68982b..ee288eea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -516,7 +516,7 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" sample_transport = AIOHTTPWebsocketsTransport( url=url, - protocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) async with Client(transport=sample_transport) as session: diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 8d6fbab9..6fb8eafa 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -499,10 +499,12 @@ async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, ser url = f"ws://{server.hostname}:{server.port}/graphql" - # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions + # Increase max payload size transport = AIOHTTPWebsocketsTransport( url=url, - max_msg_size=(2**21), + connect_args={ + "max_msg_size": 2**21, + }, ) query = gql(query1_str) From 6906ad2db96654fd9c1de2f74bb0e221bbf2c586 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 18:24:54 +0200 Subject: [PATCH 52/61] Use connect_timeout param --- gql/transport/aiohttp_websockets.py | 38 +++++++++++++++++------------ 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 6186610f..b3054be0 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -785,22 +785,28 @@ async def connect(self) -> None: connect_args.update(self.connect_args) try: - self.websocket = await self.session.ws_connect( - url=self.url, - headers=self.headers, - auth=self.auth, - heartbeat=self.heartbeat, - origin=self.origin, - params=self.params, - protocols=self.supported_subprotocols, - proxy=self.proxy, - proxy_auth=self.proxy_auth, - proxy_headers=self.proxy_headers, - timeout=self.websocket_close_timeout, - receive_timeout=self.receive_timeout, - ssl=self.ssl, - ssl_context=None, - **connect_args, + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + self.websocket = await asyncio.wait_for( + self.session.ws_connect( + url=self.url, + headers=self.headers, + auth=self.auth, + heartbeat=self.heartbeat, + origin=self.origin, + params=self.params, + protocols=self.supported_subprotocols, + proxy=self.proxy, + proxy_auth=self.proxy_auth, + proxy_headers=self.proxy_headers, + timeout=self.websocket_close_timeout, + receive_timeout=self.receive_timeout, + ssl=self.ssl, + ssl_context=None, + **connect_args, + ), + self.connect_timeout, ) finally: self._connecting = False From 4ea698ff11e5fa78d61b784e80d3b483c4398bdb Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 19:08:48 +0200 Subject: [PATCH 53/61] Add reference documentation + remove deprecated ssl_context param --- docs/modules/gql.rst | 1 + docs/modules/transport_aiohttp_websockets.rst | 7 ++ gql/transport/aiohttp_websockets.py | 82 ++++++++++++++++++- 3 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 docs/modules/transport_aiohttp_websockets.rst diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 5f9edebe..b7c13c7c 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,6 +21,7 @@ Sub-Packages client transport transport_aiohttp + transport_aiohttp_websockets transport_appsync_auth transport_appsync_websockets transport_exceptions diff --git a/docs/modules/transport_aiohttp_websockets.rst b/docs/modules/transport_aiohttp_websockets.rst new file mode 100644 index 00000000..efa7e1bc --- /dev/null +++ b/docs/modules/transport_aiohttp_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.aiohttp_websockets +================================ + +.. currentmodule:: gql.transport.aiohttp_websockets + +.. automodule:: gql.transport.aiohttp_websockets + :member-order: bysource diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index b3054be0..225e67a7 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -120,7 +120,6 @@ def __init__( proxy_auth: Optional[BasicAuth] = None, proxy_headers: Optional[LooseHeaders] = None, ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, - ssl_context: Optional[SSLContext] = None, websocket_close_timeout: float = 10.0, receive_timeout: Optional[float] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, @@ -135,6 +134,85 @@ def __init__( client_session_args: Optional[Dict[str, Any]] = None, connect_args: Dict[str, Any] = {}, ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param headers: HTTP Headers that sent with every request + May be either *iterable of key-value pairs* or + :class:`~collections.abc.Mapping` + (e.g. :class:`dict`, + :class:`~multidict.CIMultiDict`). + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param init_payload: Dict of the payload sent in the connection_init message. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param client_session_args: Dict of extra args passed to + `aiohttp.ClientSession`_ + :param connect_args: Dict of extra args passed to + `aiohttp.ClientSession.ws_connect`_ + + .. _aiohttp.ClientSession.ws_connect: + https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession.ws_connect + .. _aiohttp.ClientSession: + https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession + """ self.url: StrOrURL = url self.heartbeat: Optional[float] = heartbeat self.auth: Optional[BasicAuth] = auth @@ -147,7 +225,6 @@ def __init__( self.proxy_headers: Optional[LooseHeaders] = proxy_headers self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl - self.ssl_context: Optional[SSLContext] = ssl_context self.websocket_close_timeout: float = websocket_close_timeout self.receive_timeout: Optional[float] = receive_timeout @@ -803,7 +880,6 @@ async def connect(self) -> None: timeout=self.websocket_close_timeout, receive_timeout=self.receive_timeout, ssl=self.ssl, - ssl_context=None, **connect_args, ), self.connect_timeout, From 97ac985e22235dcd0752cf3acfb6849e9565e422 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 16 Jul 2024 19:37:42 +0200 Subject: [PATCH 54/61] Allow to use the new transport from the cli --- gql/cli.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/gql/cli.py b/gql/cli.py index dd991546..a7d129e2 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -159,6 +159,7 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: "aiohttp", "phoenix", "websockets", + "aiohttp_websockets", "appsync_http", "appsync_websockets", ], @@ -286,7 +287,12 @@ def autodetect_transport(url: URL) -> str: """Detects which transport should be used depending on url.""" if url.scheme in ["ws", "wss"]: - transport_name = "websockets" + try: + import websockets # noqa: F401 + + transport_name = "websockets" + except ImportError: # pragma: no cover + transport_name = "aiohttp_websockets" else: assert url.scheme in ["http", "https"] @@ -338,6 +344,11 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: return WebsocketsTransport(url=args.server, **transport_args) + elif transport_name == "aiohttp_websockets": + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + return AIOHTTPWebsocketsTransport(url=args.server, **transport_args) + else: from gql.transport.appsync_auth import AppSyncAuthentication From cfee44c6ad4a1744800e0df3aad3814ddc706bb4 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 16 Jul 2024 19:39:37 +0200 Subject: [PATCH 55/61] Adding a new ws test server using aiohttp to do some tests without websockets --- tests/conftest.py | 175 ++++++++++++++++++++++++++ tests/test_aiohttp_websocket_query.py | 101 +++++++++++---- 2 files changed, 249 insertions(+), 27 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ee288eea..f4775345 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -211,6 +211,142 @@ async def stop(self): print("Server stopped\n\n\n") +class AIOHTTPWebsocketServer: + def __init__(self, with_ssl=False): + self.runner = None + self.site = None + self.port = None + self.hostname = "127.0.0.1" + self.with_ssl = with_ssl + self.ssl_context = None + if with_ssl: + _, self.ssl_context = get_localhost_ssl_context() + + def get_default_server_handler(answers): + async def default_server_handler(request): + + import aiohttp + import aiohttp.web + from aiohttp import WSMsgType + + ws = aiohttp.web.WebSocketResponse() + ws.headers.update({"dummy": "test1234"}) + await ws.prepare(request) + + try: + # Init and ack + msg = await anext(ws) + assert msg.type == WSMsgType.TEXT + result = msg.data + json_result = json.loads(result) + assert json_result["type"] == "connection_init" + await ws.send_str('{"type":"connection_ack"}') + query_id = 1 + + # Wait for queries and send answers + for answer in answers: + msg = await anext(ws) + if msg.type == WSMsgType.TEXT: + result = msg.data + + print(f"Server received: {result}", file=sys.stderr) + if isinstance(answer, str) and "{query_id}" in answer: + answer_format_params = {"query_id": query_id} + formatted_answer = answer.format(**answer_format_params) + else: + formatted_answer = answer + await ws.send_str(formatted_answer) + await ws.send_str( + f'{{"type":"complete","id":"{query_id}","payload":null}}' + ) + query_id += 1 + + elif msg.type == WSMsgType.ERROR: + print(f"WebSocket connection closed with: {ws.exception()}") + raise ws.exception() + elif msg.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + ): + print("WebSocket connection closed") + raise ConnectionResetError + + # Wait for connection_terminate + msg = await anext(ws) + result = msg.data + json_result = json.loads(result) + assert json_result["type"] == "connection_terminate" + + # Wait for connection close + msg = await anext(ws) + + except Exception as e: + print(f"Server exception {e!s}", file=sys.stderr) + + await ws.close() + return ws + + return default_server_handler + + async def shutdown_server(self, app): + print("Shutting down server...") + await app.shutdown() + await app.cleanup() + + async def start(self, handler): + import aiohttp + import aiohttp.web + + app = aiohttp.web.Application() + app.router.add_get("/graphql", handler) + self.runner = aiohttp.web.AppRunner(app) + await self.runner.setup() + + # Use port 0 to bind to an available port + self.site = aiohttp.web.TCPSite( + self.runner, self.hostname, 0, ssl_context=self.ssl_context + ) + await self.site.start() + + # Retrieve the actual port the server is listening on + sockets = self.site._server.sockets + if sockets: + self.port = sockets[0].getsockname()[1] + protocol = "https" if self.with_ssl else "http" + print(f"Server started at {protocol}://{self.hostname}:{self.port}") + + async def stop(self): + if self.site: + await self.site.stop() + if self.runner: + await self.runner.cleanup() + + +@pytest_asyncio.fixture +async def aiohttp_ws_server(request): + """Fixture used to start a dummy server to test the client behaviour + using the aiohttp dependency. + + It can take as argument either a handler function for the websocket server for + complete control OR an array of answers to be sent by the default server handler. + """ + + server_handler = get_aiohttp_ws_server_handler(request) + + try: + test_server = AIOHTTPWebsocketServer() + + # Starting the server with the fixture param as the handler function + await test_server.start(server_handler) + + yield test_server + except Exception as e: + print("Exception received in server fixture:", e) + finally: + await test_server.stop() + + class WebSocketServerHelper: @staticmethod async def send_complete(ws, query_id): @@ -307,6 +443,23 @@ def __exit__(self, type, value, traceback): os.unlink(self.filename) +def get_aiohttp_ws_server_handler(request): + """Get the server handler for the aiohttp websocket server. + + Either get it from test or use the default server handler + if the test provides only an array of answers. + """ + + if isinstance(request.param, types.FunctionType): + server_handler = request.param + + else: + answers = request.param + server_handler = AIOHTTPWebsocketServer.get_default_server_handler(answers) + + return server_handler + + def get_server_handler(request): """Get the server handler. @@ -483,6 +636,28 @@ async def aiohttp_client_and_server(server): yield session, server +@pytest_asyncio.fixture +async def aiohttp_client_and_aiohttp_ws_server(aiohttp_ws_server): + """ + Helper fixture to start an aiohttp websocket server and + a client connected to its port with an aiohttp websockets transport. + """ + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = aiohttp_ws_server + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, server + + @pytest_asyncio.fixture async def client_and_graphqlws_server(graphqlws_server): """Helper fixture to start a server with the graphql-ws prototocol diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 6fb8eafa..f154386b 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -17,7 +17,7 @@ from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the aiohttp AND websockets marker -pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] +pytestmark = pytest.mark.aiohttp query1_str = """ query getContinents { @@ -50,9 +50,12 @@ @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_starting_client_in_context_manager(event_loop, server): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_starting_client_in_context_manager( + event_loop, aiohttp_ws_server +): + server = aiohttp_ws_server from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -86,6 +89,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(event_loop, @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) async def test_aiohttp_websocket_using_ssl_connection( @@ -127,6 +131,7 @@ async def test_aiohttp_websocket_using_ssl_connection( @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_simple_query( @@ -149,13 +154,15 @@ async def test_aiohttp_websocket_simple_query( @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) +@pytest.mark.parametrize( + "aiohttp_ws_server", [server1_two_answers_in_series], indirect=True +) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_two_queries_in_series( - event_loop, aiohttp_client_and_server, query_str + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str ): - session, server = aiohttp_client_and_server + session, server = aiohttp_client_and_aiohttp_ws_server query = gql(query_str) @@ -185,6 +192,7 @@ async def server1_two_queries_in_parallel(ws, path): @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_two_queries_in_parallel( @@ -230,6 +238,7 @@ async def server_closing_while_we_are_doing_something_else(ws, path): @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize( "server", [server_closing_while_we_are_doing_something_else], indirect=True ) @@ -262,13 +271,15 @@ async def test_aiohttp_websocket_server_closing_after_first_query( @pytest.mark.asyncio -@pytest.mark.parametrize("server", [ignore_invalid_id_answers], indirect=True) +@pytest.mark.parametrize( + "aiohttp_ws_server", [ignore_invalid_id_answers], indirect=True +) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_ignore_invalid_id( - event_loop, aiohttp_client_and_server, query_str + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str ): - session, server = aiohttp_client_and_server + session, server = aiohttp_client_and_aiohttp_ws_server query = gql(query_str) @@ -300,8 +311,12 @@ async def assert_client_is_working(session): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_multiple_connections_in_series(event_loop, server): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_series( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -324,8 +339,12 @@ async def test_aiohttp_websocket_multiple_connections_in_series(event_loop, serv @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_multiple_connections_in_parallel(event_loop, server): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_parallel( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -344,10 +363,12 @@ async def task_coro(): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transport( - event_loop, server + event_loop, aiohttp_ws_server ): + server = aiohttp_ws_server + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -389,6 +410,7 @@ async def server_with_authentication_in_connection_init_payload(ws, path): @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize( "server", [server_with_authentication_in_connection_init_payload], indirect=True ) @@ -423,6 +445,7 @@ async def test_aiohttp_websocket_connect_success_with_authentication_in_connecti @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize( "server", [server_with_authentication_in_connection_init_payload], indirect=True ) @@ -449,8 +472,10 @@ async def test_aiohttp_websocket_connect_failed_with_authentication_in_connectio assert transport.websocket is None -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -def test_aiohttp_websocket_execute_sync(server): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): + server = aiohttp_ws_server + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -492,8 +517,12 @@ def test_aiohttp_websocket_execute_sync(server): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, server): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_add_extra_parameters_to_connect( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -527,6 +556,7 @@ async def server_sending_keep_alive_before_connection_ack(ws, path): @pytest.mark.asyncio +@pytest.mark.websockets @pytest.mark.parametrize( "server", [server_sending_keep_alive_before_connection_ack], indirect=True ) @@ -554,8 +584,19 @@ async def test_aiohttp_websocket_non_regression_bug_108( @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_using_cli(event_loop, server, monkeypatch, capsys): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("transport_arg", [[], ["--transport=aiohttp_websockets"]]) +async def test_aiohttp_websocket_using_cli( + event_loop, aiohttp_ws_server, transport_arg, monkeypatch, capsys +): + + """ + Note: depending on the transport_arg parameter, if there is no transport argument, + then we will use WebsocketsTransport if the websockets dependency is installed, + or AIOHTTPWebsocketsTransport if that is not the case. + """ + + server = aiohttp_ws_server url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") @@ -566,7 +607,7 @@ async def test_aiohttp_websocket_using_cli(event_loop, server, monkeypatch, caps from gql.cli import get_parser, main parser = get_parser(with_examples=True) - args = parser.parse_args([url]) + args = parser.parse_args([url, *transport_arg]) # Monkeypatching sys.stdin to simulate getting the query # via the standard input @@ -605,13 +646,15 @@ async def test_aiohttp_websocket_using_cli(event_loop, server, monkeypatch, caps @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True) +@pytest.mark.parametrize( + "aiohttp_ws_server", [server1_answers_with_extensions], indirect=True +) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_simple_query_with_extensions( - event_loop, aiohttp_client_and_server, query_str + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str ): - session, server = aiohttp_client_and_server + session, server = aiohttp_client_and_aiohttp_ws_server query = gql(query_str) @@ -621,9 +664,13 @@ async def test_aiohttp_websocket_simple_query_with_extensions( @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_connector_owner_false(event_loop, server): +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_server): + + server = aiohttp_ws_server + from aiohttp import TCPConnector + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" From d43224f6161d23fd64e084817d67c4ee5d3c107e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 16 Jul 2024 19:46:16 +0200 Subject: [PATCH 56/61] Try to fix ConnectionResetError exception --- tests/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index f4775345..783dacbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -281,6 +281,9 @@ async def default_server_handler(request): # Wait for connection close msg = await anext(ws) + except ConnectionResetError: + pass + except Exception as e: print(f"Server exception {e!s}", file=sys.stderr) From 04cd47c9811d43bccdb5bc0eb0c7fb1d18c1a7d6 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 16 Jul 2024 20:33:13 +0200 Subject: [PATCH 57/61] Fix anext not existing on Python 3.8 --- tests/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 783dacbd..c164c355 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -235,7 +235,7 @@ async def default_server_handler(request): try: # Init and ack - msg = await anext(ws) + msg = await ws.__anext__() assert msg.type == WSMsgType.TEXT result = msg.data json_result = json.loads(result) @@ -245,7 +245,7 @@ async def default_server_handler(request): # Wait for queries and send answers for answer in answers: - msg = await anext(ws) + msg = await ws.__anext__() if msg.type == WSMsgType.TEXT: result = msg.data @@ -273,13 +273,13 @@ async def default_server_handler(request): raise ConnectionResetError # Wait for connection_terminate - msg = await anext(ws) + msg = await ws.__anext__() result = msg.data json_result = json.loads(result) assert json_result["type"] == "connection_terminate" # Wait for connection close - msg = await anext(ws) + msg = await ws.__anext__() except ConnectionResetError: pass From 933348d6d515e7ea70f5d4e4447042c4bc30d169 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 17 Jul 2024 17:21:34 +0200 Subject: [PATCH 58/61] Adding the transport to the docs --- .../code_examples/aiohttp_websockets_async.py | 50 +++++++++++++++++++ docs/intro.rst | 40 ++++++++------- docs/transports/aiohttp.rst | 4 +- docs/transports/aiohttp_websockets.rst | 31 ++++++++++++ docs/transports/async_transports.rst | 1 + 5 files changed, 106 insertions(+), 20 deletions(-) create mode 100644 docs/code_examples/aiohttp_websockets_async.py create mode 100644 docs/transports/aiohttp_websockets.rst diff --git a/docs/code_examples/aiohttp_websockets_async.py b/docs/code_examples/aiohttp_websockets_async.py new file mode 100644 index 00000000..69520053 --- /dev/null +++ b/docs/code_examples/aiohttp_websockets_async.py @@ -0,0 +1,50 @@ +import asyncio +import logging + +from gql import Client, gql +from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + +logging.basicConfig(level=logging.INFO) + + +async def main(): + + transport = AIOHTTPWebsocketsTransport( + url="wss://countries.trevorblades.com/graphql" + ) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with Client( + transport=transport, + ) as session: + + # Execute single query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + result = await session.execute(query) + print(result) + + # Request subscription + subscription = gql( + """ + subscription { + somethingChanged { + id + } + } + """ + ) + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/intro.rst b/docs/intro.rst index 8f59ed16..21de16bd 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -36,25 +36,27 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: The corresponding between extra dependencies required and the GQL classes is: -+---------------------+----------------------------------------------------------------+ -| Extra dependencies | Classes | -+=====================+================================================================+ -| aiohttp | :ref:`AIOHTTPTransport ` | -+---------------------+----------------------------------------------------------------+ -| websockets | :ref:`WebsocketsTransport ` | -| | | -| | :ref:`PhoenixChannelWebsocketsTransport ` | -| | | -| | :ref:`AppSyncWebsocketsTransport ` | -+---------------------+----------------------------------------------------------------+ -| requests | :ref:`RequestsHTTPTransport ` | -+---------------------+----------------------------------------------------------------+ -| httpx | :ref:`HTTPTXTransport ` | -| | | -| | :ref:`HTTPXAsyncTransport ` | -+---------------------+----------------------------------------------------------------+ -| botocore | :ref:`AppSyncIAMAuthentication ` | -+---------------------+----------------------------------------------------------------+ ++---------------------+------------------------------------------------------------------+ +| Extra dependencies | Classes | ++=====================+==================================================================+ +| aiohttp | :ref:`AIOHTTPTransport ` | +| | | +| | :ref:`AIOHTTPWebsocketsTransport ` | ++---------------------+------------------------------------------------------------------+ +| websockets | :ref:`WebsocketsTransport ` | +| | | +| | :ref:`PhoenixChannelWebsocketsTransport ` | +| | | +| | :ref:`AppSyncWebsocketsTransport ` | ++---------------------+------------------------------------------------------------------+ +| requests | :ref:`RequestsHTTPTransport ` | ++---------------------+------------------------------------------------------------------+ +| httpx | :ref:`HTTPTXTransport ` | +| | | +| | :ref:`HTTPXAsyncTransport ` | ++---------------------+------------------------------------------------------------------+ +| botocore | :ref:`AppSyncIAMAuthentication ` | ++---------------------+------------------------------------------------------------------+ .. note:: diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index 68b3eb99..b852108b 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -10,7 +10,9 @@ Reference: :class:`gql.transport.aiohttp.AIOHTTPTransport` .. note:: GraphQL subscriptions are not supported on the HTTP transport. - For subscriptions you should use the :ref:`websockets transport `. + For subscriptions you should use a websockets transport: + :ref:`WebsocketsTransport ` or + :ref:`AIOHTTPWebsocketsTransport `. .. literalinclude:: ../code_examples/aiohttp_async.py diff --git a/docs/transports/aiohttp_websockets.rst b/docs/transports/aiohttp_websockets.rst new file mode 100644 index 00000000..def3372e --- /dev/null +++ b/docs/transports/aiohttp_websockets.rst @@ -0,0 +1,31 @@ +.. _aiohttp_websockets_transport: + +AIOHTTPWebsocketsTransport +========================== + +The AIOHTTPWebsocketsTransport is an alternative to the :ref:`websockets_transport`, +using the `aiohttp` dependency instead of the `websockets` dependency. + +It also supports both: + + - the `Apollo websockets transport protocol`_. + - the `GraphQL-ws websockets transport protocol`_ + +It will propose both subprotocols to the backend and detect the supported protocol +from the response http headers returned by the backend. + +.. note:: + For some backends (graphql-ws before `version 5.6.1`_ without backwards compatibility), it may be necessary to specify + only one subprotocol to the backend. It can be done by using + :code:`subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL]` + or :code:`subprotocols=[AIOHTTPWebsocketsTransport.APOLLO_SUBPROTOCOL]` in the transport arguments. + +This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. + +Reference: :class:`gql.transport.aiohttp_websockets.AIOHTTPWebsocketsTransport` + +.. literalinclude:: ../code_examples/aiohttp_websockets_async.py + +.. _version 5.6.1: https://github.com/enisdenjo/graphql-ws/releases/tag/v5.6.1 +.. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +.. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index 7d751df0..ba5ca136 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -12,5 +12,6 @@ Async transports are transports which are using an underlying async library. The aiohttp httpx_async websockets + aiohttp_websockets phoenix appsync From 22e8a3ac4a278a179ef6b7278fdbdaea8b1d871c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 17 Jul 2024 18:05:06 +0200 Subject: [PATCH 59/61] Fix tests still using WebsocketsTransport instead of new transport --- ...aiohttp_websocket_graphqlws_exceptions.py} | 41 +++++++++--------- ...ohttp_websocket_graphqlws_subscription.py} | 36 ++++++++-------- tests/test_aiohttp_websocket_subscription.py | 42 +++++++++---------- 3 files changed, 62 insertions(+), 57 deletions(-) rename tests/{test_aiohttp_websocket_graphql_exceptions.py => test_aiohttp_websocket_graphqlws_exceptions.py} (84%) rename tests/{test_aiohttp_websocket_graphql_subscription.py => test_aiohttp_websocket_graphqlws_subscription.py} (95%) diff --git a/tests/test_aiohttp_websocket_graphql_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py similarity index 84% rename from tests/test_aiohttp_websocket_graphql_exceptions.py rename to tests/test_aiohttp_websocket_graphqlws_exceptions.py index 577ddc6b..d87315c9 100644 --- a/tests/test_aiohttp_websocket_graphql_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -38,7 +38,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_graphqlws_invalid_query( +async def test_aiohttp_websocket_graphqlws_invalid_query( event_loop, client_and_aiohttp_websocket_graphql_server, query_str ): @@ -81,7 +81,7 @@ async def server_invalid_subscription(ws, path): "graphqlws_server", [server_invalid_subscription], indirect=True ) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) -async def test_aiohttp_graphqlws_invalid_subscription( +async def test_aiohttp_websocket_graphqlws_invalid_subscription( event_loop, client_and_aiohttp_websocket_graphql_server, query_str ): @@ -109,17 +109,17 @@ async def server_no_ack(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_graphqlws_server_does_not_send_ack( +async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( event_loop, graphqlws_server, query_str ): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -141,7 +141,7 @@ async def server_invalid_query(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) -async def test_aiohttp_graphqlws_sending_invalid_query( +async def test_aiohttp_websocket_graphqlws_sending_invalid_query( event_loop, client_and_aiohttp_websocket_graphql_server ): @@ -195,7 +195,7 @@ async def test_aiohttp_graphqlws_sending_invalid_query( ], indirect=True, ) -async def test_aiohttp_graphqlws_transport_protocol_errors( +async def test_aiohttp_websocket_graphqlws_transport_protocol_errors( event_loop, client_and_aiohttp_websocket_graphql_server ): @@ -215,16 +215,18 @@ async def server_without_ack(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) -async def test_aiohttp_graphqlws_server_does_not_ack(event_loop, graphqlws_server): - from gql.transport.websockets import WebsocketsTransport +async def test_aiohttp_websocket_graphqlws_server_does_not_ack( + event_loop, graphqlws_server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -234,18 +236,19 @@ async def server_closing_directly(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) -async def test_aiohttp_graphqlws_server_closing_directly(event_loop, graphqlws_server): - import websockets +async def test_aiohttp_websocket_graphqlws_server_closing_directly( + event_loop, graphqlws_server +): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): - async with Client(transport=sample_transport): + with pytest.raises(ConnectionResetError): + async with Client(transport=transport): pass @@ -256,7 +259,7 @@ async def server_closing_after_ack(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) -async def test_aiohttp_graphqlws_server_closing_after_ack( +async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( event_loop, client_and_aiohttp_websocket_graphql_server ): diff --git a/tests/test_aiohttp_websocket_graphql_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py similarity index 95% rename from tests/test_aiohttp_websocket_graphql_subscription.py rename to tests/test_aiohttp_websocket_graphqlws_subscription.py index bb5529a1..e5db7ca1 100644 --- a/tests/test_aiohttp_websocket_graphql_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -228,7 +228,7 @@ async def server_countdown_disconnect(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription( +async def test_aiohttp_websocket_graphqlws_subscription( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): @@ -251,7 +251,7 @@ async def test_aiohttp_graphqlws_subscription( @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_break( +async def test_aiohttp_websocket_graphqlws_subscription_break( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): @@ -281,7 +281,7 @@ async def test_aiohttp_graphqlws_subscription_break( @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_task_cancel( +async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): @@ -320,7 +320,7 @@ async def cancel_task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_close_transport( +async def test_aiohttp_websocket_graphqlws_subscription_close_transport( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): @@ -385,7 +385,7 @@ async def server_countdown_close_connection_in_middle(ws, path): "graphqlws_server", [server_countdown_close_connection_in_middle], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_server_connection_closed( +async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): session, _ = client_and_aiohttp_websocket_graphql_server @@ -406,7 +406,7 @@ async def test_aiohttp_graphqlws_subscription_server_connection_closed( @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_with_operation_name( +async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): @@ -436,7 +436,7 @@ async def test_aiohttp_graphqlws_subscription_with_operation_name( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_with_keepalive( +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str ): @@ -466,7 +466,7 @@ async def test_aiohttp_graphqlws_subscription_with_keepalive( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_ok( +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_ok( event_loop, graphqlws_server, subscription_str ): @@ -500,7 +500,7 @@ async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_ok( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_nok( +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_nok( event_loop, graphqlws_server, subscription_str ): @@ -535,7 +535,7 @@ async def test_aiohttp_graphqlws_subscription_with_keepalive_with_timeout_nok( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_with_ping_interval_ok( +async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( event_loop, graphqlws_server, subscription_str ): @@ -571,7 +571,7 @@ async def test_aiohttp_graphqlws_subscription_with_ping_interval_ok( "graphqlws_server", [server_countdown_dont_answer_pings], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_with_ping_interval_nok( +async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( event_loop, graphqlws_server, subscription_str ): @@ -604,7 +604,7 @@ async def test_aiohttp_graphqlws_subscription_with_ping_interval_nok( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_manual_pings_with_payload( +async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payload( event_loop, graphqlws_server, subscription_str ): @@ -646,7 +646,7 @@ async def test_aiohttp_graphqlws_subscription_manual_pings_with_payload( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_manual_pong_answers_with_payload( +async def test_aiohttp_websocket_graphqlws_subscription_manual_pong_with_payload( event_loop, graphqlws_server, subscription_str ): @@ -690,7 +690,9 @@ async def answer_ping_coro(): "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -def test_aiohttp_graphqlws_subscription_sync(graphqlws_server, subscription_str): +def test_aiohttp_websocket_graphqlws_subscription_sync( + graphqlws_server, subscription_str +): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -719,7 +721,7 @@ def test_aiohttp_graphqlws_subscription_sync(graphqlws_server, subscription_str) "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -def test_aiohttp_graphqlws_subscription_sync_graceful_shutdown( +def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( graphqlws_server, subscription_str ): """Note: this test will simulate a control-C happening while a sync subscription @@ -777,7 +779,7 @@ def test_aiohttp_graphqlws_subscription_sync_graceful_shutdown( "graphqlws_server", [server_countdown_keepalive], indirect=True ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_aiohttp_graphqlws_subscription_running_in_thread( +async def test_aiohttp_websocket_graphqlws_subscription_running_in_thread( event_loop, graphqlws_server, subscription_str, run_sync_test ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -811,7 +813,7 @@ def test_code(): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) -async def test_aiohttp_graphqlws_subscription_reconnecting_session( +async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe ): diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index c5b6d504..7aa2fcb1 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -483,9 +483,9 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -513,9 +513,9 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -536,14 +536,14 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_aiohttp_websocket_subscription_sync(server, subscription_str): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -562,14 +562,14 @@ def test_aiohttp_websocket_subscription_sync(server, subscription_str): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_aiohttp_websocket_subscription_sync_user_exception(server, subscription_str): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -593,14 +593,14 @@ def test_aiohttp_websocket_subscription_sync_user_exception(server, subscription @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_aiohttp_websocket_subscription_sync_break(server, subscription_str): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -636,14 +636,14 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( This test does not work on Windows but the behaviour with Windows is correct. """ - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -686,14 +686,14 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( async def test_aiohttp_websocket_subscription_running_in_thread( event_loop, server, subscription_str, run_sync_test ): - from gql.transport.websockets import WebsocketsTransport + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport def test_code(): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -730,9 +730,9 @@ async def test_async_aiohttp_client_validation( url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: From db9f5db1fa78c9942d5eb844877ed1534832faf1 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 17 Jul 2024 19:22:28 +0200 Subject: [PATCH 60/61] Fix countdown server not closing properly in some cases --- tests/test_aiohttp_websocket_subscription.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 7aa2fcb1..3ebf4dbc 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -149,12 +149,15 @@ async def keepalive_coro(): break stopping_task = asyncio.ensure_future(stopping_coro()) - keepalive_task = asyncio.ensure_future(keepalive_coro()) + if WITH_KEEPALIVE: + keepalive_task = asyncio.ensure_future(keepalive_coro()) try: await counting_task except asyncio.CancelledError: print("Now counting task is cancelled") + except Exception as exc: + print(f"Exception in counting task: {exc!s}") stopping_task.cancel() From 9eaafa4a5eee1600d6cb0b47fa4530050140729d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 17 Jul 2024 19:23:12 +0200 Subject: [PATCH 61/61] More resilient _close_coro cleanup --- gql/transport/aiohttp_websockets.py | 73 ++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 225e67a7..ff310a82 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -967,18 +967,23 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: try: - # We should always have an active websocket connection here - assert self.websocket is not None - - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task + try: + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + except Exception as exc: # pragma: no cover + log.warning( + "_close_coro cancel keep alive task exception: " + repr(exc) + ) - # Calling the subclass close hook - await self._close_hook() + try: + # Calling the subclass close hook + await self._close_hook() + except Exception as exc: # pragma: no cover + log.warning("_close_coro close_hook exception: " + repr(exc)) # Saving exception to raise it later if trying to use the transport # after it has already closed. @@ -999,8 +1004,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: close websocket connection") - await self.websocket.close() - self.websocket = None + try: + assert self.websocket is not None + + await self.websocket.close() + self.websocket = None + except Exception as exc: + log.warning("_close_coro websocket close exception: " + repr(exc)) log.debug("_close_coro: close aiohttp session") @@ -1012,31 +1022,48 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("connector_owner is False -> not closing connector") else: - assert self.session is not None - - closed_event = AIOHTTPTransport.create_aiohttp_closed_event( - self.session - ) - await self.session.close() try: - await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) - except asyncio.TimeoutError: - pass + assert self.session is not None + + closed_event = AIOHTTPTransport.create_aiohttp_closed_event( + self.session + ) + await self.session.close() + try: + await asyncio.wait_for( + closed_event.wait(), self.ssl_close_timeout + ) + except asyncio.TimeoutError: + pass + except Exception as exc: # pragma: no cover + log.warning("_close_coro session close exception: " + repr(exc)) self.session = None log.debug("_close_coro: aiohttp session closed") + try: + assert self.receive_data_task is not None + + self.receive_data_task.cancel() + with suppress(asyncio.CancelledError): + await self.receive_data_task + except Exception as exc: # pragma: no cover + log.warning( + "_close_coro cancel receive data task exception: " + repr(exc) + ) + except Exception as exc: # pragma: no cover log.warning("Exception catched in _close_coro: " + repr(exc)) finally: - log.debug("_close_coro: start cleanup") + log.debug("_close_coro: final cleanup") self.websocket = None self.close_task = None self.check_keep_alive_task = None + self.receive_data_task = None self._wait_closed.set() log.debug("_close_coro: exiting")