diff --git a/CHANGES/11717.feature.rst b/CHANGES/11717.feature.rst new file mode 100644 index 00000000000..e38b91fc804 --- /dev/null +++ b/CHANGES/11717.feature.rst @@ -0,0 +1 @@ +Add parser factory to response handler and a writer factory to client request base to prepare for future protocol support -- by :user:`Vizonex` diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 144cb42d52b..8d8f067ac31 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -25,7 +25,11 @@ class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamReader]]): """Helper class to adapt between Protocol and StreamReader.""" - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + loop: asyncio.AbstractEventLoop, + parser_factory: type[HttpResponseParser] = HttpResponseParser, + ) -> None: BaseProtocol.__init__(self, loop=loop) DataQueue.__init__(self, loop) @@ -39,6 +43,9 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._tail = b"" self._upgraded = False + # parser_factory is important because it will + # allow for other protocols to be added in the future. + self._parser_factory = parser_factory self._parser: HttpResponseParser | None = None self._read_timeout: float | None = None @@ -237,7 +244,7 @@ def set_response_params( self._timeout_ceil_threshold = timeout_ceil_threshold - self._parser = HttpResponseParser( + self._parser = self._parser_factory( self, self._loop, read_bufsize, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 551b3374c6a..bcae4daf962 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -701,7 +701,7 @@ class ClientRequestBase: method = "GET" _writer_task: asyncio.Task[None] | None = None # async task for streaming data - + _writer_factory: type[StreamWriter] = StreamWriter # allowing http/2 and http/3 _skip_auto_headers: "CIMultiDict[None] | None" = None # N.B. @@ -719,6 +719,8 @@ def __init__( loop: asyncio.AbstractEventLoop, ssl: SSLContext | bool | Fingerprint, trust_env: bool = False, + writer_factory: type[StreamWriter] | None = None, + version: HttpVersion | None = None, ): if match := _CONTAINS_CONTROL_CHAR_RE.search(method): raise ValueError( @@ -740,6 +742,13 @@ def __init__( self._update_headers(headers) self._update_auth(auth, trust_env) + # setup in case of newer protocols. + if version: + self.version = version + + if writer_factory: + self._writer_factory = writer_factory + def _reset_writer(self, _: object = None) -> None: self._writer_task = None @@ -844,11 +853,12 @@ def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: ) def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: - return StreamWriter(protocol, self.loop) + return self._writer_factory(protocol, self.loop) def _should_write(self, protocol: BaseProtocol) -> bool: return protocol.writing_paused + # TODO: Make _send swappable for http/2 or http/3 ? async def _send(self, conn: "Connection") -> ClientResponse: # Specify request target: # - CONNECT request must send authority form URI