Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sl collateral commands #242

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ def _request(
response = restore_from_api(direct_response)
return response

def post_raw_request(self, body: Dict[str, Any]) -> API_RESPONSE:
return self._request(
method=http_methods.POST,
path=self.base_path,
json_data=body,
)

def _get(
self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None
) -> Optional[API_RESPONSE]:
Expand Down Expand Up @@ -1280,6 +1287,13 @@ async def _request(
response = restore_from_api(adirect_response)
return response

async def post_raw_request(self, body: Dict[str, Any]) -> API_RESPONSE:
return await self._request(
method=http_methods.POST,
path=self.base_path,
json_data=body,
)

async def _get(
self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None
) -> Optional[API_RESPONSE]:
Expand Down Expand Up @@ -2290,6 +2304,13 @@ def _request(
response = restore_from_api(direct_response)
return response

def post_raw_request(self, body: Dict[str, Any]) -> API_RESPONSE:
return self._request(
method=http_methods.POST,
path=self.base_path,
json_data=body,
)

def collection(self, collection_name: str) -> AstraDBCollection:
"""
Retrieve a collection from the database.
Expand Down Expand Up @@ -2588,6 +2609,13 @@ async def _request(
response = restore_from_api(adirect_response)
return response

async def post_raw_request(self, body: Dict[str, Any]) -> API_RESPONSE:
return await self._request(
method=http_methods.POST,
path=self.base_path,
json_data=body,
)

async def collection(self, collection_name: str) -> AsyncAstraDBCollection:
"""
Retrieve a collection from the database.
Expand Down
42 changes: 42 additions & 0 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def __call__(self, *pargs: Any, **kwargs: Any) -> None:
raise TypeError(
f"'{self.__class__.__name__}' object is not callable. If you "
f"meant to call the '{self.name}' method on a "
f"'{self.database.__class__.__name__}' object "
"it is failing because no such method exists."
)

def copy(
self,
*,
Expand All @@ -115,6 +123,19 @@ def copy(
caller_version=caller_version or self._astra_db_collection.caller_version,
)

def with_options(
self,
*,
name: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> Collection:
return self.copy(
name=name,
caller_name=caller_name,
caller_version=caller_version,
)

def to_async(
self,
*,
Expand Down Expand Up @@ -570,6 +591,14 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def __call__(self, *pargs: Any, **kwargs: Any) -> None:
raise TypeError(
f"'{self.__class__.__name__}' object is not callable. If you "
f"meant to call the '{self.name}' method on a "
f"'{self.database.__class__.__name__}' object "
"it is failing because no such method exists."
)

def copy(
self,
*,
Expand All @@ -587,6 +616,19 @@ def copy(
caller_version=caller_version or self._astra_db_collection.caller_version,
)

def with_options(
self,
*,
name: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AsyncCollection:
return self.copy(
name=name,
caller_name=caller_name,
caller_version=caller_version,
)

def to_sync(
self,
*,
Expand Down
68 changes: 64 additions & 4 deletions astrapy/idiomatic/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ def copy(
api_version=api_version or self._astra_db.api_version,
)

def with_options(
self,
*,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> Database:
return self.copy(
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
)

def to_async(
self,
*,
Expand Down Expand Up @@ -282,6 +295,23 @@ def list_collection_names(
# we know this is a list of strings
return gc_response["status"]["collections"] # type: ignore[no-any-return]

def command(
self,
body: Dict[str, Any],
*,
namespace: Optional[str] = None,
collection_name: Optional[str] = None,
) -> Dict[str, Any]:
if namespace:
_client = self._astra_db.copy(namespace=namespace)
else:
_client = self._astra_db
if collection_name:
_collection = _client.collection(collection_name)
return _collection.post_raw_request(body=body)
else:
return _client.post_raw_request(body=body)


class AsyncDatabase:
def __init__(
Expand All @@ -305,11 +335,11 @@ def __init__(
caller_version=caller_version,
)

async def __getattr__(self, collection_name: str) -> AsyncCollection:
return await self.get_collection(name=collection_name)
def __getattr__(self, collection_name: str) -> AsyncCollection:
return self.to_sync().get_collection(name=collection_name).to_async()

async def __getitem__(self, collection_name: str) -> AsyncCollection:
return await self.get_collection(name=collection_name)
def __getitem__(self, collection_name: str) -> AsyncCollection:
return self.to_sync().get_collection(name=collection_name).to_async()

def __repr__(self) -> str:
return f'{self.__class__.__name__}[_astra_db={self._astra_db}"]'
Expand Down Expand Up @@ -360,6 +390,19 @@ def copy(
api_version=api_version or self._astra_db.api_version,
)

def with_options(
self,
*,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AsyncDatabase:
return self.copy(
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
)

def to_sync(
self,
*,
Expand Down Expand Up @@ -504,3 +547,20 @@ async def list_collection_names(
else:
# we know this is a list of strings
return gc_response["status"]["collections"] # type: ignore[no-any-return]

async def command(
self,
body: Dict[str, Any],
*,
namespace: Optional[str] = None,
collection_name: Optional[str] = None,
) -> Dict[str, Any]:
if namespace:
_client = self._astra_db.copy(namespace=namespace)
else:
_client = self._astra_db
if collection_name:
_collection = await _client.collection(collection_name)
return await _collection.post_raw_request(body=body)
else:
return await _client.post_raw_request(body=body)
31 changes: 31 additions & 0 deletions tests/idiomatic/integration/test_ddl_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,34 @@ async def test_collection_namespace_async(
TEST_LOCAL_COLLECTION_NAME2
not in await database_on_secondary.list_collection_names()
)

@pytest.mark.describe("test of collection command, async")
async def test_collection_command_async(
self,
async_database: AsyncDatabase,
async_collection: AsyncCollection,
) -> None:
cmd1 = await async_database.command(
{"countDocuments": {}}, collection_name=async_collection.name
)
assert isinstance(cmd1, dict)
assert isinstance(cmd1["status"]["count"], int)
cmd2 = await async_database.copy(namespace="...").command(
{"countDocuments": {}},
namespace=async_collection.namespace,
collection_name=async_collection.name,
)
assert cmd2 == cmd1

@pytest.mark.describe("test of database command, async")
async def test_database_command_async(
self,
async_database: AsyncDatabase,
) -> None:
cmd1 = await async_database.command({"findCollections": {}})
assert isinstance(cmd1, dict)
assert isinstance(cmd1["status"]["collections"], list)
cmd2 = await async_database.copy(namespace="...").command(
{"findCollections": {}}, namespace=async_database.namespace
)
assert cmd2 == cmd1
31 changes: 31 additions & 0 deletions tests/idiomatic/integration/test_ddl_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,34 @@ def test_collection_namespace_sync(
TEST_LOCAL_COLLECTION_NAME2
not in database_on_secondary.list_collection_names()
)

@pytest.mark.describe("test of collection command, sync")
def test_collection_command_sync(
self,
sync_database: Database,
sync_collection: Collection,
) -> None:
cmd1 = sync_database.command(
{"countDocuments": {}}, collection_name=sync_collection.name
)
assert isinstance(cmd1, dict)
assert isinstance(cmd1["status"]["count"], int)
cmd2 = sync_database.copy(namespace="...").command(
{"countDocuments": {}},
namespace=sync_collection.namespace,
collection_name=sync_collection.name,
)
assert cmd2 == cmd1

@pytest.mark.describe("test of database command, sync")
def test_database_command_sync(
self,
sync_database: Database,
) -> None:
cmd1 = sync_database.command({"findCollections": {}})
assert isinstance(cmd1, dict)
assert isinstance(cmd1["status"]["collections"], list)
cmd2 = sync_database.copy(namespace="...").command(
{"findCollections": {}}, namespace=sync_database.namespace
)
assert cmd2 == cmd1
15 changes: 15 additions & 0 deletions tests/idiomatic/unit/test_collections_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def test_convert_collection_async(
caller_version="c_v",
)
assert col1 == col1.copy()
assert col1 == col1.with_options()
assert col1 == col1.to_sync().to_async()

@pytest.mark.describe("test of Collection rich copy, async")
Expand Down Expand Up @@ -89,6 +90,20 @@ async def test_rich_copy_collection_async(
)
assert col3 == col1

assert col1.with_options(name="x") != col1
assert (
col1.with_options(name="x").with_options(name="id_test_collection") == col1
)
assert col1.with_options(caller_name="x") != col1
assert (
col1.with_options(caller_name="x").with_options(caller_name="c_n") == col1
)
assert col1.with_options(caller_version="x") != col1
assert (
col1.with_options(caller_version="x").with_options(caller_version="c_v")
== col1
)

@pytest.mark.describe("test of Collection rich conversions, async")
async def test_rich_convert_collection_async(
self,
Expand Down
15 changes: 15 additions & 0 deletions tests/idiomatic/unit/test_collections_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_convert_collection_sync(
caller_version="c_v",
)
assert col1 == col1.copy()
assert col1 == col1.with_options()
assert col1 == col1.to_async().to_sync()

@pytest.mark.describe("test of Collection rich copy, sync")
Expand Down Expand Up @@ -89,6 +90,20 @@ def test_rich_copy_collection_sync(
)
assert col3 == col1

assert col1.with_options(name="x") != col1
assert (
col1.with_options(name="x").with_options(name="id_test_collection") == col1
)
assert col1.with_options(caller_name="x") != col1
assert (
col1.with_options(caller_name="x").with_options(caller_name="c_n") == col1
)
assert col1.with_options(caller_version="x") != col1
assert (
col1.with_options(caller_version="x").with_options(caller_version="c_v")
== col1
)

@pytest.mark.describe("test of Collection rich conversions, sync")
def test_rich_convert_collection_sync(
self,
Expand Down
19 changes: 15 additions & 4 deletions tests/idiomatic/unit/test_databases_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def test_convert_database_async(
**astra_db_credentials_kwargs,
)
assert db1 == db1.copy()
assert db1 == db1.with_options()
assert db1 == db1.to_sync().to_async()

@pytest.mark.describe("test of Database rich copy, async")
Expand Down Expand Up @@ -99,6 +100,18 @@ async def test_rich_copy_database_async(
)
assert db3 == db1

assert db1.with_options(namespace="x") != db1
assert (
db1.with_options(namespace="x").with_options(namespace="namespace") == db1
)
assert db1.with_options(caller_name="x") != db1
assert db1.with_options(caller_name="x").with_options(caller_name="c_n") == db1
assert db1.with_options(caller_version="x") != db1
assert (
db1.with_options(caller_version="x").with_options(caller_version="c_v")
== db1
)

@pytest.mark.describe("test of Database rich conversions, async")
async def test_rich_convert_database_async(
self,
Expand Down Expand Up @@ -174,10 +187,8 @@ async def test_database_get_collection_async(
collection = await async_database.get_collection(TEST_COLLECTION_INSTANCE_NAME)
assert collection == async_collection_instance

assert (
await getattr(async_database, TEST_COLLECTION_INSTANCE_NAME) == collection
)
assert await async_database[TEST_COLLECTION_INSTANCE_NAME] == collection
assert getattr(async_database, TEST_COLLECTION_INSTANCE_NAME) == collection
assert async_database[TEST_COLLECTION_INSTANCE_NAME] == collection

NAMESPACE_2 = "other_namespace"
collection_ns2 = await async_database.get_collection(
Expand Down
Loading
Loading