diff --git a/astrapy/idiomatic/cursors.py b/astrapy/idiomatic/cursors.py index 0c0843a6..106efbe5 100644 --- a/astrapy/idiomatic/cursors.py +++ b/astrapy/idiomatic/cursors.py @@ -18,6 +18,7 @@ from typing import ( Any, Dict, + Generic, List, Optional, TypeVar, @@ -36,6 +37,7 @@ BC = TypeVar("BC", bound="BaseCursor") +T = TypeVar("T") FIND_PREFETCH = 20 @@ -403,3 +405,87 @@ async def distinct(self, key: str) -> List[Any]: if key in document } ) + + +class CommandCursor(Generic[T]): + def __init__(self, address: str, items: List[T]) -> None: + self._address = address + self.items = items + self.iterable = items.__iter__() + self._alive = True + + def __iter__(self) -> CommandCursor[T]: + self._ensure_alive() + return self + + def __next__(self) -> T: + try: + item = self.iterable.__next__() + return item + except StopIteration: + self._alive = False + raise + + @property + def address(self) -> str: + return self._address + + @property + def alive(self) -> bool: + return self._alive + + @property + def cursor_id(self) -> int: + return id(self) + + def _ensure_alive(self) -> None: + if not self._alive: + raise ValueError("Cursor is closed.") + + def try_next(self) -> T: + return self.__next__() + + def close(self) -> None: + self._alive = False + + +class AsyncCommandCursor(Generic[T]): + def __init__(self, address: str, items: List[T]) -> None: + self._address = address + self.items = items + self.iterable = items.__iter__() + self._alive = True + + def __aiter__(self) -> AsyncCommandCursor[T]: + self._ensure_alive() + return self + + async def __anext__(self) -> T: + try: + item = self.iterable.__next__() + return item + except StopIteration: + self._alive = False + raise StopAsyncIteration + + @property + def address(self) -> str: + return self._address + + @property + def alive(self) -> bool: + return self._alive + + @property + def cursor_id(self) -> int: + return id(self) + + def _ensure_alive(self) -> None: + if not self._alive: + raise ValueError("Cursor is closed.") + + async def try_next(self) -> T: + return await self.__anext__() + + def close(self) -> None: + self._alive = False diff --git a/astrapy/idiomatic/database.py b/astrapy/idiomatic/database.py index 8eab8849..e5ada2d6 100644 --- a/astrapy/idiomatic/database.py +++ b/astrapy/idiomatic/database.py @@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING from astrapy.db import AstraDB, AsyncAstraDB +from astrapy.idiomatic.cursors import AsyncCommandCursor, CommandCursor if TYPE_CHECKING: @@ -258,7 +259,7 @@ def list_collections( self, *, namespace: Optional[str] = None, - ) -> List[Dict[str, Any]]: + ) -> CommandCursor[Dict[str, Any]]: if namespace: _client = self._astra_db.copy(namespace=namespace) else: @@ -271,10 +272,13 @@ def list_collections( ) else: # we know this is a list of dicts which need a little adjusting - return [ - _recast_api_collection_dict(col_dict) - for col_dict in gc_response["status"]["collections"] - ] + return CommandCursor( + address=self._astra_db.base_url, + items=[ + _recast_api_collection_dict(col_dict) + for col_dict in gc_response["status"]["collections"] + ], + ) def list_collection_names( self, @@ -510,17 +514,17 @@ async def drop_collection( dc_response = await self._astra_db.delete_collection(name_or_collection) return dc_response.get("status", {}) # type: ignore[no-any-return] - async def list_collections( + def list_collections( self, *, namespace: Optional[str] = None, - ) -> List[Dict[str, Any]]: + ) -> AsyncCommandCursor[Dict[str, Any]]: _client: AsyncAstraDB if namespace: _client = self._astra_db.copy(namespace=namespace) else: _client = self._astra_db - gc_response = await _client.get_collections(options={"explain": True}) + gc_response = _client.to_sync().get_collections(options={"explain": True}) if "collections" not in gc_response.get("status", {}): raise ValueError( "Could not complete a get_collections operation. " @@ -528,10 +532,13 @@ async def list_collections( ) else: # we know this is a list of dicts which need a little adjusting - return [ - _recast_api_collection_dict(col_dict) - for col_dict in gc_response["status"]["collections"] - ] + return AsyncCommandCursor( + address=self._astra_db.base_url, + items=[ + _recast_api_collection_dict(col_dict) + for col_dict in gc_response["status"]["collections"] + ], + ) async def list_collection_names( self, diff --git a/tests/idiomatic/integration/test_ddl_async.py b/tests/idiomatic/integration/test_ddl_async.py index cfff6bf0..20927c62 100644 --- a/tests/idiomatic/integration/test_ddl_async.py +++ b/tests/idiomatic/integration/test_ddl_async.py @@ -41,7 +41,7 @@ async def test_collection_lifecycle_async( TEST_LOCAL_COLLECTION_NAME_B, indexing={"allow": ["z"]}, ) - lc_response = await async_database.list_collections() + lc_response = [col async for col in async_database.list_collections()] # expected_coll_dict = { "name": TEST_LOCAL_COLLECTION_NAME, diff --git a/tests/idiomatic/integration/test_ddl_sync.py b/tests/idiomatic/integration/test_ddl_sync.py index ea611ae8..5a05b31d 100644 --- a/tests/idiomatic/integration/test_ddl_sync.py +++ b/tests/idiomatic/integration/test_ddl_sync.py @@ -41,7 +41,7 @@ def test_collection_lifecycle_sync( TEST_LOCAL_COLLECTION_NAME_B, indexing={"allow": ["z"]}, ) - lc_response = sync_database.list_collections() + lc_response = list(sync_database.list_collections()) # expected_coll_dict = { "name": TEST_LOCAL_COLLECTION_NAME,