Skip to content

Commit

Permalink
Add stream support for asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
perklet committed Oct 1, 2023
1 parent 3475645 commit bb9dcbb
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 41 deletions.
5 changes: 3 additions & 2 deletions curl_cffi/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ._wrapper import ffi, lib # type: ignore
from .const import CurlMOpt
from .curl import Curl, CurlError, CurlInfo
from .curl import Curl

DEFAULT_CACERT = os.path.join(os.path.dirname(__file__), "cacert.pem")

Expand Down Expand Up @@ -72,6 +72,7 @@ def socket_function(curl, sockfd: int, what: int, clientp: Any, data: Any):
class AsyncCurl:
"""Wrapper around curl_multi handle to provide asyncio support. It uses the libcurl
socket_action APIs."""

def __init__(self, cacert: str = DEFAULT_CACERT, loop=None):
self._curlm = lib.curl_multi_init()
self._cacert = cacert
Expand Down Expand Up @@ -103,7 +104,7 @@ def close(self):
lib.curl_multi_cleanup(self._curlm)
self._curlm = None
# Remove add readers and writers
for sockfd in self._sockfds:
for sockfd in self._sockfds:
self.loop.remove_reader(sockfd)
self.loop.remove_writer(sockfd)
# Cancel all time functions
Expand Down
59 changes: 48 additions & 11 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non
self.http_version = 0
self.history = []
self.queue: Optional[queue.Queue] = None
self.stream_task = None

@property
def text(self) -> str:
Expand Down Expand Up @@ -101,19 +102,55 @@ def iter_content(self, chunk_size=None, decode_unicode=False):
warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode:
raise NotImplementedError()
try:
while True:
chunk = self.queue.get() # type: ignore
if chunk is None:
return
yield chunk
finally:
# If anything happens, always free the memory
self.curl.reset() # type: ignore
clear_queue(self.queue) # type: ignore
while True:
chunk = self.queue.get() # type: ignore
if chunk is None:
return
yield chunk

def json(self, **kw):
return loads(self.content, **kw)

def close(self):
warnings.warn("Deprecated, use Session.close")
self.stream_task.result() # type: ignore

async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
"""
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending = None

async for chunk in self.aiter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
if pending is not None:
chunk = pending + chunk
if delimiter:
lines = chunk.split(delimiter)
else:
lines = chunk.splitlines()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None

for line in lines:
yield line

if pending is not None:
yield pending

async def aiter_content(self, chunk_size=None, decode_unicode=False):
if chunk_size:
warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode:
raise NotImplementedError()
while True:
chunk = await self.queue.get() # type: ignore
if chunk is None:
return
yield chunk

async def aclose(self):
await self.stream_task # type: ignore
84 changes: 59 additions & 25 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partialmethod
from io import BytesIO
from json import dumps
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from typing import Callable, Dict, List, Any, Optional, Tuple, Union, cast
from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse
from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -188,6 +188,7 @@ def _set_curl_options(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
queue_class: Any = None,
):
c = curl

Expand Down Expand Up @@ -359,8 +360,10 @@ def _set_curl_options(
c.setopt(k, v)

buffer = None
q = None
if stream:
c.setopt(CurlOpt.WRITEFUNCTION, self.queue.put) # type: ignore
q = queue_class() # type: ignore
c.setopt(CurlOpt.WRITEFUNCTION, q.put_nowait) # type: ignore
elif content_callback is not None:
c.setopt(CurlOpt.WRITEFUNCTION, content_callback)
else:
Expand All @@ -377,7 +380,7 @@ def _set_curl_options(
if interface:
c.setopt(CurlOpt.INTERFACE, interface.encode())

return req, buffer, header_buffer
return req, buffer, header_buffer, q

def _parse_response(self, curl, buffer, header_buffer):
c = curl
Expand Down Expand Up @@ -505,17 +508,6 @@ def executor(self):
self._executor = ThreadPoolExecutor()
return self._executor

@property
def queue(self):
if self._use_thread_local_curl:
if getattr(self._local, "queue", None) is None:
self._local.queue = queue.Queue()
return self._local.queue
else:
if self._queue is None:
self._queue = queue.Queue()
return self._queue

def __enter__(self):
return self

Expand Down Expand Up @@ -553,7 +545,7 @@ def request(
) -> Response:
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""
c = self.curl
req, buffer, header_buffer = self._set_curl_options(
req, buffer, header_buffer, q = self._set_curl_options(
c,
method=method,
url=url,
Expand All @@ -577,6 +569,7 @@ def request(
http_version=http_version,
interface=interface,
stream=stream,
queue_class=queue.Queue,
)
try:
if self._thread == "eventlet":
Expand All @@ -587,13 +580,17 @@ def request(
gevent.get_hub().threadpool.spawn(c.perform).get()
else:
if stream:
queue = self.queue # using queue from current thread

def perform():
c.perform()
queue.put(None) # sentinel
# None acts as a sentinel
q.put(None) # type: ignore

def cleanup(fut):
c.reset()

self.executor.submit(perform)
stream_task = self.executor.submit(perform)
stream_task.add_done_callback(cleanup)
else:
c.perform()
except CurlError as e:
Expand All @@ -603,7 +600,9 @@ def perform():
else:
rsp = self._parse_response(c, buffer, header_buffer)
rsp.request = req
rsp.queue = self.queue
if stream:
rsp.stream_task = stream_task # type: ignore
rsp.queue = q
return rsp
finally:
if not stream:
Expand Down Expand Up @@ -660,6 +659,7 @@ def __init__(
self.loop = loop
self._acurl = async_curl
self.max_clients = max_clients
self._closed = False
self.init_pool()
if sys.version_info >= (3, 8) and sys.platform.lower().startswith("win"):
if isinstance(
Expand Down Expand Up @@ -705,6 +705,23 @@ async def __aexit__(self, *args):
def close(self):
"""Close the session."""
self.acurl.close()
self._closed = True
while True:
try:
curl = self.pool.get_nowait()
if curl:
curl.close()
except asyncio.QueueEmpty:
break

def release_curl(self, curl):
curl.clean_after_perform()
if not self._closed:
self.acurl.remove_handle(curl)
curl.reset()
self.push_curl(curl)
else:
curl.close()

async def request(
self,
Expand All @@ -729,10 +746,11 @@ async def request(
default_headers: Optional[bool] = None,
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
):
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""
curl = await self.pop_curl()
req, buffer, header_buffer = self._set_curl_options(
req, buffer, header_buffer, q = self._set_curl_options(
curl=curl,
method=method,
url=url,
Expand All @@ -755,11 +773,26 @@ async def request(
default_headers=default_headers,
http_version=http_version,
interface=interface,
stream=stream,
queue_class=asyncio.Queue,
)
try:
# curl.debug()
task = self.acurl.add_handle(curl)
await task
if stream:

async def perform():
await task
# None acts as a sentinel
await q.put(None) # type: ignore

def cleanup(fut):
self.release_curl(curl)

stream_task = asyncio.create_task(perform())
stream_task.add_done_callback(cleanup)
else:
await task
# print(curl.getinfo(CurlInfo.CAINFO))
except CurlError as e:
rsp = self._parse_response(curl, buffer, header_buffer)
Expand All @@ -768,12 +801,13 @@ async def request(
else:
rsp = self._parse_response(curl, buffer, header_buffer)
rsp.request = req
rsp.queue = q
if stream:
rsp.stream_task = stream_task # type: ignore
return rsp
finally:
curl.clean_after_perform()
self.acurl.remove_handle(curl)
curl.reset()
self.push_curl(curl)
if not stream:
self.release_curl(curl)

head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
Expand Down
59 changes: 56 additions & 3 deletions examples/stream.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,66 @@
import asyncio
from curl_cffi import requests
from contextlib import closing

try:
# Python 3.10+
from contextlib import aclosing
except ImportError:
from contextlib import asynccontextmanager

@asynccontextmanager
async def aclosing(thing):
try:
yield thing
finally:
await thing.aclose()


URL = "https://httpbin.org/stream/20"

with requests.Session() as s:
r = s.get("https://httpbin.org/stream/20", stream=True)
print("\n======================================================================")
print("Iterating over chunks")
print("=====================================================================\n")
r = s.get(URL, stream=True)
for chunk in r.iter_content():
print("CHUNK", chunk)
r.close()


with requests.Session() as s:
r = s.get("https://httpbin.org/stream/20", stream=True)
print("\n=====================================================================")
print("Iterating on a line basis")
print("=====================================================================\n")
r = s.get(URL, stream=True)
for line in r.iter_lines():
print("LINE", line.decode())
r.close()


print("\n=====================================================================")
print("Better, using closing to ensure the response is closed")
print("=====================================================================\n")
with closing(s.get(URL, stream=True)) as r:
for chunk in r.iter_content():
print("CHUNK", chunk)


async def async_examples():
async with requests.AsyncSession() as s:
print("\n====================================================================")
print("Using asyncio")
print("====================================================================\n")
r = await s.get(URL, stream=True)
async for chunk in r.aiter_content():
print("CHUNK", chunk)
await r.aclose()

print("\n====================================================================")
print("Better, using aclosing to ensure the response is closed")
print("====================================================================\n")
async with aclosing(await s.get(URL, stream=True)) as r:
async for chunk in r.aiter_content():
print("CHUNK", chunk)


asyncio.run(async_examples())

0 comments on commit bb9dcbb

Please sign in to comment.