Skip to content

Commit

Permalink
collection.drop + database/collection info and related methods/proper…
Browse files Browse the repository at this point in the history
…ties (metadata) (#247)

* Adding drop function to the collection

* Adding drop collection to collection

* Moving the tests into the right test files

* Removing async test because I was getting an event loop issue

* Adding database attributes

* Adding attributes to asyncdatabase

* Removing unneeded attributes, fixing up collection.drop()

* Running poetry black

* Setting the property as a separate piece

* Fixing up the async version for get_database_info

* error message against db.nonmethod() and getattr/getitem of async are non-async functions

* command() method for raw POSTs

* with_options as sugar for copy() methods

* rework database/collection info as per agreed conventions

* regexp to parse api_endpoint into database_id

---------

Co-authored-by: Kirsten Hunter <synedra@gmail.com>
  • Loading branch information
hemidactylus and synedra committed Mar 8, 2024
1 parent ef6f3d9 commit 91261e7
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 33 deletions.
81 changes: 57 additions & 24 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
BulkWriteResult,
)
from astrapy.idiomatic.cursors import AsyncCursor, Cursor
from astrapy.idiomatic.info import CollectionInfo


if TYPE_CHECKING:
Expand Down Expand Up @@ -79,18 +80,6 @@ def __init__(
namespace=self._astra_db_collection.astra_db.namespace
)

@property
def database(self) -> Database:
return self._database

@property
def namespace(self) -> str:
return self.database.namespace

@property
def name(self) -> str:
return self._astra_db_collection.collection_name

def __repr__(self) -> str:
return f'{self.__class__.__name__}[_astra_db_collection="{self._astra_db_collection}"]'

Expand Down Expand Up @@ -176,6 +165,31 @@ def options(self) -> Dict[str, Any]:
else:
raise ValueError(f"Collection {self.namespace}.{self.name} not found.")

@property
def info(self) -> CollectionInfo:
return CollectionInfo(
database_info=self.database.info,
namespace=self.namespace,
name=self.name,
full_name=self.full_name,
)

@property
def database(self) -> Database:
return self._database

@property
def namespace(self) -> str:
return self.database.namespace

@property
def name(self) -> str:
return self._astra_db_collection.collection_name

@property
def full_name(self) -> str:
return f"{self.namespace}.{self.name}"

def insert_one(
self,
document: DocumentType,
Expand Down Expand Up @@ -566,6 +580,9 @@ def bulk_write(
]
return reduce_bulk_write_results(bulk_write_results)

def drop(self) -> Dict[str, Any]:
return self.database.drop_collection(self)


class AsyncCollection:
def __init__(
Expand All @@ -589,18 +606,6 @@ def __init__(
namespace=self._astra_db_collection.astra_db.namespace
)

@property
def database(self) -> AsyncDatabase:
return self._database

@property
def namespace(self) -> str:
return self.database.namespace

@property
def name(self) -> str:
return self._astra_db_collection.collection_name

def __repr__(self) -> str:
return f'{self.__class__.__name__}[_astra_db_collection="{self._astra_db_collection}"]'

Expand Down Expand Up @@ -686,6 +691,31 @@ async def options(self) -> Dict[str, Any]:
else:
raise ValueError(f"Collection {self.namespace}.{self.name} not found.")

@property
def info(self) -> CollectionInfo:
return CollectionInfo(
database_info=self.database.info,
namespace=self.namespace,
name=self.name,
full_name=self.full_name,
)

@property
def database(self) -> AsyncDatabase:
return self._database

@property
def namespace(self) -> str:
return self.database.namespace

@property
def name(self) -> str:
return self._astra_db_collection.collection_name

@property
def full_name(self) -> str:
return f"{self.namespace}.{self.name}"

async def insert_one(
self,
document: DocumentType,
Expand Down Expand Up @@ -1088,3 +1118,6 @@ async def concurrent_execute_operation(
]
bulk_write_results = await asyncio.gather(*tasks)
return reduce_bulk_write_results(bulk_write_results)

async def drop(self) -> Dict[str, Any]:
return await self.database.drop_collection(self)
56 changes: 47 additions & 9 deletions astrapy/idiomatic/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from astrapy.db import AstraDB, AsyncAstraDB
from astrapy.idiomatic.cursors import AsyncCommandCursor, CommandCursor

from astrapy.idiomatic.info import DatabaseInfo, get_database_info

if TYPE_CHECKING:
from astrapy.idiomatic.collection import AsyncCollection, Collection
Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(
caller_name=caller_name,
caller_version=caller_version,
)
self._database_info: Optional[DatabaseInfo] = None

def __getattr__(self, collection_name: str) -> Collection:
return self.get_collection(name=collection_name)
Expand All @@ -110,10 +111,6 @@ def __eq__(self, other: Any) -> bool:
else:
return False

@property
def namespace(self) -> str:
return self._astra_db.namespace

def copy(
self,
*,
Expand Down Expand Up @@ -179,6 +176,28 @@ def set_caller(
caller_version=caller_version,
)

@property
def info(self) -> DatabaseInfo:
if self._database_info is None:
self._database_info = get_database_info(
self._astra_db.api_endpoint,
token=self._astra_db.token,
namespace=self.namespace,
)
return self._database_info

@property
def id(self) -> Optional[str]:
return self.info.id

@property
def name(self) -> Optional[str]:
return self.info.name

@property
def namespace(self) -> str:
return self._astra_db.namespace

def get_collection(
self, name: str, *, namespace: Optional[str] = None
) -> Collection:
Expand Down Expand Up @@ -338,6 +357,7 @@ def __init__(
caller_name=caller_name,
caller_version=caller_version,
)
self._database_info: Optional[DatabaseInfo] = None

def __getattr__(self, collection_name: str) -> AsyncCollection:
return self.to_sync().get_collection(name=collection_name).to_async()
Expand Down Expand Up @@ -369,10 +389,6 @@ async def __aexit__(
traceback=traceback,
)

@property
def namespace(self) -> str:
return self._astra_db.namespace

def copy(
self,
*,
Expand Down Expand Up @@ -438,6 +454,28 @@ def set_caller(
caller_version=caller_version,
)

@property
def info(self) -> DatabaseInfo:
if self._database_info is None:
self._database_info = get_database_info(
self._astra_db.api_endpoint,
token=self._astra_db.token,
namespace=self.namespace,
)
return self._database_info

@property
def id(self) -> Optional[str]:
return self.info.id

@property
def name(self) -> Optional[str]:
return self.info.name

@property
def namespace(self) -> str:
return self._astra_db.namespace

async def get_collection(
self, name: str, *, namespace: Optional[str] = None
) -> AsyncCollection:
Expand Down
83 changes: 83 additions & 0 deletions astrapy/idiomatic/info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Any, Dict, Optional

from astrapy.ops import AstraDBOps


database_id_finder = re.compile(
"https://([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
)


def find_database_id(api_endpoint: str) -> Optional[str]:
match = database_id_finder.match(api_endpoint)
if match and match.groups():
return match.groups()[0]
else:
return None


@dataclass
class DatabaseInfo:
id: Optional[str]
region: Optional[str]
namespace: str
name: Optional[str]
raw_info: Optional[Dict[str, Any]]


@dataclass
class CollectionInfo:
database_info: DatabaseInfo
namespace: str
name: str
full_name: str


def get_database_info(api_endpoint: str, token: str, namespace: str) -> DatabaseInfo:
try:
astra_db_ops = AstraDBOps(token=token)
database_id = find_database_id(api_endpoint)
if database_id:
gd_response = astra_db_ops.get_database(database=database_id)
raw_info = gd_response["info"]
return DatabaseInfo(
id=database_id,
region=raw_info["region"],
namespace=namespace,
name=raw_info["name"],
raw_info=raw_info,
)
else:
return DatabaseInfo(
id=None,
region=None,
namespace=namespace,
name=None,
raw_info=None,
)
except Exception:
return DatabaseInfo(
id=None,
region=None,
namespace=namespace,
name=None,
raw_info=None,
)
35 changes: 35 additions & 0 deletions tests/idiomatic/integration/test_ddl_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TEST_COLLECTION_NAME,
)
from astrapy.api import APIRequestError
from astrapy.idiomatic.info import DatabaseInfo
from astrapy import AsyncCollection, AsyncDatabase


Expand Down Expand Up @@ -64,6 +65,40 @@ async def test_collection_lifecycle_async(
assert dc_response2 == {"ok": 1}
await async_database.drop_collection(TEST_LOCAL_COLLECTION_NAME_B)

@pytest.mark.describe("test of collection drop, async")
async def test_collection_drop_async(self, async_database: AsyncDatabase) -> None:
col = await async_database.create_collection(
name="async_collection_to_drop", dimension=2
)
del_res = await col.drop()
assert del_res["ok"] == 1
assert "async_collection_to_drop" not in (
await async_database.list_collection_names()
)

@pytest.mark.describe("test of database metainformation, async")
async def test_get_database_info_async(
self,
async_database: AsyncDatabase,
astra_db_credentials_kwargs: AstraDBCredentials,
) -> None:
assert isinstance(async_database.id, str)
assert isinstance(async_database.name, str)
assert async_database.namespace == astra_db_credentials_kwargs["namespace"]
assert isinstance(async_database.info, DatabaseInfo)
assert isinstance(async_database.info.raw_info, dict)

@pytest.mark.describe("test of collection metainformation, async")
async def test_get_collection_info_async(
self,
async_collection: AsyncCollection,
) -> None:
info = async_collection.info
assert info.namespace == async_collection.namespace
assert (
info.namespace == async_collection._astra_db_collection.astra_db.namespace
)

@pytest.mark.describe("test of check_exists for create_collection, async")
async def test_create_collection_check_exists_async(
self,
Expand Down
Loading

0 comments on commit 91261e7

Please sign in to comment.