Skip to content
1 change: 1 addition & 0 deletions CHANGES/11717.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add parser factory to response handler and a writer factory to client request base to prepare for future protocol support -- by :user:`Vizonex`
11 changes: 9 additions & 2 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""

def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self,
loop: asyncio.AbstractEventLoop,
parser_factory: type[HttpResponseParser] = HttpResponseParser,
) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop)

Expand All @@ -39,6 +43,9 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:

self._tail = b""
self._upgraded = False
# parser_factory is important because it will
# allow for other protocols to be added in the future.
self._parser_factory = parser_factory
self._parser: HttpResponseParser | None = None

self._read_timeout: float | None = None
Expand Down Expand Up @@ -237,7 +244,7 @@ def set_response_params(

self._timeout_ceil_threshold = timeout_ceil_threshold

self._parser = HttpResponseParser(
self._parser = self._parser_factory(
self,
self._loop,
read_bufsize,
Expand Down
14 changes: 12 additions & 2 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@
method = "GET"

_writer_task: asyncio.Task[None] | None = None # async task for streaming data

_writer_factory: type[StreamWriter] = StreamWriter # allowing http/2 and http/3
_skip_auto_headers: "CIMultiDict[None] | None" = None

# N.B.
Expand All @@ -719,6 +719,8 @@
loop: asyncio.AbstractEventLoop,
ssl: SSLContext | bool | Fingerprint,
trust_env: bool = False,
writer_factory: type[StreamWriter] | None = None,
version: HttpVersion | None = None,
):
if match := _CONTAINS_CONTROL_CHAR_RE.search(method):
raise ValueError(
Expand All @@ -740,6 +742,13 @@
self._update_headers(headers)
self._update_auth(auth, trust_env)

# setup in case of newer protocols.
if version:
self.version = version

if writer_factory:
self._writer_factory = writer_factory

def _reset_writer(self, _: object = None) -> None:
self._writer_task = None

Expand Down Expand Up @@ -844,11 +853,12 @@
)

def _create_writer(self, protocol: BaseProtocol) -> StreamWriter:
return StreamWriter(protocol, self.loop)
return self._writer_factory(protocol, self.loop)

def _should_write(self, protocol: BaseProtocol) -> bool:
return protocol.writing_paused

# TODO: Make _send swappable for http/2 or http/3 ?
async def _send(self, conn: "Connection") -> ClientResponse:
# Specify request target:
# - CONNECT request must send authority form URI
Expand Down Expand Up @@ -1016,7 +1026,7 @@

self._update_auto_headers(skip_auto_headers)
self._update_cookies(cookies)
self._update_content_encoding(data, compress)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute version, which was previously defined in superclass
ClientRequestBase
.
self._update_proxy(proxy, proxy_auth, proxy_headers)

self._update_body_from_data(data)
Expand Down
Loading