From a3b3e5679bb15ee3a0a1e5f7291411d43e5c72ef Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 5 Jan 2024 15:54:38 +0100 Subject: [PATCH] Add async insert_many_chunked and upsert_many --- astrapy/db.py | 68 ++++++++++++++++++++++++ astrapy/defaults.py | 2 + tests/astrapy/test_async_db_dml.py | 85 +++++++++++++++++++++++++++++- 3 files changed, 154 insertions(+), 1 deletion(-) diff --git a/astrapy/db.py b/astrapy/db.py index dbf89872..14a8684f 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -41,6 +41,7 @@ DEFAULT_JSON_API_PATH, DEFAULT_JSON_API_VERSION, DEFAULT_KEYSPACE_NAME, + MAX_INSERT_BATCH_SIZE, ) from astrapy.utils import ( convert_vector_to_floats, @@ -1440,6 +1441,48 @@ async def insert_many( return response + async def insert_many_chunked( + self, + documents: List[API_DOC], + options: Optional[Dict[str, Any]] = None, + partial_failures_allowed: bool = False, + chunk_size: int = MAX_INSERT_BATCH_SIZE, + concurrency: int = 1, + ) -> List[Union[API_RESPONSE, BaseException]]: + """ + Insert multiple documents into the collection. + The input list of documents is split into chunks of chunk_size size before sending to the JSON API. + The calls to the API can be done concurrently. + Args: + documents (list): A list of documents to insert. + options (dict, optional): Additional options for the insert operation. + partial_failures_allowed (bool, optional): Whether to allow partial failures. + chunk_size (int, optional): Override the default insertion chunk size. + concurrency (int, optional): The number of concurrent calls. + Returns: + list: The responses from the database for each chunk (or an exception if the call failed). + """ + sem = asyncio.Semaphore(concurrency) + + async def concurrent_insert_many( + docs: List[API_DOC], index: int + ) -> API_RESPONSE: + async with sem: + logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}") + return await self.insert_many( + documents=docs, + options=options, + partial_failures_allowed=partial_failures_allowed, + ) + + tasks = [ + asyncio.create_task( + concurrent_insert_many(documents[i : i + chunk_size], i) + ) + for i in range(0, len(documents), chunk_size) + ] + return await asyncio.gather(*tasks, return_exceptions=partial_failures_allowed) + async def update_one( self, filter: Dict[str, Any], update: Dict[str, Any] ) -> API_RESPONSE: @@ -1566,6 +1609,31 @@ async def upsert(self, document: API_DOC) -> str: return upserted_id + async def upsert_many( + self, + documents: list[API_DOC], + concurrency: int = 1, + partial_failures_allowed: bool = False, + ) -> List[Union[str, BaseException]]: + """ + Emulate an upsert operation for multiple documents in the collection. + This method attempts to insert the documents. If a document with the same _id exists, it updates the existing document. + Args: + documents (List[dict]): The documents to insert or update. + concurrency (int, optional): The number of concurrent updates. + partial_failures_allowed (bool, optional): Whether to allow partial failures in the batch. + Returns: + List[Union[str, BaseException]]: A list of "_id"s of the inserted or updated documents (or an exception if the call failed) + """ + sem = asyncio.Semaphore(concurrency) + + async def concurrent_upsert(doc: API_DOC) -> str: + async with sem: + return await self.upsert(document=doc) + + tasks = [asyncio.create_task(concurrent_upsert(doc)) for doc in documents] + return await asyncio.gather(*tasks, return_exceptions=partial_failures_allowed) + class AstraDB: # Initialize the shared httpx client as a class attribute diff --git a/astrapy/defaults.py b/astrapy/defaults.py index 8fa05edf..32540469 100644 --- a/astrapy/defaults.py +++ b/astrapy/defaults.py @@ -11,3 +11,5 @@ DEFAULT_AUTH_HEADER = "Token" DEFAULT_KEYSPACE_NAME = "default_keyspace" DEFAULT_REGION = "us-east1" + +MAX_INSERT_BATCH_SIZE = 20 diff --git a/tests/astrapy/test_async_db_dml.py b/tests/astrapy/test_async_db_dml.py index ef7520cb..11b8785d 100644 --- a/tests/astrapy/test_async_db_dml.py +++ b/tests/astrapy/test_async_db_dml.py @@ -307,7 +307,7 @@ async def test_insert_many( { "_id": _id2, "name": "Ciccio", - "description": "The thid in this list", + "description": "The third in this list", "$vector": [0.4, 0.3], }, ] @@ -319,6 +319,25 @@ async def test_insert_many( assert isinstance(list(inserted_ids - {_id0, _id2})[0], str) +@pytest.mark.describe("insert_many_chunked") +async def test_insert_many_chunked( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + documents: List[API_DOC] = [{"name": "Abba"}] * 30 + response = await async_writable_vector_collection.insert_many_chunked( + documents, concurrency=2 + ) + assert response is not None + inserted_ids_list = [ + set(r["status"]["insertedIds"]) + for r in response + if not isinstance(r, BaseException) + ] + inserted_ids = set.union(*inserted_ids_list) + assert len(inserted_ids) == 30 + assert isinstance(list(inserted_ids)[0], str) + + @pytest.mark.describe("insert_many with 'ordered' set to False") async def test_insert_many_ordered_false( async_writable_vector_collection: AsyncAstraDBCollection, @@ -419,6 +438,70 @@ async def test_upsert_document( assert response1["data"]["document"] == document1 +@pytest.mark.describe("upsert many") +async def test_upsert_many( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + _id0 = str(uuid.uuid4()) + _id1 = str(uuid.uuid4()) + + documents = [ + { + "_id": _id0, + "addresses": { + "work": { + "city": "Seattle", + "state": "WA", + }, + }, + }, + { + "_id": _id1, + "addresses": { + "work": { + "city": "Seattle", + "state": "WA", + }, + }, + }, + ] + upsert_result0 = await async_writable_vector_collection.upsert_many( + documents, concurrency=2, partial_failures_allowed=True + ) + assert upsert_result0[0] == _id0 + assert upsert_result0[1] == _id1 + + response0 = await async_writable_vector_collection.find_one(filter={"_id": _id0}) + assert response0 is not None + assert response0["data"]["document"] == documents[0] + + response0 = await async_writable_vector_collection.find_one(filter={"_id": _id1}) + assert response0 is not None + assert response0["data"]["document"] == documents[1] + + documents2 = [ + { + "_id": _id0, + "addresses": { + "work": { + "state": "MN", + "floor": 12, + }, + }, + "hobbies": [ + "ice skating", + "accounting", + ], + } + ] + upsert_result2 = await async_writable_vector_collection.upsert_many(documents2) + assert upsert_result2[0] == _id0 + + response2 = await async_writable_vector_collection.find_one(filter={"_id": _id0}) + assert response2 is not None + assert response2["data"]["document"] == documents2[0] + + @pytest.mark.describe("update_one to create a subdocument, not through vector") async def test_update_one_create_subdocument_novector( async_disposable_vector_collection: AsyncAstraDBCollection,