Skip to content

Commit

Permalink
enhanced copy methods for astrapy objects
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus committed Feb 28, 2024
1 parent 42bee6b commit 9d4c586
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 28 deletions.
124 changes: 98 additions & 26 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def __init__(
caller_name=caller_name,
caller_version=caller_version,
)
else:
# if astra_db passed, copy and apply possible overrides
astra_db = astra_db.copy(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)

# Set the remaining instance attributes
self.astra_db = astra_db
Expand All @@ -128,12 +135,31 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AstraDBCollection:
def copy(
self,
*,
collection_name: Optional[str] = None,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AstraDBCollection:
return AstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db.copy(),
caller_name=self.caller_name,
caller_version=self.caller_version,
collection_name=collection_name or self.collection_name,
astra_db=self.astra_db.copy(
token=token,
api_endpoint=api_endpoint,
api_path=api_path,
api_version=api_version,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
),
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def to_async(self) -> AsyncAstraDBCollection:
Expand Down Expand Up @@ -1092,6 +1118,13 @@ def __init__(
caller_name=caller_name,
caller_version=caller_version,
)
else:
# if astra_db passed, copy and apply possible overrides
astra_db = astra_db.copy(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)

# Set the remaining instance attributes
self.astra_db: AsyncAstraDB = astra_db
Expand All @@ -1117,12 +1150,31 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AsyncAstraDBCollection:
def copy(
self,
*,
collection_name: Optional[str] = None,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AsyncAstraDBCollection:
return AsyncAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db.copy(),
caller_name=self.caller_name,
caller_version=self.caller_version,
collection_name=collection_name or self.collection_name,
astra_db=self.astra_db.copy(
token=token,
api_endpoint=api_endpoint,
api_path=api_path,
api_version=api_version,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
),
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def set_caller(
Expand Down Expand Up @@ -2063,15 +2115,25 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AstraDB:
def copy(
self,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AstraDB:
return AstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
token=token or self.token,
api_endpoint=api_endpoint or self.base_url,
api_path=api_path or self.api_path,
api_version=api_version or self.api_version,
namespace=namespace or self.namespace,
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def to_async(self) -> AsyncAstraDB:
Expand Down Expand Up @@ -2349,15 +2411,25 @@ async def __aexit__(
) -> None:
await self.client.aclose()

def copy(self) -> AsyncAstraDB:
def copy(
self,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AsyncAstraDB:
return AsyncAstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
token=token or self.token,
api_endpoint=api_endpoint or self.base_url,
api_path=api_path or self.api_path,
api_version=api_version or self.api_version,
namespace=namespace or self.namespace,
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def to_sync(self) -> AstraDB:
Expand Down
19 changes: 17 additions & 2 deletions astrapy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,23 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AstraDBOps:
return AstraDBOps(**self.constructor_params)
def copy(
self,
*,
token: Optional[str] = None,
dev_ops_url: Optional[str] = None,
dev_ops_api_version: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AstraDBOps:
return AstraDBOps(
token=token or self.constructor_params["token"],
dev_ops_url=dev_ops_url or self.constructor_params["dev_ops_url"],
dev_ops_api_version=dev_ops_api_version
or self.constructor_params["dev_ops_api_version"],
caller_name=caller_name or self.constructor_params["caller_name"],
caller_version=caller_version or self.constructor_params["caller_version"],
)

def set_caller(
self,
Expand Down
167 changes: 167 additions & 0 deletions tests/astrapy/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,173 @@ def test_copy_methods() -> None:
assert c_adb_ops is not adb_ops


@pytest.mark.describe("test parameter override in copy methods")
def test_parameter_override_copy_methods() -> None:
sync_astradb = AstraDB(
token="token",
api_endpoint="api_endpoint",
api_path="api_path",
api_version="api_version",
namespace="namespace",
caller_name="caller_name",
caller_version="caller_version",
)
sync_astradb2 = AstraDB(
token="token2",
api_endpoint="api_endpoint2",
api_path="api_path2",
api_version="api_version2",
namespace="namespace2",
caller_name="caller_name2",
caller_version="caller_version2",
)
c_sync_astradb = sync_astradb.copy(
token="token2",
api_endpoint="api_endpoint2",
api_path="api_path2",
api_version="api_version2",
namespace="namespace2",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert c_sync_astradb == sync_astradb2

async_astradb = AsyncAstraDB(
token="token",
api_endpoint="api_endpoint",
api_path="api_path",
api_version="api_version",
namespace="namespace",
caller_name="caller_name",
caller_version="caller_version",
)
async_astradb2 = AsyncAstraDB(
token="token2",
api_endpoint="api_endpoint2",
api_path="api_path2",
api_version="api_version2",
namespace="namespace2",
caller_name="caller_name2",
caller_version="caller_version2",
)
c_async_astradb = async_astradb.copy(
token="token2",
api_endpoint="api_endpoint2",
api_path="api_path2",
api_version="api_version2",
namespace="namespace2",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert c_async_astradb == async_astradb2

sync_adbcollection = AstraDBCollection(
collection_name="collection_name",
astra_db=sync_astradb,
caller_name="caller_name",
caller_version="caller_version",
)
sync_adbcollection2 = AstraDBCollection(
collection_name="collection_name2",
astra_db=sync_astradb2,
caller_name="caller_name2",
caller_version="caller_version2",
)
c_sync_adbcollection = sync_adbcollection.copy(
collection_name="collection_name2",
token="token2",
api_endpoint="api_endpoint2",
api_path="api_path2",
api_version="api_version2",
namespace="namespace2",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert c_sync_adbcollection == sync_adbcollection2

async_adbcollection = AsyncAstraDBCollection(
collection_name="collection_name",
astra_db=async_astradb,
caller_name="caller_name",
caller_version="caller_version",
)
async_adbcollection2 = AsyncAstraDBCollection(
collection_name="collection_name2",
astra_db=async_astradb2,
caller_name="caller_name2",
caller_version="caller_version2",
)
c_async_adbcollection = async_adbcollection.copy(
collection_name="collection_name2",
token="token2",
api_endpoint="api_endpoint2",
api_path="api_path2",
api_version="api_version2",
namespace="namespace2",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert c_async_adbcollection == async_adbcollection2

adb_ops = AstraDBOps(
token="token",
dev_ops_url="dev_ops_url",
dev_ops_api_version="dev_ops_api_version",
caller_name="caller_name",
caller_version="caller_version",
)
adb_ops2 = AstraDBOps(
token="token2",
dev_ops_url="dev_ops_url2",
dev_ops_api_version="dev_ops_api_version2",
caller_name="caller_name2",
caller_version="caller_version2",
)
c_adb_ops = adb_ops.copy(
token="token2",
dev_ops_url="dev_ops_url2",
dev_ops_api_version="dev_ops_api_version2",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert c_adb_ops == adb_ops2


@pytest.mark.describe("test parameter override when instantiating collections")
def test_parameter_override_collection_instances() -> None:
astradb0 = AstraDB(token="t0", api_endpoint="a0")
astradb1 = AstraDB(token="t1", api_endpoint="a1", namespace="n1")
col0 = AstraDBCollection(
collection_name="col0",
astra_db=astradb0,
)
col1 = AstraDBCollection(
collection_name="col0",
astra_db=astradb0,
token="t1",
api_endpoint="a1",
namespace="n1",
)
assert col0 != col1
assert col1 == AstraDBCollection(collection_name="col0", astra_db=astradb1)

a_astradb0 = AsyncAstraDB(token="t0", api_endpoint="a0")
a_astradb1 = AsyncAstraDB(token="t1", api_endpoint="a1", namespace="n1")
a_col0 = AsyncAstraDBCollection(
collection_name="col0",
astra_db=a_astradb0,
)
a_col1 = AsyncAstraDBCollection(
collection_name="col0",
astra_db=a_astradb0,
token="t1",
api_endpoint="a1",
namespace="n1",
)
assert a_col0 != a_col1
assert a_col1 == AsyncAstraDBCollection(collection_name="col0", astra_db=a_astradb1)


@pytest.mark.describe("test set_caller works in place for clients")
def test_set_caller_clients() -> None:
astradb0 = AstraDB(token="t1", api_endpoint="a1")
Expand Down

0 comments on commit 9d4c586

Please sign in to comment.