Skip to content

Commit

Permalink
commandcursor (+async), used in list_collections (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus authored Mar 7, 2024
1 parent 85cc7cc commit 9bfa169
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 14 deletions.
86 changes: 86 additions & 0 deletions astrapy/idiomatic/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
Any,
Dict,
Generic,
List,
Optional,
TypeVar,
Expand All @@ -36,6 +37,7 @@


BC = TypeVar("BC", bound="BaseCursor")
T = TypeVar("T")

FIND_PREFETCH = 20

Expand Down Expand Up @@ -403,3 +405,87 @@ async def distinct(self, key: str) -> List[Any]:
if key in document
}
)


class CommandCursor(Generic[T]):
def __init__(self, address: str, items: List[T]) -> None:
self._address = address
self.items = items
self.iterable = items.__iter__()
self._alive = True

def __iter__(self) -> CommandCursor[T]:
self._ensure_alive()
return self

def __next__(self) -> T:
try:
item = self.iterable.__next__()
return item
except StopIteration:
self._alive = False
raise

@property
def address(self) -> str:
return self._address

@property
def alive(self) -> bool:
return self._alive

@property
def cursor_id(self) -> int:
return id(self)

def _ensure_alive(self) -> None:
if not self._alive:
raise ValueError("Cursor is closed.")

def try_next(self) -> T:
return self.__next__()

def close(self) -> None:
self._alive = False


class AsyncCommandCursor(Generic[T]):
def __init__(self, address: str, items: List[T]) -> None:
self._address = address
self.items = items
self.iterable = items.__iter__()
self._alive = True

def __aiter__(self) -> AsyncCommandCursor[T]:
self._ensure_alive()
return self

async def __anext__(self) -> T:
try:
item = self.iterable.__next__()
return item
except StopIteration:
self._alive = False
raise StopAsyncIteration

@property
def address(self) -> str:
return self._address

@property
def alive(self) -> bool:
return self._alive

@property
def cursor_id(self) -> int:
return id(self)

def _ensure_alive(self) -> None:
if not self._alive:
raise ValueError("Cursor is closed.")

async def try_next(self) -> T:
return await self.__anext__()

def close(self) -> None:
self._alive = False
31 changes: 19 additions & 12 deletions astrapy/idiomatic/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING

from astrapy.db import AstraDB, AsyncAstraDB
from astrapy.idiomatic.cursors import AsyncCommandCursor, CommandCursor


if TYPE_CHECKING:
Expand Down Expand Up @@ -258,7 +259,7 @@ def list_collections(
self,
*,
namespace: Optional[str] = None,
) -> List[Dict[str, Any]]:
) -> CommandCursor[Dict[str, Any]]:
if namespace:
_client = self._astra_db.copy(namespace=namespace)
else:
Expand All @@ -271,10 +272,13 @@ def list_collections(
)
else:
# we know this is a list of dicts which need a little adjusting
return [
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
]
return CommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)

def list_collection_names(
self,
Expand Down Expand Up @@ -510,28 +514,31 @@ async def drop_collection(
dc_response = await self._astra_db.delete_collection(name_or_collection)
return dc_response.get("status", {}) # type: ignore[no-any-return]

async def list_collections(
def list_collections(
self,
*,
namespace: Optional[str] = None,
) -> List[Dict[str, Any]]:
) -> AsyncCommandCursor[Dict[str, Any]]:
_client: AsyncAstraDB
if namespace:
_client = self._astra_db.copy(namespace=namespace)
else:
_client = self._astra_db
gc_response = await _client.get_collections(options={"explain": True})
gc_response = _client.to_sync().get_collections(options={"explain": True})
if "collections" not in gc_response.get("status", {}):
raise ValueError(
"Could not complete a get_collections operation. "
f"(gotten '${json.dumps(gc_response)}')"
)
else:
# we know this is a list of dicts which need a little adjusting
return [
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
]
return AsyncCommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)

async def list_collection_names(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/idiomatic/integration/test_ddl_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_collection_lifecycle_async(
TEST_LOCAL_COLLECTION_NAME_B,
indexing={"allow": ["z"]},
)
lc_response = await async_database.list_collections()
lc_response = [col async for col in async_database.list_collections()]
#
expected_coll_dict = {
"name": TEST_LOCAL_COLLECTION_NAME,
Expand Down
2 changes: 1 addition & 1 deletion tests/idiomatic/integration/test_ddl_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_collection_lifecycle_sync(
TEST_LOCAL_COLLECTION_NAME_B,
indexing={"allow": ["z"]},
)
lc_response = sync_database.list_collections()
lc_response = list(sync_database.list_collections())
#
expected_coll_dict = {
"name": TEST_LOCAL_COLLECTION_NAME,
Expand Down

0 comments on commit 9bfa169

Please sign in to comment.