Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix find prefetching #299

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 85 additions & 41 deletions astrapy/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import asyncio
import json
import logging
import queue
import threading
from collections.abc import AsyncGenerator, AsyncIterator, Generator, Iterator
import weakref
from collections.abc import AsyncGenerator, AsyncIterator, Iterator
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from queue import Queue
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

Expand Down Expand Up @@ -54,6 +55,72 @@
logger = logging.getLogger(__name__)


class _PrefetchIterator(Iterator[API_DOC]):
def __init__(
self,
prefetched: int,
request_method: PaginableRequestMethod,
options: Optional[Dict[str, Any]],
raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
):
self.queue: queue.Queue[Optional[API_DOC]] = queue.Queue(prefetched)
self.request_method = request_method
self.options = options
self.raw_response_callback = raw_response_callback
self.initialised = threading.Event()
self.stop = threading.Event()
self.thread = threading.Thread(
target=_PrefetchIterator.queued_paginate, args=(weakref.proxy(self),)
)
self.thread.start()
# wait until the exception handler in queued_paginate can deal with the
# object being deleted.
self.initialised.wait()

def __iter__(self) -> Iterator[API_DOC]:
return self

@staticmethod
def queue_put(
q: queue.Queue[Optional[API_DOC]],
item: Optional[API_DOC],
stop: threading.Event,
) -> None:
while not stop.is_set():
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
try:
q.put(item, timeout=1)
break
except queue.Full:
# Wait until there is space in the queue or the thread
# is stopped
pass

def queued_paginate(self) -> None:
self.initialised.set()
try:
for row in AstraDBCollection.paginate(
request_method=self.request_method,
options=self.options,
raw_response_callback=self.raw_response_callback,
):
self.queue_put(self.queue, row, self.stop)
except ReferenceError:
logger.debug("queued_paginate terminated")
return
finally:
self.queue_put(self.queue, None, self.stop)
logger.debug("queued_paginate end")

def __next__(self) -> API_DOC:
doc = self.queue.get()
if doc is None:
raise StopIteration
return doc

def __del__(self) -> None:
self.stop.set()


class AstraDBCollection:
# Initialize the shared httpx client as a class attribute
client = httpx.Client()
Expand Down Expand Up @@ -403,9 +470,8 @@ def paginate(
*,
request_method: PaginableRequestMethod,
options: Optional[Dict[str, Any]],
prefetched: Optional[int] = None,
raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> Generator[API_DOC, None, None]:
) -> Iterator[API_DOC]:
"""
Generate paginated results for a given database query method.

Expand All @@ -426,45 +492,17 @@ def paginate(
raw_response_callback(response0)
next_page_state = response0["data"]["nextPageState"]
options0 = _options
if next_page_state is not None and prefetched:

def queued_paginate(
queue: Queue[Optional[API_DOC]],
request_method: PaginableRequestMethod,
options: Optional[Dict[str, Any]],
) -> None:
try:
for row in AstraDBCollection.paginate(
request_method=request_method, options=options
):
queue.put(row)
finally:
queue.put(None)

queue: Queue[Optional[API_DOC]] = Queue(prefetched)
for document in response0["data"]["documents"]:
yield document
while next_page_state is not None:
options1 = {**options0, **{"pageState": next_page_state}}
t = threading.Thread(
target=queued_paginate, args=(queue, request_method, options1)
)
t.start()
for document in response0["data"]["documents"]:
response1 = request_method(options=options1)
if raw_response_callback:
raw_response_callback(response1)
for document in response1["data"]["documents"]:
yield document
doc = queue.get()
while doc is not None:
yield doc
doc = queue.get()
t.join()
else:
for document in response0["data"]["documents"]:
yield document
while next_page_state is not None and not prefetched:
options1 = {**options0, **{"pageState": next_page_state}}
response1 = request_method(options=options1)
if raw_response_callback:
raw_response_callback(response1)
for document in response1["data"]["documents"]:
yield document
next_page_state = response1["data"]["nextPageState"]
next_page_state = response1["data"]["nextPageState"]

def paginated_find(
self,
Expand Down Expand Up @@ -505,10 +543,16 @@ def paginated_find(
sort=sort,
timeout_info=timeout_info,
)
if prefetched:
return _PrefetchIterator(
request_method=partialed_find,
options=options,
prefetched=prefetched,
raw_response_callback=raw_response_callback,
)
return self.paginate(
request_method=partialed_find,
options=options,
prefetched=prefetched,
raw_response_callback=raw_response_callback,
)

Expand Down
37 changes: 34 additions & 3 deletions tests/core/test_db_dml_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import logging
import time
from typing import Optional

import pytest
Expand All @@ -43,16 +44,46 @@
],
)
def test_find_paginated(
prefetched: Optional[int], pagination_v_collection: AstraDBCollection
prefetched: Optional[int],
pagination_v_collection: AstraDBCollection,
caplog: pytest.LogCaptureFixture,
) -> None:
options = {"limit": FIND_LIMIT}
projection = {"$vector": 0}
caplog.set_level(logging.INFO)

paginated_documents_gen = pagination_v_collection.paginated_find(
paginated_documents_it = pagination_v_collection.paginated_find(
projection=projection, options=options, prefetched=prefetched
)
paginated_documents = list(paginated_documents_gen)

time.sleep(1)
if prefetched:
# If prefetched is set requests are performed eagerly
assert caplog.text.count("HTTP Request: POST") == 3
else:
assert caplog.text.count("HTTP Request: POST") == 0

paginated_documents = list(paginated_documents_it)
paginated_ids = [doc["_id"] for doc in paginated_documents]
assert all(["$vector" not in doc for doc in paginated_documents])
assert len(paginated_ids) == FIND_LIMIT
assert len(paginated_ids) == len(set(paginated_ids))


def test_prefetched_thread_terminated(
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
pagination_v_collection: AstraDBCollection, caplog: pytest.LogCaptureFixture
) -> None:
options = {"limit": FIND_LIMIT}
projection = {"$vector": 0}
caplog.set_level(logging.DEBUG)

paginated_documents_it = pagination_v_collection.paginated_find(
projection=projection, options=options, prefetched=PREFETCHED
)

assert next(paginated_documents_it) is not None
del paginated_documents_it

time.sleep(1)

assert caplog.text.count("queued_paginate terminated") == 1
Loading