diff --git a/astrapy/idiomatic/collection.py b/astrapy/idiomatic/collection.py index 8ec33209..1ff99fd3 100644 --- a/astrapy/idiomatic/collection.py +++ b/astrapy/idiomatic/collection.py @@ -14,8 +14,10 @@ from __future__ import annotations +import asyncio import json -from typing import Any, Dict, Iterable, List, Optional, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Iterable, List, Optional, Union, TYPE_CHECKING from astrapy.db import AstraDBCollection, AsyncAstraDBCollection from astrapy.idiomatic.types import ( @@ -30,11 +32,17 @@ InsertManyResult, InsertOneResult, UpdateResult, + BulkWriteResult, ) from astrapy.idiomatic.cursors import AsyncCursor, Cursor +if TYPE_CHECKING: + from astrapy.idiomatic.operations import AsyncBaseOperation, BaseOperation + + INSERT_MANY_CONCURRENCY = 20 +BULK_WRITE_CONCURRENCY = 10 def _prepare_update_info(status: Dict[str, Any]) -> Dict[str, Any]: @@ -143,6 +151,7 @@ def insert_one( if io_response["status"]["insertedIds"]: inserted_id = io_response["status"]["insertedIds"][0] return InsertOneResult( + raw_result=io_response, inserted_id=inserted_id, ) else: @@ -195,7 +204,11 @@ def insert_many( if isinstance(response, dict) for ins_id in (response.get("status") or {}).get("insertedIds", []) ] - return InsertManyResult(inserted_ids=inserted_ids) + return InsertManyResult( + # if we are here, cim_responses are all dicts (no exceptions) + raw_result=cim_responses, # type: ignore[arg-type] + inserted_ids=inserted_ids, + ) def find( self, @@ -473,7 +486,7 @@ def delete_many( raw_result=dm_responses, ) else: - # expected a non-negative integer (None : + # per API specs, deleted_count has to be a non-negative integer. return DeleteResult( deleted_count=deleted_count, raw_result=dm_responses, @@ -484,6 +497,37 @@ def delete_many( f"(gotten '${json.dumps(dm_responses)}')" ) + def bulk_write( + self, + requests: Iterable[BaseOperation], + *, + ordered: bool = True, + ) -> BulkWriteResult: + # lazy importing here against circular-import error + from astrapy.idiomatic.operations import reduce_bulk_write_results + + if ordered: + bulk_write_results = [ + operation.execute(self, operation_i) + for operation_i, operation in enumerate(requests) + ] + return reduce_bulk_write_results(bulk_write_results) + else: + with ThreadPoolExecutor(max_workers=BULK_WRITE_CONCURRENCY) as executor: + bulk_write_futures = [ + executor.submit( + operation.execute, + self, + operation_i, + ) + for operation_i, operation in enumerate(requests) + ] + bulk_write_results = [ + bulk_write_future.result() + for bulk_write_future in bulk_write_futures + ] + return reduce_bulk_write_results(bulk_write_results) + class AsyncCollection: def __init__( @@ -579,6 +623,7 @@ async def insert_one( if io_response["status"]["insertedIds"]: inserted_id = io_response["status"]["insertedIds"][0] return InsertOneResult( + raw_result=io_response, inserted_id=inserted_id, ) else: @@ -631,7 +676,11 @@ async def insert_many( if isinstance(response, dict) for ins_id in (response.get("status") or {}).get("insertedIds", []) ] - return InsertManyResult(inserted_ids=inserted_ids) + return InsertManyResult( + # if we are here, cim_responses are all dicts (no exceptions) + raw_result=cim_responses, # type: ignore[arg-type] + inserted_ids=inserted_ids, + ) def find( self, @@ -916,7 +965,7 @@ async def delete_many( raw_result=dm_responses, ) else: - # expected a non-negative integer (None : + # per API specs, deleted_count has to be a non-negative integer. return DeleteResult( deleted_count=deleted_count, raw_result=dm_responses, @@ -926,3 +975,40 @@ async def delete_many( "Could not complete a chunked_delete_many operation. " f"(gotten '${json.dumps(dm_responses)}')" ) + + async def bulk_write( + self, + requests: Iterable[AsyncBaseOperation], + *, + ordered: bool = True, + ) -> BulkWriteResult: + # lazy importing here against circular-import error + from astrapy.idiomatic.operations import reduce_bulk_write_results + + if ordered: + bulk_write_results = [ + await operation.execute(self, operation_i) + for operation_i, operation in enumerate(requests) + ] + return reduce_bulk_write_results(bulk_write_results) + else: + sem = asyncio.Semaphore(BULK_WRITE_CONCURRENCY) + + async def concurrent_execute_operation( + operation: AsyncBaseOperation, + collection: AsyncCollection, + index_in_bulk_write: int, + ) -> BulkWriteResult: + async with sem: + return await operation.execute( + collection=collection, index_in_bulk_write=index_in_bulk_write + ) + + tasks = [ + asyncio.create_task( + concurrent_execute_operation(operation, self, operation_i) + ) + for operation_i, operation in enumerate(requests) + ] + bulk_write_results = await asyncio.gather(*tasks) + return reduce_bulk_write_results(bulk_write_results) diff --git a/astrapy/idiomatic/operations.py b/astrapy/idiomatic/operations.py new file mode 100644 index 00000000..1a03a214 --- /dev/null +++ b/astrapy/idiomatic/operations.py @@ -0,0 +1,542 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import reduce +from typing import ( + Any, + Dict, + Iterable, + List, +) + +from astrapy.idiomatic.types import DocumentType +from astrapy.idiomatic.results import BulkWriteResult +from astrapy.idiomatic.collection import AsyncCollection, Collection + + +def reduce_bulk_write_results(results: List[BulkWriteResult]) -> BulkWriteResult: + zero = BulkWriteResult( + bulk_api_results={}, + deleted_count=0, + inserted_count=0, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + def _sum_results(r1: BulkWriteResult, r2: BulkWriteResult) -> BulkWriteResult: + bulk_api_results = {**r1.bulk_api_results, **r2.bulk_api_results} + if r1.deleted_count is None or r2.deleted_count is None: + deleted_count = None + else: + deleted_count = r1.deleted_count + r2.deleted_count + inserted_count = r1.inserted_count + r2.inserted_count + matched_count = r1.matched_count + r2.matched_count + modified_count = r1.modified_count + r2.modified_count + upserted_count = r1.upserted_count + r2.upserted_count + upserted_ids = {**r1.upserted_ids, **r2.upserted_ids} + return BulkWriteResult( + bulk_api_results=bulk_api_results, + deleted_count=deleted_count, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=modified_count, + upserted_count=upserted_count, + upserted_ids=upserted_ids, + ) + + return reduce(_sum_results, results, zero) + + +class BaseOperation(ABC): + @abstractmethod + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: ... + + +@dataclass +class InsertOne(BaseOperation): + document: DocumentType + + def __init__( + self, + document: DocumentType, + ) -> None: + self.document = document + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.insert_one(document=self.document) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=1, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +@dataclass +class InsertMany(BaseOperation): + documents: Iterable[DocumentType] + ordered: bool + + def __init__( + self, + documents: Iterable[DocumentType], + ordered: bool = True, + ) -> None: + self.documents = documents + self.ordered = ordered + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.insert_many( + documents=self.documents, + ordered=self.ordered, + ) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=len(op_result.inserted_ids), + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +@dataclass +class UpdateOne(BaseOperation): + filter: Dict[str, Any] + update: Dict[str, Any] + upsert: bool + + def __init__( + self, + filter: Dict[str, Any], + update: Dict[str, Any], + *, + upsert: bool = False, + ) -> None: + self.filter = filter + self.update = update + self.upsert = upsert + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.update_one( + filter=self.filter, + update=self.update, + upsert=self.upsert, + ) + inserted_count = 1 if "upserted" in op_result.update_info else 0 + matched_count = (op_result.update_info.get("n") or 0) - inserted_count + if "upserted" in op_result.update_info: + upserted_ids = {index_in_bulk_write: op_result.update_info["upserted"]} + else: + upserted_ids = {} + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=op_result.update_info.get("nModified") or 0, + upserted_count=1 if "upserted" in op_result.update_info else 0, + upserted_ids=upserted_ids, + ) + + +@dataclass +class UpdateMany(BaseOperation): + filter: Dict[str, Any] + update: Dict[str, Any] + upsert: bool + + def __init__( + self, + filter: Dict[str, Any], + update: Dict[str, Any], + *, + upsert: bool = False, + ) -> None: + self.filter = filter + self.update = update + self.upsert = upsert + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.update_many( + filter=self.filter, + update=self.update, + upsert=self.upsert, + ) + inserted_count = 1 if "upserted" in op_result.update_info else 0 + matched_count = (op_result.update_info.get("n") or 0) - inserted_count + if "upserted" in op_result.update_info: + upserted_ids = {index_in_bulk_write: op_result.update_info["upserted"]} + else: + upserted_ids = {} + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=op_result.update_info.get("nModified") or 0, + upserted_count=1 if "upserted" in op_result.update_info else 0, + upserted_ids=upserted_ids, + ) + + +@dataclass +class ReplaceOne(BaseOperation): + filter: Dict[str, Any] + replacement: DocumentType + upsert: bool + + def __init__( + self, + filter: Dict[str, Any], + replacement: DocumentType, + *, + upsert: bool = False, + ) -> None: + self.filter = filter + self.replacement = replacement + self.upsert = upsert + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.replace_one( + filter=self.filter, + replacement=self.replacement, + upsert=self.upsert, + ) + inserted_count = 1 if "upserted" in op_result.update_info else 0 + matched_count = (op_result.update_info.get("n") or 0) - inserted_count + if "upserted" in op_result.update_info: + upserted_ids = {index_in_bulk_write: op_result.update_info["upserted"]} + else: + upserted_ids = {} + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=op_result.update_info.get("nModified") or 0, + upserted_count=1 if "upserted" in op_result.update_info else 0, + upserted_ids=upserted_ids, + ) + + +@dataclass +class DeleteOne(BaseOperation): + filter: Dict[str, Any] + + def __init__( + self, + filter: Dict[str, Any], + ) -> None: + self.filter = filter + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.delete_one(filter=self.filter) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=op_result.deleted_count, + inserted_count=0, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +@dataclass +class DeleteMany(BaseOperation): + filter: Dict[str, Any] + + def __init__( + self, + filter: Dict[str, Any], + ) -> None: + self.filter = filter + + def execute( + self, collection: Collection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = collection.delete_many(filter=self.filter) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=op_result.deleted_count, + inserted_count=0, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +class AsyncBaseOperation(ABC): + @abstractmethod + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: ... + + +@dataclass +class AsyncInsertOne(AsyncBaseOperation): + document: DocumentType + + def __init__( + self, + document: DocumentType, + ) -> None: + self.document = document + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.insert_one(document=self.document) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=1, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +@dataclass +class AsyncInsertMany(AsyncBaseOperation): + documents: Iterable[DocumentType] + ordered: bool + + def __init__( + self, + documents: Iterable[DocumentType], + ordered: bool = True, + ) -> None: + self.documents = documents + self.ordered = ordered + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.insert_many( + documents=self.documents, + ordered=self.ordered, + ) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=len(op_result.inserted_ids), + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +@dataclass +class AsyncUpdateOne(AsyncBaseOperation): + filter: Dict[str, Any] + update: Dict[str, Any] + upsert: bool + + def __init__( + self, + filter: Dict[str, Any], + update: Dict[str, Any], + *, + upsert: bool = False, + ) -> None: + self.filter = filter + self.update = update + self.upsert = upsert + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.update_one( + filter=self.filter, + update=self.update, + upsert=self.upsert, + ) + inserted_count = 1 if "upserted" in op_result.update_info else 0 + matched_count = (op_result.update_info.get("n") or 0) - inserted_count + if "upserted" in op_result.update_info: + upserted_ids = {index_in_bulk_write: op_result.update_info["upserted"]} + else: + upserted_ids = {} + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=op_result.update_info.get("nModified") or 0, + upserted_count=1 if "upserted" in op_result.update_info else 0, + upserted_ids=upserted_ids, + ) + + +@dataclass +class AsyncUpdateMany(AsyncBaseOperation): + filter: Dict[str, Any] + update: Dict[str, Any] + upsert: bool + + def __init__( + self, + filter: Dict[str, Any], + update: Dict[str, Any], + *, + upsert: bool = False, + ) -> None: + self.filter = filter + self.update = update + self.upsert = upsert + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.update_many( + filter=self.filter, + update=self.update, + upsert=self.upsert, + ) + inserted_count = 1 if "upserted" in op_result.update_info else 0 + matched_count = (op_result.update_info.get("n") or 0) - inserted_count + if "upserted" in op_result.update_info: + upserted_ids = {index_in_bulk_write: op_result.update_info["upserted"]} + else: + upserted_ids = {} + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=op_result.update_info.get("nModified") or 0, + upserted_count=1 if "upserted" in op_result.update_info else 0, + upserted_ids=upserted_ids, + ) + + +@dataclass +class AsyncReplaceOne(AsyncBaseOperation): + filter: Dict[str, Any] + replacement: DocumentType + upsert: bool + + def __init__( + self, + filter: Dict[str, Any], + replacement: DocumentType, + *, + upsert: bool = False, + ) -> None: + self.filter = filter + self.replacement = replacement + self.upsert = upsert + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.replace_one( + filter=self.filter, + replacement=self.replacement, + upsert=self.upsert, + ) + inserted_count = 1 if "upserted" in op_result.update_info else 0 + matched_count = (op_result.update_info.get("n") or 0) - inserted_count + if "upserted" in op_result.update_info: + upserted_ids = {index_in_bulk_write: op_result.update_info["upserted"]} + else: + upserted_ids = {} + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=0, + inserted_count=inserted_count, + matched_count=matched_count, + modified_count=op_result.update_info.get("nModified") or 0, + upserted_count=1 if "upserted" in op_result.update_info else 0, + upserted_ids=upserted_ids, + ) + + +@dataclass +class AsyncDeleteOne(AsyncBaseOperation): + filter: Dict[str, Any] + + def __init__( + self, + filter: Dict[str, Any], + ) -> None: + self.filter = filter + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.delete_one(filter=self.filter) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=op_result.deleted_count, + inserted_count=0, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + + +@dataclass +class AsyncDeleteMany(AsyncBaseOperation): + filter: Dict[str, Any] + + def __init__( + self, + filter: Dict[str, Any], + ) -> None: + self.filter = filter + + async def execute( + self, collection: AsyncCollection, index_in_bulk_write: int + ) -> BulkWriteResult: + op_result = await collection.delete_many(filter=self.filter) + return BulkWriteResult( + bulk_api_results={index_in_bulk_write: op_result.raw_result}, + deleted_count=op_result.deleted_count, + inserted_count=0, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) diff --git a/astrapy/idiomatic/results.py b/astrapy/idiomatic/results.py index 223f0ab4..13030d09 100644 --- a/astrapy/idiomatic/results.py +++ b/astrapy/idiomatic/results.py @@ -27,12 +27,14 @@ class DeleteResult: @dataclass class InsertOneResult: + raw_result: Dict[str, Any] inserted_id: Any acknowledged: bool = True @dataclass class InsertManyResult: + raw_result: List[Dict[str, Any]] inserted_ids: List[Any] acknowledged: bool = True @@ -42,3 +44,15 @@ class UpdateResult: raw_result: Dict[str, Any] update_info: Dict[str, Any] acknowledged: bool = True + + +@dataclass +class BulkWriteResult: + bulk_api_results: Dict[int, Union[Dict[str, Any], List[Dict[str, Any]]]] + deleted_count: Optional[int] + inserted_count: int + matched_count: int + modified_count: int + upserted_count: int + upserted_ids: Dict[int, Any] + acknowledged: bool = True diff --git a/tests/idiomatic/integration/test_dml_async.py b/tests/idiomatic/integration/test_dml_async.py index 72556896..c76a6c8e 100644 --- a/tests/idiomatic/integration/test_dml_async.py +++ b/tests/idiomatic/integration/test_dml_async.py @@ -22,6 +22,15 @@ from astrapy.idiomatic.types import DocumentType from astrapy.idiomatic.cursors import AsyncCursor from astrapy.idiomatic.types import ReturnDocument +from astrapy.idiomatic.operations import ( + AsyncInsertOne, + AsyncInsertMany, + AsyncUpdateOne, + AsyncUpdateMany, + AsyncReplaceOne, + AsyncDeleteOne, + AsyncDeleteMany, +) class TestDMLAsync: @@ -954,3 +963,73 @@ async def test_collection_find_one_and_update_async( assert resp_pr2 is not None assert set(resp_pr2.keys()) == {"f"} await acol.delete_many({}) + + @pytest.mark.describe("test of ordered bulk_write, async") + async def test_collection_ordered_bulk_write_async( + self, + async_empty_collection: AsyncCollection, + ) -> None: + acol = async_empty_collection + + bw_ops = [ + AsyncInsertOne({"seq": 0}), + AsyncInsertMany([{"seq": 1}, {"seq": 2}, {"seq": 3}]), + AsyncUpdateOne({"seq": 0}, {"$set": {"edited": 1}}), + AsyncUpdateMany({"seq": {"$gt": 0}}, {"$set": {"positive": True}}), + AsyncReplaceOne({"edited": 1}, {"seq": 0, "edited": 2}), + AsyncDeleteOne({"seq": 1}), + AsyncDeleteMany({"seq": {"$gt": 1}}), + AsyncReplaceOne( + {"no": "matches"}, {"_id": "seq4", "from_upsert": True}, upsert=True + ), + ] + + bw_result = await acol.bulk_write(bw_ops) + + assert bw_result.deleted_count == 3 + assert bw_result.inserted_count == 5 + assert bw_result.matched_count == 5 + assert bw_result.modified_count == 5 + assert bw_result.upserted_count == 1 + assert set(bw_result.upserted_ids.keys()) == {7} + + found_docs = sorted( + [doc async for doc in acol.find({})], + key=lambda doc: doc.get("seq", 10), + ) + assert len(found_docs) == 2 + assert found_docs[0]["seq"] == 0 + assert found_docs[0]["edited"] == 2 + assert "_id" in found_docs[0] + assert len(found_docs[0]) == 3 + assert found_docs[1] == {"_id": "seq4", "from_upsert": True} + + @pytest.mark.describe("test of unordered bulk_write, async") + async def test_collection_unordered_bulk_write_async( + self, + async_empty_collection: AsyncCollection, + ) -> None: + acol = async_empty_collection + + bw_u_ops = [ + AsyncInsertOne({"a": 1}), + AsyncUpdateOne({"b": 1}, {"$set": {"newfield": True}}, upsert=True), + AsyncDeleteMany({"x": 100}), + ] + + bw_u_result = await acol.bulk_write(bw_u_ops, ordered=False) + + assert bw_u_result.deleted_count == 0 + assert bw_u_result.inserted_count == 2 + assert bw_u_result.matched_count == 0 + assert bw_u_result.modified_count == 0 + assert bw_u_result.upserted_count == 1 + assert set(bw_u_result.upserted_ids.keys()) == {1} + + found_docs = [doc async for doc in acol.find({})] + no_id_found_docs = [ + {k: v for k, v in doc.items() if k != "_id"} for doc in found_docs + ] + assert len(no_id_found_docs) == 2 + assert {"a": 1} in no_id_found_docs + assert {"b": 1, "newfield": True} in no_id_found_docs diff --git a/tests/idiomatic/integration/test_dml_sync.py b/tests/idiomatic/integration/test_dml_sync.py index fb7359e9..13f8469a 100644 --- a/tests/idiomatic/integration/test_dml_sync.py +++ b/tests/idiomatic/integration/test_dml_sync.py @@ -18,6 +18,15 @@ from astrapy.results import DeleteResult, InsertOneResult from astrapy.api import APIRequestError from astrapy.idiomatic.types import ReturnDocument +from astrapy.idiomatic.operations import ( + InsertOne, + InsertMany, + UpdateOne, + UpdateMany, + ReplaceOne, + DeleteOne, + DeleteMany, +) class TestDMLSync: @@ -928,3 +937,73 @@ def test_collection_find_one_and_update_sync( assert resp_pr2 is not None assert set(resp_pr2.keys()) == {"f"} col.delete_many({}) + + @pytest.mark.describe("test of ordered bulk_write, sync") + def test_collection_ordered_bulk_write_sync( + self, + sync_empty_collection: Collection, + ) -> None: + col = sync_empty_collection + + bw_ops = [ + InsertOne({"seq": 0}), + InsertMany([{"seq": 1}, {"seq": 2}, {"seq": 3}]), + UpdateOne({"seq": 0}, {"$set": {"edited": 1}}), + UpdateMany({"seq": {"$gt": 0}}, {"$set": {"positive": True}}), + ReplaceOne({"edited": 1}, {"seq": 0, "edited": 2}), + DeleteOne({"seq": 1}), + DeleteMany({"seq": {"$gt": 1}}), + ReplaceOne( + {"no": "matches"}, {"_id": "seq4", "from_upsert": True}, upsert=True + ), + ] + + bw_result = col.bulk_write(bw_ops) + + assert bw_result.deleted_count == 3 + assert bw_result.inserted_count == 5 + assert bw_result.matched_count == 5 + assert bw_result.modified_count == 5 + assert bw_result.upserted_count == 1 + assert set(bw_result.upserted_ids.keys()) == {7} + + found_docs = sorted( + col.find({}), + key=lambda doc: doc.get("seq", 10), + ) + assert len(found_docs) == 2 + assert found_docs[0]["seq"] == 0 + assert found_docs[0]["edited"] == 2 + assert "_id" in found_docs[0] + assert len(found_docs[0]) == 3 + assert found_docs[1] == {"_id": "seq4", "from_upsert": True} + + @pytest.mark.describe("test of unordered bulk_write, sync") + def test_collection_unordered_bulk_write_sync( + self, + sync_empty_collection: Collection, + ) -> None: + col = sync_empty_collection + + bw_u_ops = [ + InsertOne({"a": 1}), + UpdateOne({"b": 1}, {"$set": {"newfield": True}}, upsert=True), + DeleteMany({"x": 100}), + ] + + bw_u_result = col.bulk_write(bw_u_ops, ordered=False) + + assert bw_u_result.deleted_count == 0 + assert bw_u_result.inserted_count == 2 + assert bw_u_result.matched_count == 0 + assert bw_u_result.modified_count == 0 + assert bw_u_result.upserted_count == 1 + assert set(bw_u_result.upserted_ids.keys()) == {1} + + found_docs = list(col.find({})) + no_id_found_docs = [ + {k: v for k, v in doc.items() if k != "_id"} for doc in found_docs + ] + assert len(no_id_found_docs) == 2 + assert {"a": 1} in no_id_found_docs + assert {"b": 1, "newfield": True} in no_id_found_docs diff --git a/tests/idiomatic/unit/test_bulk_write_results.py b/tests/idiomatic/unit/test_bulk_write_results.py new file mode 100644 index 00000000..d1fdad44 --- /dev/null +++ b/tests/idiomatic/unit/test_bulk_write_results.py @@ -0,0 +1,93 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from astrapy.idiomatic.results import BulkWriteResult +from astrapy.idiomatic.operations import reduce_bulk_write_results + + +class TestBulkWriteResults: + @pytest.mark.describe("test of reduction of bulk write results") + def test_reduce_bulk_write_results(self) -> None: + bwr1 = BulkWriteResult( + bulk_api_results={1: {"seq1": 1}}, + deleted_count=100, + inserted_count=200, + matched_count=300, + modified_count=400, + upserted_count=500, + upserted_ids={1: {"useq1": 1}}, + ) + bwr2 = BulkWriteResult( + bulk_api_results={}, + deleted_count=10, + inserted_count=20, + matched_count=30, + modified_count=40, + upserted_count=50, + upserted_ids={2: {"useq2": 2}}, + ) + bwr3 = BulkWriteResult( + bulk_api_results={3: {"seq3": 3}}, + deleted_count=1, + inserted_count=2, + matched_count=3, + modified_count=4, + upserted_count=5, + upserted_ids={}, + ) + + reduced_a = reduce_bulk_write_results([bwr1, bwr2, bwr3]) + expected_a = BulkWriteResult( + bulk_api_results={1: {"seq1": 1}, 3: {"seq3": 3}}, + deleted_count=111, + inserted_count=222, + matched_count=333, + modified_count=444, + upserted_count=555, + upserted_ids={1: {"useq1": 1}, 2: {"useq2": 2}}, + ) + assert reduced_a == expected_a + + bwr_n = BulkWriteResult( + bulk_api_results={}, + deleted_count=None, + inserted_count=0, + matched_count=0, + modified_count=0, + upserted_count=0, + upserted_ids={}, + ) + bwr_1 = BulkWriteResult( + bulk_api_results={}, + deleted_count=1, + inserted_count=1, + matched_count=1, + modified_count=1, + upserted_count=1, + upserted_ids={}, + ) + + reduced_n = reduce_bulk_write_results([bwr_1, bwr_n, bwr_1]) + expected_n = BulkWriteResult( + bulk_api_results={}, + deleted_count=None, + inserted_count=2, + matched_count=2, + modified_count=2, + upserted_count=2, + upserted_ids={}, + ) + assert reduced_n == expected_n