From bb9dcbb28ae44286c230f9f638de1693bed5ce20 Mon Sep 17 00:00:00 2001 From: Yifei Kong Date: Sun, 1 Oct 2023 15:33:22 +0800 Subject: [PATCH] Add stream support for asyncio --- curl_cffi/aio.py | 5 ++- curl_cffi/requests/models.py | 59 +++++++++++++++++++----- curl_cffi/requests/session.py | 84 ++++++++++++++++++++++++----------- examples/stream.py | 59 ++++++++++++++++++++++-- 4 files changed, 166 insertions(+), 41 deletions(-) diff --git a/curl_cffi/aio.py b/curl_cffi/aio.py index 52508d4a..46cae583 100644 --- a/curl_cffi/aio.py +++ b/curl_cffi/aio.py @@ -8,7 +8,7 @@ from ._wrapper import ffi, lib # type: ignore from .const import CurlMOpt -from .curl import Curl, CurlError, CurlInfo +from .curl import Curl DEFAULT_CACERT = os.path.join(os.path.dirname(__file__), "cacert.pem") @@ -72,6 +72,7 @@ def socket_function(curl, sockfd: int, what: int, clientp: Any, data: Any): class AsyncCurl: """Wrapper around curl_multi handle to provide asyncio support. It uses the libcurl socket_action APIs.""" + def __init__(self, cacert: str = DEFAULT_CACERT, loop=None): self._curlm = lib.curl_multi_init() self._cacert = cacert @@ -103,7 +104,7 @@ def close(self): lib.curl_multi_cleanup(self._curlm) self._curlm = None # Remove add readers and writers - for sockfd in self._sockfds: + for sockfd in self._sockfds: self.loop.remove_reader(sockfd) self.loop.remove_writer(sockfd) # Cancel all time functions diff --git a/curl_cffi/requests/models.py b/curl_cffi/requests/models.py index 8cdd5450..e304c183 100644 --- a/curl_cffi/requests/models.py +++ b/curl_cffi/requests/models.py @@ -61,6 +61,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non self.http_version = 0 self.history = [] self.queue: Optional[queue.Queue] = None + self.stream_task = None @property def text(self) -> str: @@ -101,19 +102,55 @@ def iter_content(self, chunk_size=None, decode_unicode=False): warnings.warn("chunk_size is ignored, there is no way to tell curl that.") if decode_unicode: raise NotImplementedError() - try: - while True: - chunk = self.queue.get() # type: ignore - if chunk is None: - return - yield chunk - finally: - # If anything happens, always free the memory - self.curl.reset() # type: ignore - clear_queue(self.queue) # type: ignore + while True: + chunk = self.queue.get() # type: ignore + if chunk is None: + return + yield chunk def json(self, **kw): return loads(self.content, **kw) def close(self): - warnings.warn("Deprecated, use Session.close") + self.stream_task.result() # type: ignore + + async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None): + """ + Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/ + which is under the License: Apache 2.0 + """ + pending = None + + async for chunk in self.aiter_content( + chunk_size=chunk_size, decode_unicode=decode_unicode + ): + if pending is not None: + chunk = pending + chunk + if delimiter: + lines = chunk.split(delimiter) + else: + lines = chunk.splitlines() + if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: + pending = lines.pop() + else: + pending = None + + for line in lines: + yield line + + if pending is not None: + yield pending + + async def aiter_content(self, chunk_size=None, decode_unicode=False): + if chunk_size: + warnings.warn("chunk_size is ignored, there is no way to tell curl that.") + if decode_unicode: + raise NotImplementedError() + while True: + chunk = await self.queue.get() # type: ignore + if chunk is None: + return + yield chunk + + async def aclose(self): + await self.stream_task # type: ignore diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index ff48952d..62b68084 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -8,7 +8,7 @@ from functools import partialmethod from io import BytesIO from json import dumps -from typing import Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Callable, Dict, List, Any, Optional, Tuple, Union, cast from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse from concurrent.futures import ThreadPoolExecutor @@ -188,6 +188,7 @@ def _set_curl_options( http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, stream: bool = False, + queue_class: Any = None, ): c = curl @@ -359,8 +360,10 @@ def _set_curl_options( c.setopt(k, v) buffer = None + q = None if stream: - c.setopt(CurlOpt.WRITEFUNCTION, self.queue.put) # type: ignore + q = queue_class() # type: ignore + c.setopt(CurlOpt.WRITEFUNCTION, q.put_nowait) # type: ignore elif content_callback is not None: c.setopt(CurlOpt.WRITEFUNCTION, content_callback) else: @@ -377,7 +380,7 @@ def _set_curl_options( if interface: c.setopt(CurlOpt.INTERFACE, interface.encode()) - return req, buffer, header_buffer + return req, buffer, header_buffer, q def _parse_response(self, curl, buffer, header_buffer): c = curl @@ -505,17 +508,6 @@ def executor(self): self._executor = ThreadPoolExecutor() return self._executor - @property - def queue(self): - if self._use_thread_local_curl: - if getattr(self._local, "queue", None) is None: - self._local.queue = queue.Queue() - return self._local.queue - else: - if self._queue is None: - self._queue = queue.Queue() - return self._queue - def __enter__(self): return self @@ -553,7 +545,7 @@ def request( ) -> Response: """Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters.""" c = self.curl - req, buffer, header_buffer = self._set_curl_options( + req, buffer, header_buffer, q = self._set_curl_options( c, method=method, url=url, @@ -577,6 +569,7 @@ def request( http_version=http_version, interface=interface, stream=stream, + queue_class=queue.Queue, ) try: if self._thread == "eventlet": @@ -587,13 +580,17 @@ def request( gevent.get_hub().threadpool.spawn(c.perform).get() else: if stream: - queue = self.queue # using queue from current thread def perform(): c.perform() - queue.put(None) # sentinel + # None acts as a sentinel + q.put(None) # type: ignore + + def cleanup(fut): + c.reset() - self.executor.submit(perform) + stream_task = self.executor.submit(perform) + stream_task.add_done_callback(cleanup) else: c.perform() except CurlError as e: @@ -603,7 +600,9 @@ def perform(): else: rsp = self._parse_response(c, buffer, header_buffer) rsp.request = req - rsp.queue = self.queue + if stream: + rsp.stream_task = stream_task # type: ignore + rsp.queue = q return rsp finally: if not stream: @@ -660,6 +659,7 @@ def __init__( self.loop = loop self._acurl = async_curl self.max_clients = max_clients + self._closed = False self.init_pool() if sys.version_info >= (3, 8) and sys.platform.lower().startswith("win"): if isinstance( @@ -705,6 +705,23 @@ async def __aexit__(self, *args): def close(self): """Close the session.""" self.acurl.close() + self._closed = True + while True: + try: + curl = self.pool.get_nowait() + if curl: + curl.close() + except asyncio.QueueEmpty: + break + + def release_curl(self, curl): + curl.clean_after_perform() + if not self._closed: + self.acurl.remove_handle(curl) + curl.reset() + self.push_curl(curl) + else: + curl.close() async def request( self, @@ -729,10 +746,11 @@ async def request( default_headers: Optional[bool] = None, http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, + stream: bool = False, ): """Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters.""" curl = await self.pop_curl() - req, buffer, header_buffer = self._set_curl_options( + req, buffer, header_buffer, q = self._set_curl_options( curl=curl, method=method, url=url, @@ -755,11 +773,26 @@ async def request( default_headers=default_headers, http_version=http_version, interface=interface, + stream=stream, + queue_class=asyncio.Queue, ) try: # curl.debug() task = self.acurl.add_handle(curl) - await task + if stream: + + async def perform(): + await task + # None acts as a sentinel + await q.put(None) # type: ignore + + def cleanup(fut): + self.release_curl(curl) + + stream_task = asyncio.create_task(perform()) + stream_task.add_done_callback(cleanup) + else: + await task # print(curl.getinfo(CurlInfo.CAINFO)) except CurlError as e: rsp = self._parse_response(curl, buffer, header_buffer) @@ -768,12 +801,13 @@ async def request( else: rsp = self._parse_response(curl, buffer, header_buffer) rsp.request = req + rsp.queue = q + if stream: + rsp.stream_task = stream_task # type: ignore return rsp finally: - curl.clean_after_perform() - self.acurl.remove_handle(curl) - curl.reset() - self.push_curl(curl) + if not stream: + self.release_curl(curl) head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") diff --git a/examples/stream.py b/examples/stream.py index f9f204d5..a7a427dc 100644 --- a/examples/stream.py +++ b/examples/stream.py @@ -1,13 +1,66 @@ +import asyncio from curl_cffi import requests +from contextlib import closing +try: + # Python 3.10+ + from contextlib import aclosing +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager + async def aclosing(thing): + try: + yield thing + finally: + await thing.aclose() + + +URL = "https://httpbin.org/stream/20" with requests.Session() as s: - r = s.get("https://httpbin.org/stream/20", stream=True) + print("\n======================================================================") + print("Iterating over chunks") + print("=====================================================================\n") + r = s.get(URL, stream=True) for chunk in r.iter_content(): print("CHUNK", chunk) + r.close() -with requests.Session() as s: - r = s.get("https://httpbin.org/stream/20", stream=True) + print("\n=====================================================================") + print("Iterating on a line basis") + print("=====================================================================\n") + r = s.get(URL, stream=True) for line in r.iter_lines(): print("LINE", line.decode()) + r.close() + + + print("\n=====================================================================") + print("Better, using closing to ensure the response is closed") + print("=====================================================================\n") + with closing(s.get(URL, stream=True)) as r: + for chunk in r.iter_content(): + print("CHUNK", chunk) + + +async def async_examples(): + async with requests.AsyncSession() as s: + print("\n====================================================================") + print("Using asyncio") + print("====================================================================\n") + r = await s.get(URL, stream=True) + async for chunk in r.aiter_content(): + print("CHUNK", chunk) + await r.aclose() + + print("\n====================================================================") + print("Better, using aclosing to ensure the response is closed") + print("====================================================================\n") + async with aclosing(await s.get(URL, stream=True)) as r: + async for chunk in r.aiter_content(): + print("CHUNK", chunk) + + +asyncio.run(async_examples())