Skip to content

Commit

Permalink
fix bug in copy and to_[a]sync when set_caller is later used (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus authored Feb 29, 2024
1 parent 13ee9bf commit 8001f3f
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 8 deletions.
12 changes: 10 additions & 2 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def __eq__(self, other: Any) -> bool:
return False

def copy(self) -> Collection:
return Collection(**self._constructor_params)
return Collection(
**self._constructor_params,
)

def to_async(self) -> AsyncCollection:
return AsyncCollection(
Expand All @@ -93,6 +95,8 @@ def set_caller(
caller_name=caller_name,
caller_version=caller_version,
)
self._constructor_params["caller_name"] = caller_name
self._constructor_params["caller_version"] = caller_version

def insert_one(
self,
Expand Down Expand Up @@ -303,7 +307,9 @@ def __eq__(self, other: Any) -> bool:
return False

def copy(self) -> AsyncCollection:
return AsyncCollection(**self._constructor_params)
return AsyncCollection(
**self._constructor_params,
)

def to_sync(self) -> Collection:
return Collection(
Expand All @@ -322,6 +328,8 @@ def set_caller(
caller_name=caller_name,
caller_version=caller_version,
)
self._constructor_params["caller_name"] = caller_name
self._constructor_params["caller_version"] = caller_version

async def insert_one(
self,
Expand Down
28 changes: 24 additions & 4 deletions astrapy/idiomatic/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def __init__(
caller_version=caller_version,
)

@property
def namespace(self) -> str:
return self._astra_db.namespace

def __repr__(self) -> str:
return f'{self.__class__.__name__}[_astra_db={self._astra_db}"]'

Expand All @@ -102,10 +106,14 @@ def __eq__(self, other: Any) -> bool:
return False

def copy(self) -> Database:
return Database(**self._constructor_params)
return Database(
**self._constructor_params,
)

def to_async(self) -> AsyncDatabase:
return AsyncDatabase(**self._constructor_params)
return AsyncDatabase(
**self._constructor_params,
)

def set_caller(
self,
Expand All @@ -114,6 +122,8 @@ def set_caller(
) -> None:
self._astra_db.caller_name = caller_name
self._astra_db.caller_version = caller_version
self._constructor_params["caller_name"] = caller_name
self._constructor_params["caller_version"] = caller_version

def get_collection(
self, name: str, *, namespace: Optional[str] = None
Expand Down Expand Up @@ -245,6 +255,10 @@ def __init__(
caller_version=caller_version,
)

@property
def namespace(self) -> str:
return self._astra_db.namespace

def __repr__(self) -> str:
return f'{self.__class__.__name__}[_astra_db={self._astra_db}"]'

Expand All @@ -270,10 +284,14 @@ async def __aexit__(
)

def copy(self) -> AsyncDatabase:
return AsyncDatabase(**self._constructor_params)
return AsyncDatabase(
**self._constructor_params,
)

def to_sync(self) -> Database:
return Database(**self._constructor_params)
return Database(
**self._constructor_params,
)

def set_caller(
self,
Expand All @@ -282,6 +300,8 @@ def set_caller(
) -> None:
self._astra_db.caller_name = caller_name
self._astra_db.caller_version = caller_version
self._constructor_params["caller_name"] = caller_name
self._constructor_params["caller_version"] = caller_version

async def get_collection(
self, name: str, *, namespace: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions astrapy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def copy(
dev_ops_url=dev_ops_url or self.constructor_params["dev_ops_url"],
dev_ops_api_version=dev_ops_api_version
or self.constructor_params["dev_ops_api_version"],
caller_name=caller_name or self.constructor_params["caller_name"],
caller_version=caller_version or self.constructor_params["caller_version"],
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def set_caller(
Expand Down
107 changes: 107 additions & 0 deletions tests/astrapy/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,113 @@ def test_copy_methods() -> None:
assert c_adb_ops is not adb_ops


@pytest.mark.describe("test copy methods respect mutable caller identity")
def test_copy_methods_mutable_caller() -> None:
sync_astradb = AstraDB(
token="token",
api_endpoint="api_endpoint",
api_path="api_path",
api_version="api_version",
namespace="namespace",
caller_name="caller_name",
caller_version="caller_version",
)
sync_astradb.set_caller(
caller_name="caller_name2",
caller_version="caller_version2",
)
sync_astradb2 = AstraDB(
token="token",
api_endpoint="api_endpoint",
api_path="api_path",
api_version="api_version",
namespace="namespace",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert sync_astradb.copy() == sync_astradb2

async_astradb = AsyncAstraDB(
token="token",
api_endpoint="api_endpoint",
api_path="api_path",
api_version="api_version",
namespace="namespace",
caller_name="caller_name",
caller_version="caller_version",
)
async_astradb.set_caller(
caller_name="caller_name2",
caller_version="caller_version2",
)
async_astradb2 = AsyncAstraDB(
token="token",
api_endpoint="api_endpoint",
api_path="api_path",
api_version="api_version",
namespace="namespace",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert async_astradb.copy() == async_astradb2

sync_adbcollection = AstraDBCollection(
collection_name="collection_name",
astra_db=sync_astradb,
caller_name="caller_name",
caller_version="caller_version",
)
sync_adbcollection.set_caller(
caller_name="caller_name2",
caller_version="caller_version2",
)
sync_adbcollection2 = AstraDBCollection(
collection_name="collection_name",
astra_db=sync_astradb,
caller_name="caller_name2",
caller_version="caller_version2",
)
assert sync_adbcollection.copy() == sync_adbcollection2

async_adbcollection = AsyncAstraDBCollection(
collection_name="collection_name",
astra_db=async_astradb,
caller_name="caller_name",
caller_version="caller_version",
)
async_adbcollection.set_caller(
caller_name="caller_name2",
caller_version="caller_version2",
)
async_adbcollection2 = AsyncAstraDBCollection(
collection_name="collection_name",
astra_db=async_astradb,
caller_name="caller_name2",
caller_version="caller_version2",
)
assert async_adbcollection.copy() == async_adbcollection2

adb_ops = AstraDBOps(
token="token",
dev_ops_url="dev_ops_url",
dev_ops_api_version="dev_ops_api_version",
caller_name="caller_name",
caller_version="caller_version",
)
adb_ops.set_caller(
caller_name="caller_name2",
caller_version="caller_version2",
)
adb_ops2 = AstraDBOps(
token="token",
dev_ops_url="dev_ops_url",
dev_ops_api_version="dev_ops_api_version",
caller_name="caller_name2",
caller_version="caller_version2",
)
assert adb_ops.copy() == adb_ops2


@pytest.mark.describe("test parameter override in copy methods")
def test_parameter_override_copy_methods() -> None:
sync_astradb = AstraDB(
Expand Down
24 changes: 24 additions & 0 deletions tests/idiomatic/integration/test_collections_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,27 @@ async def test_collection_unsupported_methods_async(
await async_collection_instance.update_search_index(1, "x")
with pytest.raises(TypeError):
await async_collection_instance.distinct(1, "x")

@pytest.mark.describe("test collection conversions with caller mutableness, async")
async def test_collection_conversions_caller_mutableness_async(
self,
async_database: AsyncDatabase,
) -> None:
col1 = AsyncCollection(
async_database,
"id_test_collection",
caller_name="c_n1",
caller_version="c_v1",
)
col1.set_caller(
caller_name="c_n2",
caller_version="c_v2",
)
col2 = AsyncCollection(
async_database,
"id_test_collection",
caller_name="c_n2",
caller_version="c_v2",
)
assert col1.copy() == col2
assert col1.to_sync().to_async() == col2
24 changes: 24 additions & 0 deletions tests/idiomatic/integration/test_collections_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,27 @@ def test_collection_unsupported_methods_sync(
sync_collection_instance.update_search_index(1, "x")
with pytest.raises(TypeError):
sync_collection_instance.distinct(1, "x")

@pytest.mark.describe("test collection conversions with caller mutableness, sync")
def test_collection_conversions_caller_mutableness_sync(
self,
sync_database: Database,
) -> None:
col1 = Collection(
sync_database,
"id_test_collection",
caller_name="c_n1",
caller_version="c_v1",
)
col1.set_caller(
caller_name="c_n2",
caller_version="c_v2",
)
col2 = Collection(
sync_database,
"id_test_collection",
caller_name="c_n2",
caller_version="c_v2",
)
assert col1.copy() == col2
assert col1.to_async().to_sync() == col2
22 changes: 22 additions & 0 deletions tests/idiomatic/integration/test_databases_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,25 @@ async def test_database_get_collection_async(
async_database, TEST_COLLECTION_INSTANCE_NAME, namespace=NAMESPACE_2
)
assert collection_ns2._astra_db_collection.astra_db.namespace == NAMESPACE_2

@pytest.mark.describe("test database conversions with caller mutableness, async")
async def test_database_conversions_caller_mutableness_async(
self,
astra_db_credentials_kwargs: AstraDBCredentials,
) -> None:
db1 = AsyncDatabase(
caller_name="c_n1",
caller_version="c_v1",
**astra_db_credentials_kwargs,
)
db1.set_caller(
caller_name="c_n2",
caller_version="c_v2",
)
db2 = AsyncDatabase(
caller_name="c_n2",
caller_version="c_v2",
**astra_db_credentials_kwargs,
)
assert db1.to_sync().to_async() == db2
assert db1.copy() == db2
22 changes: 22 additions & 0 deletions tests/idiomatic/integration/test_databases_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,25 @@ def test_database_get_collection_sync(
sync_database, TEST_COLLECTION_INSTANCE_NAME, namespace=NAMESPACE_2
)
assert collection_ns2._astra_db_collection.astra_db.namespace == NAMESPACE_2

@pytest.mark.describe("test database conversions with caller mutableness, sync")
def test_database_conversions_caller_mutableness_sync(
self,
astra_db_credentials_kwargs: AstraDBCredentials,
) -> None:
db1 = Database(
caller_name="c_n1",
caller_version="c_v1",
**astra_db_credentials_kwargs,
)
db1.set_caller(
caller_name="c_n2",
caller_version="c_v2",
)
db2 = Database(
caller_name="c_n2",
caller_version="c_v2",
**astra_db_credentials_kwargs,
)
assert db1.to_async().to_sync() == db2
assert db1.copy() == db2

0 comments on commit 8001f3f

Please sign in to comment.