Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

commandcursor (+async), used in list_collections #243

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions astrapy/idiomatic/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
Any,
Dict,
Generic,
List,
Optional,
TypeVar,
Expand All @@ -36,6 +37,7 @@


BC = TypeVar("BC", bound="BaseCursor")
T = TypeVar("T")

FIND_PREFETCH = 20

Expand Down Expand Up @@ -403,3 +405,87 @@ async def distinct(self, key: str) -> List[Any]:
if key in document
}
)


class CommandCursor(Generic[T]):
def __init__(self, address: str, items: List[T]) -> None:
self._address = address
self.items = items
self.iterable = items.__iter__()
self._alive = True

def __iter__(self) -> CommandCursor[T]:
self._ensure_alive()
return self

def __next__(self) -> T:
try:
item = self.iterable.__next__()
return item
except StopIteration:
self._alive = False
raise

@property
def address(self) -> str:
return self._address

@property
def alive(self) -> bool:
return self._alive

@property
def cursor_id(self) -> int:
return id(self)

def _ensure_alive(self) -> None:
if not self._alive:
raise ValueError("Cursor is closed.")

def try_next(self) -> T:
return self.__next__()

def close(self) -> None:
self._alive = False


class AsyncCommandCursor(Generic[T]):
def __init__(self, address: str, items: List[T]) -> None:
self._address = address
self.items = items
self.iterable = items.__iter__()
self._alive = True

def __aiter__(self) -> AsyncCommandCursor[T]:
self._ensure_alive()
return self

async def __anext__(self) -> T:
try:
item = self.iterable.__next__()
return item
except StopIteration:
self._alive = False
raise StopAsyncIteration

@property
def address(self) -> str:
return self._address

@property
def alive(self) -> bool:
return self._alive

@property
def cursor_id(self) -> int:
return id(self)

def _ensure_alive(self) -> None:
if not self._alive:
raise ValueError("Cursor is closed.")

async def try_next(self) -> T:
return await self.__anext__()

def close(self) -> None:
self._alive = False
31 changes: 19 additions & 12 deletions astrapy/idiomatic/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING

from astrapy.db import AstraDB, AsyncAstraDB
from astrapy.idiomatic.cursors import AsyncCommandCursor, CommandCursor


if TYPE_CHECKING:
Expand Down Expand Up @@ -258,7 +259,7 @@ def list_collections(
self,
*,
namespace: Optional[str] = None,
) -> List[Dict[str, Any]]:
) -> CommandCursor[Dict[str, Any]]:
if namespace:
_client = self._astra_db.copy(namespace=namespace)
else:
Expand All @@ -271,10 +272,13 @@ def list_collections(
)
else:
# we know this is a list of dicts which need a little adjusting
return [
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
]
return CommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)

def list_collection_names(
self,
Expand Down Expand Up @@ -510,28 +514,31 @@ async def drop_collection(
dc_response = await self._astra_db.delete_collection(name_or_collection)
return dc_response.get("status", {}) # type: ignore[no-any-return]

async def list_collections(
def list_collections(
self,
*,
namespace: Optional[str] = None,
) -> List[Dict[str, Any]]:
) -> AsyncCommandCursor[Dict[str, Any]]:
_client: AsyncAstraDB
if namespace:
_client = self._astra_db.copy(namespace=namespace)
else:
_client = self._astra_db
gc_response = await _client.get_collections(options={"explain": True})
gc_response = _client.to_sync().get_collections(options={"explain": True})
if "collections" not in gc_response.get("status", {}):
raise ValueError(
"Could not complete a get_collections operation. "
f"(gotten '${json.dumps(gc_response)}')"
)
else:
# we know this is a list of dicts which need a little adjusting
return [
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
]
return AsyncCommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)

async def list_collection_names(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/idiomatic/integration/test_ddl_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_collection_lifecycle_async(
TEST_LOCAL_COLLECTION_NAME_B,
indexing={"allow": ["z"]},
)
lc_response = await async_database.list_collections()
lc_response = [col async for col in async_database.list_collections()]
#
expected_coll_dict = {
"name": TEST_LOCAL_COLLECTION_NAME,
Expand Down
2 changes: 1 addition & 1 deletion tests/idiomatic/integration/test_ddl_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_collection_lifecycle_sync(
TEST_LOCAL_COLLECTION_NAME_B,
indexing={"allow": ["z"]},
)
lc_response = sync_database.list_collections()
lc_response = list(sync_database.list_collections())
#
expected_coll_dict = {
"name": TEST_LOCAL_COLLECTION_NAME,
Expand Down
Loading