Skip to content

Commit

Permalink
Add async chunked_insert_many and upsert_many
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 10, 2024
1 parent 4e63996 commit d6902d6
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 1 deletion.
73 changes: 73 additions & 0 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,52 @@ async def insert_many(

return response

async def chunked_insert_many(
self,
documents: List[API_DOC],
options: Optional[Dict[str, Any]] = None,
partial_failures_allowed: bool = False,
chunk_size: int = MAX_INSERT_NUM_DOCUMENTS,
concurrency: int = 1,
) -> List[Union[API_RESPONSE, BaseException]]:
"""
Insert multiple documents into the collection, handling chunking and
optionally with concurrent insertions.
Args:
documents (list): A list of documents to insert.
options (dict, optional): Additional options for the insert operation.
partial_failures_allowed (bool, optional): Whether to allow partial
failures in the chunk. Should be used combined with
options={"ordered": False} in most cases.
chunk_size (int, optional): Override the default insertion chunk size.
concurrency (int, optional): The number of concurrent chunk insertions.
Default is no concurrency.
Returns:
list: The responses from the database after the chunked insert operation.
This is a list of individual responses from the API: the caller
will need to inspect them all, e.g. to collate the inserted IDs.
"""
sem = asyncio.Semaphore(concurrency)

async def concurrent_insert_many(
docs: List[API_DOC], index: int
) -> API_RESPONSE:
async with sem:
logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}")
return await self.insert_many(
documents=docs,
options=options,
partial_failures_allowed=partial_failures_allowed,
)

tasks = [
asyncio.create_task(
concurrent_insert_many(documents[i : i + chunk_size], i)
)
for i in range(0, len(documents), chunk_size)
]
return await asyncio.gather(*tasks, return_exceptions=partial_failures_allowed)

async def update_one(
self, filter: Dict[str, Any], update: Dict[str, Any]
) -> API_RESPONSE:
Expand Down Expand Up @@ -1651,6 +1697,33 @@ async def upsert(self, document: API_DOC) -> str:

return upserted_id

async def upsert_many(
self,
documents: list[API_DOC],
concurrency: int = 1,
partial_failures_allowed: bool = False,
) -> List[Union[str, BaseException]]:
"""
Emulate an upsert operation for multiple documents in the collection.
This method attempts to insert the documents.
If a document with the same _id exists, it updates the existing document.
Args:
documents (List[dict]): The documents to insert or update.
concurrency (int, optional): The number of concurrent upserts.
partial_failures_allowed (bool, optional): Whether to allow partial
failures in the batch.
Returns:
List[str]: A list of "_id"s of the inserted or updated documents.
"""
sem = asyncio.Semaphore(concurrency)

async def concurrent_upsert(doc: API_DOC) -> str:
async with sem:
return await self.upsert(document=doc)

tasks = [asyncio.create_task(concurrent_upsert(doc)) for doc in documents]
return await asyncio.gather(*tasks, return_exceptions=partial_failures_allowed)


class AstraDB:
# Initialize the shared httpx client as a class attribute
Expand Down
222 changes: 221 additions & 1 deletion tests/astrapy/test_async_db_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async def test_insert_many(
{
"_id": _id2,
"name": "Ciccio",
"description": "The thid in this list",
"description": "The third in this list",
"$vector": [0.4, 0.3],
},
]
Expand All @@ -319,6 +319,165 @@ async def test_insert_many(
assert isinstance(list(inserted_ids - {_id0, _id2})[0], str)


@pytest.mark.describe("chunked_insert_many")
async def test_chunked_insert_many(
async_writable_vector_collection: AsyncAstraDBCollection,
) -> None:
_ids0 = [str(uuid.uuid4()) for _ in range(20)]
documents0: List[API_DOC] = [
{
"_id": _id,
"specs": {
"gen": "0",
"doc_idx": doc_idx,
},
"$vector": [1, doc_idx],
}
for doc_idx, _id in enumerate(_ids0)
]

responses0 = await async_writable_vector_collection.chunked_insert_many(
documents0, chunk_size=3
)
assert responses0 is not None
inserted_ids0 = [
ins_id
for response in responses0
for ins_id in response["status"]["insertedIds"]
]
assert inserted_ids0 == _ids0

response0a = await async_writable_vector_collection.find_one(
filter={"_id": _ids0[0]}
)
assert response0a is not None
assert response0a["data"]["document"] == documents0[0]

# partial overlap of IDs for failure modes
_ids1 = [
_id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0)
]
documents1: List[API_DOC] = [
{
"_id": _id,
"specs": {
"gen": "1",
"doc_idx": doc_idx,
},
"$vector": [1, doc_idx],
}
for doc_idx, _id in enumerate(_ids1)
]

with pytest.raises(ValueError):
_ = await async_writable_vector_collection.chunked_insert_many(
documents1, chunk_size=3
)

responses1_ok = await async_writable_vector_collection.chunked_insert_many(
documents1,
chunk_size=3,
options={"ordered": False},
partial_failures_allowed=True,
)
inserted_ids1 = [
ins_id
for response in responses1_ok
if "status" in response and "insertedIds" in response["status"]
for ins_id in response["status"]["insertedIds"]
]
# insertions that succeeded are those with a new ID
assert set(inserted_ids1) == set(_ids1) - set(_ids0)
# we can check that the failures are as many as the preexisting docs
errors1 = [
err
for response in responses1_ok
if "errors" in response
for err in response["errors"]
]
assert len(set(_ids0) & set(_ids1)) == len(errors1)


@pytest.mark.describe("chunked_insert_many concurrently")
async def test_concurrent_chunked_insert_many(
async_writable_vector_collection: AsyncAstraDBCollection,
) -> None:
_ids0 = [str(uuid.uuid4()) for _ in range(20)]
documents0: List[API_DOC] = [
{
"_id": _id,
"specs": {
"gen": "0",
"doc_idx": doc_idx,
},
"$vector": [2, doc_idx],
}
for doc_idx, _id in enumerate(_ids0)
]

responses0 = await async_writable_vector_collection.chunked_insert_many(
documents0, chunk_size=3, concurrency=4
)
assert responses0 is not None
inserted_ids0 = [
ins_id
for response in responses0
for ins_id in response["status"]["insertedIds"]
]
assert inserted_ids0 == _ids0

response0a = await async_writable_vector_collection.find_one(
filter={"_id": _ids0[0]}
)
assert response0a is not None
assert response0a["data"]["document"] == documents0[0]

# partial overlap of IDs for failure modes
_ids1 = [
_id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0)
]
documents1: List[API_DOC] = [
{
"_id": _id,
"specs": {
"gen": "1",
"doc_idx": doc_idx,
},
"$vector": [1, doc_idx],
}
for doc_idx, _id in enumerate(_ids1)
]

with pytest.raises(ValueError):
_ = await async_writable_vector_collection.chunked_insert_many(
documents1, chunk_size=3, concurrency=4
)

responses1_ok = await async_writable_vector_collection.chunked_insert_many(
documents1,
chunk_size=3,
options={"ordered": False},
partial_failures_allowed=True,
concurrency=4,
)
inserted_ids1 = [
ins_id
for response in responses1_ok
if "status" in response and "insertedIds" in response["status"]
for ins_id in response["status"]["insertedIds"]
]
# insertions that succeeded are those with a new ID
assert set(inserted_ids1) == set(_ids1) - set(_ids0)
# we can check that the failures are as many as the preexisting docs
errors1 = [
err
for response in responses1_ok
if "errors" in response
for err in response["errors"]
]
assert len(set(_ids0) & set(_ids1)) == len(errors1)


@pytest.mark.describe("insert_many with 'ordered' set to False")
async def test_insert_many_ordered_false(
async_writable_vector_collection: AsyncAstraDBCollection,
Expand Down Expand Up @@ -376,6 +535,67 @@ async def test_insert_many_ordered_false(
assert check_response["data"]["document"]["_id"] == _id1


@pytest.mark.describe("upsert_many")
async def test_upsert_many(
async_writable_vector_collection: AsyncAstraDBCollection,
) -> None:
_ids0 = [str(uuid.uuid4()) for _ in range(12)]
documents0 = [
{
"_id": _id,
"specs": {
"gen": "0",
"doc_i": doc_i,
},
}
for doc_i, _id in enumerate(_ids0)
]

upsert_result0 = await async_writable_vector_collection.upsert_many(documents0)
assert upsert_result0 == [doc["_id"] for doc in documents0]

response0a = await async_writable_vector_collection.find_one(
filter={"_id": _ids0[0]}
)
assert response0a is not None
assert response0a["data"]["document"] == documents0[0]

response0b = await async_writable_vector_collection.find_one(
filter={"_id": _ids0[-1]}
)
assert response0b is not None
assert response0b["data"]["document"] == documents0[-1]

_ids1 = _ids0[::2] + [str(uuid.uuid4()) for _ in range(3)]
documents1 = [
{
"_id": _id,
"specs": {
"gen": "1",
"doc_i": doc_i,
},
}
for doc_i, _id in enumerate(_ids1)
]
upsert_result1 = await async_writable_vector_collection.upsert_many(
documents1,
concurrency=5,
)
assert upsert_result1 == [doc["_id"] for doc in documents1]

response1a = await async_writable_vector_collection.find_one(
filter={"_id": _ids1[0]}
)
assert response1a is not None
assert response1a["data"]["document"] == documents1[0]

response1b = await async_writable_vector_collection.find_one(
filter={"_id": _ids1[-1]}
)
assert response1b is not None
assert response1b["data"]["document"] == documents1[-1]


@pytest.mark.describe("upsert")
async def test_upsert_document(
async_writable_vector_collection: AsyncAstraDBCollection,
Expand Down

0 comments on commit d6902d6

Please sign in to comment.