Skip to content

Commit

Permalink
Add async insert_many_chunked and upsert_many
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 5, 2024
1 parent 7275930 commit a3b3e56
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 1 deletion.
68 changes: 68 additions & 0 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DEFAULT_JSON_API_PATH,
DEFAULT_JSON_API_VERSION,
DEFAULT_KEYSPACE_NAME,
MAX_INSERT_BATCH_SIZE,
)
from astrapy.utils import (
convert_vector_to_floats,
Expand Down Expand Up @@ -1440,6 +1441,48 @@ async def insert_many(

return response

async def insert_many_chunked(
self,
documents: List[API_DOC],
options: Optional[Dict[str, Any]] = None,
partial_failures_allowed: bool = False,
chunk_size: int = MAX_INSERT_BATCH_SIZE,
concurrency: int = 1,
) -> List[Union[API_RESPONSE, BaseException]]:
"""
Insert multiple documents into the collection.
The input list of documents is split into chunks of chunk_size size before sending to the JSON API.
The calls to the API can be done concurrently.
Args:
documents (list): A list of documents to insert.
options (dict, optional): Additional options for the insert operation.
partial_failures_allowed (bool, optional): Whether to allow partial failures.
chunk_size (int, optional): Override the default insertion chunk size.
concurrency (int, optional): The number of concurrent calls.
Returns:
list: The responses from the database for each chunk (or an exception if the call failed).
"""
sem = asyncio.Semaphore(concurrency)

async def concurrent_insert_many(
docs: List[API_DOC], index: int
) -> API_RESPONSE:
async with sem:
logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}")
return await self.insert_many(
documents=docs,
options=options,
partial_failures_allowed=partial_failures_allowed,
)

tasks = [
asyncio.create_task(
concurrent_insert_many(documents[i : i + chunk_size], i)
)
for i in range(0, len(documents), chunk_size)
]
return await asyncio.gather(*tasks, return_exceptions=partial_failures_allowed)

async def update_one(
self, filter: Dict[str, Any], update: Dict[str, Any]
) -> API_RESPONSE:
Expand Down Expand Up @@ -1566,6 +1609,31 @@ async def upsert(self, document: API_DOC) -> str:

return upserted_id

async def upsert_many(
self,
documents: list[API_DOC],
concurrency: int = 1,
partial_failures_allowed: bool = False,
) -> List[Union[str, BaseException]]:
"""
Emulate an upsert operation for multiple documents in the collection.
This method attempts to insert the documents. If a document with the same _id exists, it updates the existing document.
Args:
documents (List[dict]): The documents to insert or update.
concurrency (int, optional): The number of concurrent updates.
partial_failures_allowed (bool, optional): Whether to allow partial failures in the batch.
Returns:
List[Union[str, BaseException]]: A list of "_id"s of the inserted or updated documents (or an exception if the call failed)
"""
sem = asyncio.Semaphore(concurrency)

async def concurrent_upsert(doc: API_DOC) -> str:
async with sem:
return await self.upsert(document=doc)

tasks = [asyncio.create_task(concurrent_upsert(doc)) for doc in documents]
return await asyncio.gather(*tasks, return_exceptions=partial_failures_allowed)


class AstraDB:
# Initialize the shared httpx client as a class attribute
Expand Down
2 changes: 2 additions & 0 deletions astrapy/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
DEFAULT_AUTH_HEADER = "Token"
DEFAULT_KEYSPACE_NAME = "default_keyspace"
DEFAULT_REGION = "us-east1"

MAX_INSERT_BATCH_SIZE = 20
85 changes: 84 additions & 1 deletion tests/astrapy/test_async_db_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async def test_insert_many(
{
"_id": _id2,
"name": "Ciccio",
"description": "The thid in this list",
"description": "The third in this list",
"$vector": [0.4, 0.3],
},
]
Expand All @@ -319,6 +319,25 @@ async def test_insert_many(
assert isinstance(list(inserted_ids - {_id0, _id2})[0], str)


@pytest.mark.describe("insert_many_chunked")
async def test_insert_many_chunked(
async_writable_vector_collection: AsyncAstraDBCollection,
) -> None:
documents: List[API_DOC] = [{"name": "Abba"}] * 30
response = await async_writable_vector_collection.insert_many_chunked(
documents, concurrency=2
)
assert response is not None
inserted_ids_list = [
set(r["status"]["insertedIds"])
for r in response
if not isinstance(r, BaseException)
]
inserted_ids = set.union(*inserted_ids_list)
assert len(inserted_ids) == 30
assert isinstance(list(inserted_ids)[0], str)


@pytest.mark.describe("insert_many with 'ordered' set to False")
async def test_insert_many_ordered_false(
async_writable_vector_collection: AsyncAstraDBCollection,
Expand Down Expand Up @@ -419,6 +438,70 @@ async def test_upsert_document(
assert response1["data"]["document"] == document1


@pytest.mark.describe("upsert many")
async def test_upsert_many(
async_writable_vector_collection: AsyncAstraDBCollection,
) -> None:
_id0 = str(uuid.uuid4())
_id1 = str(uuid.uuid4())

documents = [
{
"_id": _id0,
"addresses": {
"work": {
"city": "Seattle",
"state": "WA",
},
},
},
{
"_id": _id1,
"addresses": {
"work": {
"city": "Seattle",
"state": "WA",
},
},
},
]
upsert_result0 = await async_writable_vector_collection.upsert_many(
documents, concurrency=2, partial_failures_allowed=True
)
assert upsert_result0[0] == _id0
assert upsert_result0[1] == _id1

response0 = await async_writable_vector_collection.find_one(filter={"_id": _id0})
assert response0 is not None
assert response0["data"]["document"] == documents[0]

response0 = await async_writable_vector_collection.find_one(filter={"_id": _id1})
assert response0 is not None
assert response0["data"]["document"] == documents[1]

documents2 = [
{
"_id": _id0,
"addresses": {
"work": {
"state": "MN",
"floor": 12,
},
},
"hobbies": [
"ice skating",
"accounting",
],
}
]
upsert_result2 = await async_writable_vector_collection.upsert_many(documents2)
assert upsert_result2[0] == _id0

response2 = await async_writable_vector_collection.find_one(filter={"_id": _id0})
assert response2 is not None
assert response2["data"]["document"] == documents2[0]


@pytest.mark.describe("update_one to create a subdocument, not through vector")
async def test_update_one_create_subdocument_novector(
async_disposable_vector_collection: AsyncAstraDBCollection,
Expand Down

0 comments on commit a3b3e56

Please sign in to comment.