diff --git a/astrapy/db.py b/astrapy/db.py index 9e7f477a..83904678 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations + +import asyncio import logging import json +import threading from functools import partial +from queue import Queue from types import TracebackType from typing import ( Any, @@ -252,14 +256,17 @@ def vector_find( @staticmethod def paginate( - *, request_method: PaginableRequestMethod, options: Optional[Dict[str, Any]] + *, + request_method: PaginableRequestMethod, + options: Optional[Dict[str, Any]], + prefetched: Optional[int] = None, ) -> Iterable[API_DOC]: """ Generate paginated results for a given database query method. Args: request_method (function): The database query method to paginate. - options (dict): Options for the database query. - kwargs: Additional arguments to pass to the database query method. + options (dict, optional): Options for the database query. + prefetched (int, optional): Number of pre-fetched documents. Yields: dict: The next document in the paginated result set. """ @@ -267,14 +274,43 @@ def paginate( response0 = request_method(options=_options) next_page_state = response0["data"]["nextPageState"] options0 = _options - for document in response0["data"]["documents"]: - yield document - while next_page_state is not None: + if next_page_state is not None and prefetched: + + def queued_paginate( + queue: Queue[Optional[API_DOC]], + request_method: PaginableRequestMethod, + options: Optional[Dict[str, Any]], + ) -> None: + try: + for row in AstraDBCollection.paginate( + request_method=request_method, options=options + ): + queue.put(row) + finally: + queue.put(None) + + queue: Queue[Optional[API_DOC]] = Queue(prefetched) options1 = {**options0, **{"pageState": next_page_state}} - response1 = request_method(options=options1) - for document in response1["data"]["documents"]: + t = threading.Thread( + target=queued_paginate, args=(queue, request_method, options1) + ) + t.start() + for document in response0["data"]["documents"]: yield document - next_page_state = response1["data"]["nextPageState"] + doc = queue.get() + while doc is not None: + yield doc + doc = queue.get() + t.join() + else: + for document in response0["data"]["documents"]: + yield document + while next_page_state is not None and not prefetched: + options1 = {**options0, **{"pageState": next_page_state}} + response1 = request_method(options=options1) + for document in response1["data"]["documents"]: + yield document + next_page_state = response1["data"]["nextPageState"] def paginated_find( self, @@ -282,6 +318,7 @@ def paginated_find( projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, + prefetched: Optional[int] = None, ) -> Iterable[API_DOC]: """ Perform a paginated search in the collection. @@ -290,6 +327,7 @@ def paginated_find( projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return matching documents. options (dict, optional): Additional options for the query. + prefetched (int, optional): Number of pre-fetched documents. Returns: generator: A generator yielding documents in the paginated result set. """ @@ -302,6 +340,7 @@ def paginated_find( return self.paginate( request_method=partialed_find, options=options, + prefetched=prefetched, ) def pop( @@ -969,13 +1008,14 @@ async def paginate( *, request_method: AsyncPaginableRequestMethod, options: Optional[Dict[str, Any]], + prefetched: Optional[int] = None, ) -> AsyncGenerator[API_DOC, None]: """ Generate paginated results for a given database query method. Args: request_method (function): The database query method to paginate. - options (dict): Options for the database query. - kwargs: Additional arguments to pass to the database query method. + options (dict, optional): Options for the database query. + prefetched (int, optional): Number of pre-fetched documents. Yields: dict: The next document in the paginated result set. """ @@ -983,14 +1023,39 @@ async def paginate( response0 = await request_method(options=_options) next_page_state = response0["data"]["nextPageState"] options0 = _options - for document in response0["data"]["documents"]: - yield document - while next_page_state is not None: + if next_page_state is not None and prefetched: + + async def queued_paginate( + queue: asyncio.Queue[Optional[API_DOC]], + request_method: AsyncPaginableRequestMethod, + options: Optional[Dict[str, Any]], + ) -> None: + try: + async for doc in AsyncAstraDBCollection.paginate( + request_method=request_method, options=options + ): + await queue.put(doc) + finally: + await queue.put(None) + + queue: asyncio.Queue[Optional[API_DOC]] = asyncio.Queue(prefetched) options1 = {**options0, **{"pageState": next_page_state}} - response1 = await request_method(options=options1) - for document in response1["data"]["documents"]: + asyncio.create_task(queued_paginate(queue, request_method, options1)) + for document in response0["data"]["documents"]: + yield document + doc = await queue.get() + while doc is not None: + yield doc + doc = await queue.get() + else: + for document in response0["data"]["documents"]: yield document - next_page_state = response1["data"]["nextPageState"] + while next_page_state is not None: + options1 = {**options0, **{"pageState": next_page_state}} + response1 = await request_method(options=options1) + for document in response1["data"]["documents"]: + yield document + next_page_state = response1["data"]["nextPageState"] def paginated_find( self, @@ -998,6 +1063,7 @@ def paginated_find( projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, + prefetched: Optional[int] = None, ) -> AsyncIterable[API_DOC]: """ Perform a paginated search in the collection. @@ -1006,6 +1072,7 @@ def paginated_find( projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return matching documents. options (dict, optional): Additional options for the query. + prefetched (int, optional): Number of pre-fetched documents Returns: generator: A generator yielding documents in the paginated result set. """ @@ -1018,6 +1085,7 @@ def paginated_find( return self.paginate( request_method=partialed_find, options=options, + prefetched=prefetched, ) async def pop( diff --git a/tests/astrapy/test_async_db_dml_pagination.py b/tests/astrapy/test_async_db_dml_pagination.py index 7048bd01..3b93f7e5 100644 --- a/tests/astrapy/test_async_db_dml_pagination.py +++ b/tests/astrapy/test_async_db_dml_pagination.py @@ -31,6 +31,7 @@ INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints N = 200 # must be EVEN FIND_LIMIT = 183 # Keep this > 20 and <= N to actually put pagination to test +PREFETCHED = 42 # Keep this > 20 and <= FIND_LIMIT to actually trigger prefetching T = TypeVar("T") @@ -83,13 +84,21 @@ async def pag_test_collection( @pytest.mark.describe( "should retrieve the required amount of documents, all different, through pagination" ) -async def test_find_paginated(pag_test_collection: AsyncAstraDBCollection) -> None: +@pytest.mark.parametrize( + "prefetched", + [ + pytest.param(None, id="without pre-fetching"), + pytest.param(PREFETCHED, id="with pre-fetching"), + ], +) +async def test_find_paginated( + prefetched: Optional[int], pag_test_collection: AsyncAstraDBCollection +) -> None: options = {"limit": FIND_LIMIT} projection = {"$vector": 0} paginated_documents = pag_test_collection.paginated_find( - projection=projection, - options=options, + projection=projection, options=options, prefetched=prefetched ) paginated_ids = [doc["_id"] async for doc in paginated_documents] assert len(paginated_ids) == FIND_LIMIT diff --git a/tests/astrapy/test_db_dml_pagination.py b/tests/astrapy/test_db_dml_pagination.py index 29823cbd..27cb099c 100644 --- a/tests/astrapy/test_db_dml_pagination.py +++ b/tests/astrapy/test_db_dml_pagination.py @@ -32,6 +32,7 @@ INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints N = 200 # must be EVEN FIND_LIMIT = 183 # Keep this > 20 and <= N to actually put pagination to test +PREFETCHED = 42 # Keep this > 20 and <= FIND_LIMIT to actually trigger prefetching T = TypeVar("T") @@ -80,13 +81,21 @@ def pag_test_collection( @pytest.mark.describe( "should retrieve the required amount of documents, all different, through pagination" ) -def test_find_paginated(pag_test_collection: AstraDBCollection) -> None: +@pytest.mark.parametrize( + "prefetched", + [ + pytest.param(None, id="without pre-fetching"), + pytest.param(PREFETCHED, id="with pre-fetching"), + ], +) +def test_find_paginated( + prefetched: Optional[int], pag_test_collection: AstraDBCollection +) -> None: options = {"limit": FIND_LIMIT} projection = {"$vector": 0} paginated_documents = pag_test_collection.paginated_find( - projection=projection, - options=options, + projection=projection, options=options, prefetched=prefetched ) paginated_ids = [doc["_id"] for doc in paginated_documents] assert len(paginated_ids) == FIND_LIMIT