From 68604a7c271f886b327c8bcd79da3fc3d3590668 Mon Sep 17 00:00:00 2001 From: Louis Maddox Date: Wed, 11 Aug 2021 18:39:59 +0100 Subject: [PATCH] Add async fetcher (RangeStream only) and tests for async single request RangeStream creation --- requirements.txt | 1 + src/range_streams/__init__.py | 2 + src/range_streams/async_utils.py | 187 +++++++++++++++++++++++++++++++ src/range_streams/log_utils.py | 21 ++++ tests/async_test.py | 80 +++++++++++++ 5 files changed, 291 insertions(+) create mode 100644 src/range_streams/async_utils.py create mode 100644 src/range_streams/log_utils.py create mode 100644 tests/async_test.py diff --git a/requirements.txt b/requirements.txt index d8bc93f..a24a045 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ aiostream httpx python-ranges pyzstd +tqdm diff --git a/src/range_streams/__init__.py b/src/range_streams/__init__.py index 6ac60c1..0c0dfd4 100644 --- a/src/range_streams/__init__.py +++ b/src/range_streams/__init__.py @@ -125,6 +125,7 @@ # Get classes into package namespace but exclude from __all__ so Sphinx can access types from . import codecs, http_utils, overlaps, range_utils +from .async_utils import AsyncFetcher from .request import RangeRequest from .response import RangeResponse from .stream import RangeStream @@ -137,6 +138,7 @@ "overlaps", "range_utils", "codecs", + "async_utils", ] __author__ = "Louis Maddox" diff --git a/src/range_streams/async_utils.py b/src/range_streams/async_utils.py new file mode 100644 index 0000000..71779f2 --- /dev/null +++ b/src/range_streams/async_utils.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import asyncio +import time +from asyncio.events import AbstractEventLoop +from functools import partial +from signal import SIGINT, SIGTERM, Signals +from sys import stderr +from typing import TYPE_CHECKING, Callable, Coroutine, Iterator + +from aiostream import stream +from ranges import Range, RangeSet + +MYPY = False # when using mypy will be overrided as True +if MYPY or not TYPE_CHECKING: # pragma: no cover + import httpx # avoid importing to Sphinx type checker + +import tqdm +from tqdm.asyncio import tqdm_asyncio + +from .log_utils import log, set_up_logging +from .stream import RangeStream + +__all__ = ["SignalHaltError", "AsyncFetcher"] + + +class AsyncFetcher: + def __init__( + self, + urls: list[str], + callback: Callable | None = None, + verbose: bool = False, + show_progress_bar: bool = True, + timeout_s: float = 5.0, + client=None, + ): + """ + Args: + callback : A function to be passed 3 values: the AsyncFetcher which is calling + it, the awaited RangeStream, and its source URL (a ``httpx.URL``, + which can be coerced to a string). + """ + if urls == []: + raise ValueError("The list of URLs to fetch cannot be empty") + self.url_list = urls + self.callback = callback + self.n = len(urls) + self.verbose = verbose + self.show_progress_bar = show_progress_bar and not self.verbose + self.client = client + self.timeout = httpx.Timeout(timeout=timeout_s) + self.completed = RangeSet() + set_up_logging(quiet=not verbose) + + def make_calls(self): + """ + The method called to run the event loop to fetch URLs, after initialisation + and/or repeatedly upon exitting the loop (i.e. it can recover from errors). + """ + urlset = (u for u in self.filtered_url_list) # single use URL generator + if self.show_progress_bar: + self.set_up_progress_bar() + self.fetch_things(urls=urlset) + if self.show_progress_bar: + self.pbar.close() + + async def process_stream(self, rstream: RangeStream): + """ + Process an awaited RangeStream within an async fetch loop, calling the callback + set on the `~range_streams.async_utils.AsyncFetcher.callback` attribute. + + Args: + rstream : The awaited RangeStream + """ + monostream_response = rstream._ranges[rstream.total_range] + resp = monostream_response.request.response # httpx.Response + source_url = resp.history[0].url if resp.history else resp.url + # Map the response back to the thing it came from in the url_list + i = next(i for (i, u) in enumerate(self.url_list) if source_url == u) + if self.callback is not None: + await self.callback(self, stream, source_url) + if self.verbose: + log.debug(f"Processed URL in async callback: {source_url}") + if self.show_progress_bar: + self.pbar.update() + self.completed.add(Range(i, i + 1)) + await resp.aclose() + + @property + def filtered_url_list(self) -> list[str]: + if self.completed.isempty(): + urls = self.url_list + else: + urls = [u for (i, u) in enumerate(self.url_list) if i not in self.completed] + return urls + + def set_up_progress_bar(self): + n_already_fetched = self.n - len(self.filtered_url_list) + self.pbar = tqdm_asyncio(total=self.n) + if n_already_fetched: + self.pbar.update(n_already_fetched) + self.pbar.refresh() + + def fetch_things(self, urls: Iterator[str]): + try: + return asyncio.run(self.async_fetch_urlset(urls)) + except SignalHaltError as exc: + if self.show_progress_bar: + self.pbar.disable = True + self.pbar.close() + + async def fetch(self, client: httpx.AsyncClient, url: httpx.URL) -> RangeStream: + s = RangeStream( + url=str(url), client=client, single_request=True, force_async=True + ) + await s.add_async() + return s + + async def async_fetch_urlset( + self, + urls: Iterator[str], + ) -> Coroutine: + """ + If the `~range_streams.async_utils.AsyncFetcher.client` is ``None``, create one + in a contextmanager block (i.e. close it immediately after use), otherwise use + the one provided, not in a contextmanager block (i.e. leave it up to the user to + close the client). + """ + await self.set_async_signal_handlers() + if self.client is None: + async with httpx.AsyncClient() as client: + processed = await self.fetch_and_process(urls=urls, client=client) + else: + if self.client.is_closed: + msg = ( + "Cannot use a closed client to fetch.\n\nDid you attempt to retry " + " after using the client in a contextmanager block (which implicitly" + " closes after exiting the block) perhaps?" + ) + raise ValueError(msg) + # assert self.client is not None # give mypy a clue + processed = await self.fetch_and_process(urls=urls, client=client) + return processed + + async def fetch_and_process(self, urls: Iterator[str], client): + assert isinstance(client, httpx.AsyncClient) # Not type checked due to Sphinx + client.timeout = self.timeout + ws = stream.repeat(client) + xs = stream.zip(ws, stream.iterate(urls)) + ys = stream.starmap(xs, self.fetch, ordered=False, task_limit=20) + zs = stream.map(ys, self.process_stream) + return await zs + + def immediate_exit(self, signal_enum: Signals, loop: AbstractEventLoop) -> None: + loop.stop() + halt_error = SignalHaltError(signal_enum=signal_enum) + raise halt_error + + async def set_async_signal_handlers(self) -> None: + loop = asyncio.get_running_loop() + for signal_enum in [SIGINT, SIGTERM]: + exit_func = partial(self.immediate_exit, signal_enum=signal_enum, loop=loop) + loop.add_signal_handler(signal_enum, exit_func) + + +class SignalHaltError(SystemExit): + def __init__(self, signal_enum: Signals): + self.signal_enum = signal_enum + print("", file=stderr) # Newline after the signal sequence printed to console + log.critical(msg=repr(self)) + super().__init__(self.exit_code) + + @property + def exit_code(self) -> int: + return self.signal_enum.value + + def __repr__(self) -> str: + return f"Exitted due to {self.signal_enum.name}" + + +# def demo_fetch(url_list): +# fetched = AsyncFetcher(urls=url_list, verbose=False) +# try: +# fetched.make_calls() +# except Exception as exc: +# log.debug("DEBUG ::" + repr(exc)) # Suppress it to log +# print(f"... {exc!r}") diff --git a/src/range_streams/log_utils.py b/src/range_streams/log_utils.py new file mode 100644 index 0000000..ef397a5 --- /dev/null +++ b/src/range_streams/log_utils.py @@ -0,0 +1,21 @@ +import logging + +__all__ = ["log", "set_up_logging"] + +log = logging.getLogger() # Provided for ease of access in other modules + + +def set_up_logging(quiet: bool = True): + """ + Initialise the log + + Args: + quiet : Change this flag to True/False to turn off/on console logging + """ + log.setLevel(logging.DEBUG) + log_format = logging.Formatter("[%(asctime)s] [%(levelname)s] - %(message)s") + if not quiet: + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + console.setFormatter(log_format) + log.addHandler(console) diff --git a/tests/async_test.py b/tests/async_test.py new file mode 100644 index 0000000..44e8eff --- /dev/null +++ b/tests/async_test.py @@ -0,0 +1,80 @@ +import asyncio +from signal import SIGINT + +from pytest import fixture, mark, raises +from ranges import Range + +from range_streams import _EXAMPLE_PNG_URL, _EXAMPLE_ZIP_URL, RangeStream +from range_streams.async_utils import AsyncFetcher, SignalHaltError + +from .data import EXAMPLE_FILE_LENGTH, EXAMPLE_URL + +# https://tonybaloney.github.io/posts/async-test-patterns-for-pytest-and-unittest.html + +THREE_URLS = [EXAMPLE_URL, _EXAMPLE_PNG_URL, _EXAMPLE_ZIP_URL] + + +class CallbackMutatedClass: + values = [] + + @classmethod + def reset(cls): + """ + Reset the class attribute where tests store the URLs they called back from + """ + cls.values = [] + + +async def demo_callback_func(fetcher, range_stream, url): + return CallbackMutatedClass.values.append(url) + + +async def sigint_callback_func(fetcher, range_stream, url): + """ + Mimic the act of sending the signal interrupt by raising it in a callback + """ + await demo_callback_func(fetcher, range_stream, url) + # raise KeyboardInterrupt ? + loop = asyncio.get_running_loop() + fetcher.immediate_exit(signal_enum=SIGINT, loop=loop) + + +@mark.parametrize("callback", [None, demo_callback_func]) +@mark.parametrize("verbose", [True, False]) +@mark.parametrize("error_msg", ["The list of URLs to fetch cannot be empty"]) +@mark.parametrize("urls", [([]), (THREE_URLS)]) +def test_fetcher(urls, error_msg, verbose, callback): + """ + Fetch lists of 0 or 3 URLs asynchronously, with/out a callback, verbosely/quietly. + """ + args = dict(callback=callback, urls=urls, verbose=verbose, show_progress_bar=False) + if urls == []: + with raises(ValueError, match=error_msg): + fetched = AsyncFetcher(**args) + else: + fetched = AsyncFetcher(**args) + fetched.make_calls() + expected_values = set() if callback is None else set(urls) + stored_urls = getattr(CallbackMutatedClass, "values") + assert set(stored_urls) == set(expected_values) + CallbackMutatedClass.reset() + + +@mark.parametrize("callback", [sigint_callback_func]) +@mark.parametrize("error_msg", ["The list of URLs to fetch cannot be empty"]) +@mark.parametrize("urls", [(THREE_URLS)]) +def test_fetcher_sigint(urls, error_msg, callback): + """ + Fetch lists of 3 URLs asynchronously, with/out a callback, verbosely/quietly. + Cannot figure out how to emulate passing the SIGINT from this test so can't catch, + best I can do here is to check that the loop is stopped at the first callback when + ``immediate_exit`` is called. + """ + args = dict(callback=callback, urls=urls, show_progress_bar=False) + fetched = AsyncFetcher(**args) + # with raises(SignalHaltError, match=error_msg): + fetched.make_calls() + stored_urls = getattr(CallbackMutatedClass, "values") + assert len(stored_urls) == 1 + assert set(stored_urls) < set(urls) + CallbackMutatedClass.reset()