Skip to content

Commit

Permalink
Add optional pre-fetching of paginate results
Browse files Browse the repository at this point in the history
Fix #145
  • Loading branch information
cbornet committed Jan 1, 2024
1 parent 6a83a1e commit 5c92836
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 23 deletions.
102 changes: 85 additions & 17 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import asyncio
import logging
import json
import threading
from functools import partial
from queue import Queue
from types import TracebackType
from typing import (
Any,
Expand Down Expand Up @@ -252,36 +256,69 @@ def vector_find(

@staticmethod
def paginate(
*, request_method: PaginableRequestMethod, options: Optional[Dict[str, Any]]
*,
request_method: PaginableRequestMethod,
options: Optional[Dict[str, Any]],
prefetched: Optional[int] = None,
) -> Iterable[API_DOC]:
"""
Generate paginated results for a given database query method.
Args:
request_method (function): The database query method to paginate.
options (dict): Options for the database query.
kwargs: Additional arguments to pass to the database query method.
options (dict, optional): Options for the database query.
prefetched (int, optional): Number of pre-fetched documents.
Yields:
dict: The next document in the paginated result set.
"""
_options = options or {}
response0 = request_method(options=_options)
next_page_state = response0["data"]["nextPageState"]
options0 = _options
for document in response0["data"]["documents"]:
yield document
while next_page_state is not None:
if next_page_state is not None and prefetched:

def queued_paginate(
queue: Queue[Optional[API_DOC]],
request_method: PaginableRequestMethod,
options: Optional[Dict[str, Any]],
):
try:
for row in AstraDBCollection.paginate(
request_method=request_method, options=options
):
queue.put(row)
finally:
queue.put(None)

queue = Queue(prefetched)
options1 = {**options0, **{"pageState": next_page_state}}
response1 = request_method(options=options1)
for document in response1["data"]["documents"]:
t = threading.Thread(
target=queued_paginate, args=(queue, request_method, options1)
)
t.start()
for document in response0["data"]["documents"]:
yield document
next_page_state = response1["data"]["nextPageState"]
doc = queue.get()
while doc is not None:
yield doc
doc = queue.get()
t.join()
else:
for document in response0["data"]["documents"]:
yield document
while next_page_state is not None and not prefetched:
options1 = {**options0, **{"pageState": next_page_state}}
response1 = request_method(options=options1)
for document in response1["data"]["documents"]:
yield document
next_page_state = response1["data"]["nextPageState"]

def paginated_find(
self,
filter: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
sort: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
prefetched: Optional[int] = None,
) -> Iterable[API_DOC]:
"""
Perform a paginated search in the collection.
Expand All @@ -290,6 +327,7 @@ def paginated_find(
projection (dict, optional): Specifies the fields to return.
sort (dict, optional): Specifies the order in which to return matching documents.
options (dict, optional): Additional options for the query.
prefetched (int, optional): Number of pre-fetched documents.
Returns:
generator: A generator yielding documents in the paginated result set.
"""
Expand All @@ -302,6 +340,7 @@ def paginated_find(
return self.paginate(
request_method=partialed_find,
options=options,
prefetched=prefetched,
)

def pop(
Expand Down Expand Up @@ -969,35 +1008,62 @@ async def paginate(
*,
request_method: AsyncPaginableRequestMethod,
options: Optional[Dict[str, Any]],
prefetched: Optional[int] = None,
) -> AsyncGenerator[API_DOC, None]:
"""
Generate paginated results for a given database query method.
Args:
request_method (function): The database query method to paginate.
options (dict): Options for the database query.
kwargs: Additional arguments to pass to the database query method.
options (dict, optional): Options for the database query.
prefetched (int, optional): Number of pre-fetched documents.
Yields:
dict: The next document in the paginated result set.
"""
_options = options or {}
response0 = await request_method(options=_options)
next_page_state = response0["data"]["nextPageState"]
options0 = _options
for document in response0["data"]["documents"]:
yield document
while next_page_state is not None:
if next_page_state is not None and prefetched:

async def queued_paginate(
queue: asyncio.Queue[Optional[API_DOC]],
request_method: AsyncPaginableRequestMethod,
options: Optional[Dict[str, Any]],
):
try:
async for doc in AsyncAstraDBCollection.paginate(
request_method=request_method, options=options
):
await queue.put(doc)
finally:
await queue.put(None)

queue = asyncio.Queue(prefetched)
options1 = {**options0, **{"pageState": next_page_state}}
response1 = await request_method(options=options1)
for document in response1["data"]["documents"]:
asyncio.create_task(queued_paginate(queue, request_method, options1))
for document in response0["data"]["documents"]:
yield document
doc = await queue.get()
while doc is not None:
yield doc
doc = await queue.get()
else:
for document in response0["data"]["documents"]:
yield document
next_page_state = response1["data"]["nextPageState"]
while next_page_state is not None:
options1 = {**options0, **{"pageState": next_page_state}}
response1 = await request_method(options=options1)
for document in response1["data"]["documents"]:
yield document
next_page_state = response1["data"]["nextPageState"]

def paginated_find(
self,
filter: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
sort: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
prefetched: Optional[int] = None,
) -> AsyncIterable[API_DOC]:
"""
Perform a paginated search in the collection.
Expand All @@ -1006,6 +1072,7 @@ def paginated_find(
projection (dict, optional): Specifies the fields to return.
sort (dict, optional): Specifies the order in which to return matching documents.
options (dict, optional): Additional options for the query.
prefetched (int, optional): Number of pre-fetched documents.
Returns:
generator: A generator yielding documents in the paginated result set.
"""
Expand All @@ -1018,6 +1085,7 @@ def paginated_find(
return self.paginate(
request_method=partialed_find,
options=options,
prefetched=prefetched,
)

async def pop(
Expand Down
15 changes: 12 additions & 3 deletions tests/astrapy/test_async_db_dml_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints
N = 200 # must be EVEN
FIND_LIMIT = 183 # Keep this > 20 and <= N to actually put pagination to test
PREFETCHED = 42 # Keep this > 20 and <= FIND_LIMIT to actually trigger prefetching

T = TypeVar("T")

Expand Down Expand Up @@ -83,13 +84,21 @@ async def pag_test_collection(
@pytest.mark.describe(
"should retrieve the required amount of documents, all different, through pagination"
)
async def test_find_paginated(pag_test_collection: AsyncAstraDBCollection) -> None:
@pytest.mark.parametrize(
"prefetched",
[
pytest.param(None, id="without pre-fetching"),
pytest.param(PREFETCHED, id="with pre-fetching"),
],
)
async def test_find_paginated(
prefetched: Optional[int], pag_test_collection: AsyncAstraDBCollection
) -> None:
options = {"limit": FIND_LIMIT}
projection = {"$vector": 0}

paginated_documents = pag_test_collection.paginated_find(
projection=projection,
options=options,
projection=projection, options=options, prefetched=prefetched
)
paginated_ids = [doc["_id"] async for doc in paginated_documents]
assert len(paginated_ids) == FIND_LIMIT
Expand Down
15 changes: 12 additions & 3 deletions tests/astrapy/test_db_dml_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints
N = 200 # must be EVEN
FIND_LIMIT = 183 # Keep this > 20 and <= N to actually put pagination to test
PREFETCHED = 42 # Keep this > 20 and <= FIND_LIMIT to actually trigger prefetching

T = TypeVar("T")

Expand Down Expand Up @@ -80,13 +81,21 @@ def pag_test_collection(
@pytest.mark.describe(
"should retrieve the required amount of documents, all different, through pagination"
)
def test_find_paginated(pag_test_collection: AstraDBCollection) -> None:
@pytest.mark.parametrize(
"prefetched",
[
pytest.param(None, id="without pre-fetching"),
pytest.param(PREFETCHED, id="with pre-fetching"),
],
)
def test_find_paginated(
prefetched: Optional[int], pag_test_collection: AstraDBCollection
) -> None:
options = {"limit": FIND_LIMIT}
projection = {"$vector": 0}

paginated_documents = pag_test_collection.paginated_find(
projection=projection,
options=options,
projection=projection, options=options, prefetched=prefetched
)
paginated_ids = [doc["_id"] for doc in paginated_documents]
assert len(paginated_ids) == FIND_LIMIT
Expand Down

0 comments on commit 5c92836

Please sign in to comment.