diff --git a/docs/source/NEWS.rst b/docs/source/NEWS.rst index a2e4c7f..948d8fc 100644 --- a/docs/source/NEWS.rst +++ b/docs/source/NEWS.rst @@ -22,6 +22,10 @@ API Changes - `find()` method now returns `Cursor()` instance that can be used as async generator to asynchronously iterate over results. It can still be used as Deferred too, so this change is backward-compatible. +- `Cursor()` options can be by chaining its methods, for example: + :: + async for doc in collection.find({"size": "L"}).sort({"price": 1}).limit(10).skip(5): + print(doc) - `find_with_cursor()` is deprecated and will be removed in the next release. diff --git a/tests/basic/test_collection.py b/tests/basic/test_collection.py index 101f863..1d7773d 100644 --- a/tests/basic/test_collection.py +++ b/tests/basic/test_collection.py @@ -38,7 +38,7 @@ def cmp(a, b): return (a > b) - (a < b) -class TestIndexInfo(unittest.TestCase): +class TestCollectionMethods(unittest.TestCase): timeout = 5 @@ -54,7 +54,7 @@ def tearDown(self): yield self.conn.disconnect() @defer.inlineCallbacks - def test_collection(self): + def test_type_checking(self): self.assertRaises(TypeError, Collection, self.db, 5) def make_col(base, name): @@ -76,6 +76,7 @@ def make_col(base, name): self.assertRaises(TypeError, self.db.test.find, projection="test") self.assertRaises(TypeError, self.db.test.find, skip="test") self.assertRaises(TypeError, self.db.test.find, limit="test") + self.assertRaises(TypeError, self.db.test.find, batch_size="test") self.assertRaises(TypeError, self.db.test.find, sort="test") self.assertRaises(TypeError, self.db.test.find, skip="test") self.assertRaises(TypeError, self.db.test.insert_many, [1]) @@ -105,9 +106,32 @@ def make_col(base, name): options = yield self.db.test.options() self.assertTrue(isinstance(options, dict)) + @defer.inlineCallbacks + def test_collection_names(self): + coll_names = [f"coll_{i}" for i in range(10)] + yield defer.gatherResults( + self.db[name].insert_one({"x": 1}) for name in coll_names + ) + + try: + names = yield self.db.collection_names() + self.assertEqual(set(coll_names), set(names)) + names = yield self.db.collection_names(batch_size=10) + self.assertEqual(set(coll_names), set(names)) + finally: + yield defer.gatherResults(self.db[name].drop() for name in coll_names) + + test_collection_names.timeout = 1500 + + @defer.inlineCallbacks + def test_drop_collection(self): + yield self.db.test.insert_one({"x": 1}) + collection_names = yield self.db.collection_names() + self.assertIn("test", collection_names) + yield self.db.drop_collection("test") collection_names = yield self.db.collection_names() - self.assertFalse("test" in collection_names) + self.assertNotIn("test", collection_names) @defer.inlineCallbacks def test_create_index(self): diff --git a/tests/basic/test_filters.py b/tests/basic/test_filters.py index 62b07b6..274b01c 100644 --- a/tests/basic/test_filters.py +++ b/tests/basic/test_filters.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import asynccontextmanager, contextmanager from pymongo.errors import OperationFailure from twisted.internet import defer @@ -39,19 +40,43 @@ def tearDown(self): yield self.db.system.profile.drop() yield self.conn.disconnect() - @defer.inlineCallbacks - def test_Hint(self): + @asynccontextmanager + async def _assert_single_command_with_option(self, optionname, optionvalue): + # Checking that `optionname` appears in profiler log with specified value + + await self.db.command("profile", 2) + yield + await self.db.command("profile", 0) + + profile_filter = {"command." + optionname: optionvalue} + cnt = await self.db.system.profile.count(profile_filter) + await self.db.system.profile.drop() + self.assertEqual(cnt, 1) + + async def test_Hint(self): # find() should fail with 'bad hint' if hint specifier works correctly self.assertFailure( self.coll.find({}, sort=qf.hint([("x", 1)])), OperationFailure ) + self.assertFailure(self.coll.find().hint({"x": 1}), OperationFailure) # create index and test it is honoured - yield self.coll.create_index(qf.sort(qf.ASCENDING("x")), name="test_index") - found_1 = yield self.coll.find({}, sort=qf.hint([("x", 1)])) - found_2 = yield self.coll.find({}, sort=qf.hint(qf.ASCENDING("x"))) - found_3 = yield self.coll.find({}, sort=qf.hint("test_index")) - self.assertTrue(found_1 == found_2 == found_3) + await self.coll.create_index(qf.sort(qf.ASCENDING("x")), name="test_index") + forms = [ + [("x", 1)], + {"x": 1}, + qf.ASCENDING("x"), + ] + for form in forms: + async with self._assert_single_command_with_option("hint", {"x": 1}): + await self.coll.find({}, sort=qf.hint(form)) + async with self._assert_single_command_with_option("hint", {"x": 1}): + await self.coll.find().hint(form) + + async with self._assert_single_command_with_option("hint", "test_index"): + await self.coll.find({}, sort=qf.hint("test_index")) + async with self._assert_single_command_with_option("hint", "test_index"): + await self.coll.find().hint("test_index") # find() should fail with 'bad hint' if hint specifier works correctly self.assertFailure( @@ -67,6 +92,10 @@ def test_SortAscendingMultipleFields(self): qf.sort(qf.ASCENDING(["x", "y"])), qf.sort(qf.ASCENDING("x") + qf.ASCENDING("y")), ) + self.assertEqual( + qf.sort(qf.ASCENDING(["x", "y"])), + qf.sort({"x": 1, "y": 1}), + ) def test_SortOneLevelList(self): self.assertEqual(qf.sort([("x", 1)]), qf.sort(("x", 1))) @@ -74,6 +103,7 @@ def test_SortOneLevelList(self): def test_SortInvalidKey(self): self.assertRaises(TypeError, qf.sort, [(1, 2)]) self.assertRaises(TypeError, qf.sort, [("x", 3)]) + self.assertRaises(TypeError, qf.sort, {"x": 3}) def test_SortGeoIndexes(self): self.assertEqual(qf.sort(qf.GEO2D("x")), qf.sort([("x", "2d")])) @@ -83,45 +113,33 @@ def test_SortGeoIndexes(self): def test_TextIndex(self): self.assertEqual(qf.sort(qf.TEXT("title")), qf.sort([("title", "text")])) - def __3_2_or_higher(self): - return self.db.command("buildInfo").addCallback( - lambda info: info["versionArray"] >= [3, 2] - ) - - def __3_6_or_higher(self): - return self.db.command("buildInfo").addCallback( - lambda info: info["versionArray"] >= [3, 6] - ) - - @defer.inlineCallbacks - def __test_simple_filter(self, filter, optionname, optionvalue): - # Checking that `optionname` appears in profiler log with specified value - - yield self.db.command("profile", 2) - yield self.coll.find({}, sort=filter) - yield self.db.command("profile", 0) - - if (yield self.__3_6_or_higher()): - profile_filter = {"command." + optionname: optionvalue} - elif (yield self.__3_2_or_higher()): - # query options format in system.profile have changed in MongoDB 3.2 - profile_filter = {"query." + optionname: optionvalue} - else: - profile_filter = {"query.$" + optionname: optionvalue} - - cnt = yield self.db.system.profile.count(profile_filter) - self.assertEqual(cnt, 1) - - @defer.inlineCallbacks - def test_Comment(self): + async def test_SortProfile(self): + forms = [ + qf.DESCENDING("x"), + {"x": -1}, + [("x", -1)], + ("x", -1), + ] + for form in forms: + async with self._assert_single_command_with_option("sort.x", -1): + await self.coll.find({}, sort=qf.sort(form)) + async with self._assert_single_command_with_option("sort.x", -1): + await self.coll.find().sort(form) + + async def test_Comment(self): comment = "hello world" - yield self.__test_simple_filter(qf.comment(comment), "comment", comment) + async with self._assert_single_command_with_option("comment", comment): + await self.coll.find({}, sort=qf.comment(comment)) + async with self._assert_single_command_with_option("comment", comment): + await self.coll.find().comment(comment) @defer.inlineCallbacks def test_Explain(self): result = yield self.coll.find({}, sort=qf.explain()) self.assertTrue("executionStats" in result[0] or "nscanned" in result[0]) + result = yield self.coll.find().explain() + self.assertTrue("executionStats" in result[0] or "nscanned" in result[0]) @defer.inlineCallbacks def test_FilterMerge(self): @@ -136,12 +154,7 @@ def test_FilterMerge(self): yield self.coll.find({}, sort=qf.sort(qf.ASCENDING("x")) + qf.comment(comment)) yield self.db.command("profile", 0) - if (yield self.__3_6_or_higher()): - profile_filter = {"command.sort.x": 1, "command.comment": comment} - elif (yield self.__3_2_or_higher()): - profile_filter = {"query.sort.x": 1, "query.comment": comment} - else: - profile_filter = {"query.$orderby.x": 1, "query.$comment": comment} + profile_filter = {"command.sort.x": 1, "command.comment": comment} cnt = yield self.db.system.profile.count(profile_filter) self.assertEqual(cnt, 1) diff --git a/tests/basic/test_queries.py b/tests/basic/test_queries.py index 015f196..643d650 100644 --- a/tests/basic/test_queries.py +++ b/tests/basic/test_queries.py @@ -41,7 +41,6 @@ only_for_mongodb_starting_from, ) from tests.utils import SingleCollectionTest -from txmongo.collection import Cursor from txmongo.errors import TimeExceeded from txmongo.protocol import MongoProtocol @@ -56,12 +55,12 @@ def __call__(self, this, *args, **kwargs): return self.original(this, *args, **kwargs) -class TestMongoQueries(SingleCollectionTest): +class TestFind(SingleCollectionTest): timeout = 15 @defer.inlineCallbacks - def test_find_return_type(self): + def test_FindReturnType(self): dfr = self.coll.find() dfr_one = self.coll.find_one() try: @@ -106,7 +105,7 @@ def test_FindWithCursorLimit(self): self.assertEqual(total, 150) @defer.inlineCallbacks - def test_FindWithCursorBatchsize(self): + def test_FindWithCursorBatchSize(self): self.assertRaises(TypeError, self.coll.find_with_cursor, batch_size="string") yield self.coll.insert_many([{"v": i} for i in range(140)]) @@ -119,7 +118,7 @@ def test_FindWithCursorBatchsize(self): self.assertEqual(lengths, [50, 50, 40]) @defer.inlineCallbacks - def test_FindWithCursorBatchsizeLimit(self): + def test_FindWithCursorBatchSizeLimit(self): yield self.coll.insert_many([{"v": i} for i in range(140)]) docs, d = yield self.coll.find_with_cursor(batch_size=50, limit=10) @@ -130,7 +129,7 @@ def test_FindWithCursorBatchsizeLimit(self): self.assertEqual(lengths, [10]) @defer.inlineCallbacks - def test_FindWithCursorZeroBatchsize(self): + def test_FindWithCursorZeroBatchSize(self): yield self.coll.insert_many([{"v": i} for i in range(140)]) docs, d = yield self.coll.find_with_cursor(batch_size=0) @@ -149,13 +148,23 @@ def test_LargeData(self): @defer.inlineCallbacks def test_SpecifiedFields(self): yield self.coll.insert_many([dict((k, v) for k in "abcdefg") for v in range(5)]) - res = yield self.coll.find(projection={"a": 1, "c": 1}) - self.assertTrue(all(x in ["a", "c", "_id"] for x in res[0].keys())) + + res = yield self.coll.find(projection={"a": 1, "c": 1, "_id": 0}) + self.assertTrue(all(set(x.keys()) == {"a", "c"} for x in res)) res = yield self.coll.find(projection=["a", "c"]) - self.assertTrue(all(x in ["a", "c", "_id"] for x in res[0].keys())) + self.assertTrue(all(set(x.keys()) == {"a", "c", "_id"} for x in res)) res = yield self.coll.find(projection=[]) - self.assertTrue(all(x in ["_id"] for x in res[0].keys())) - self.assertRaises(TypeError, self.coll.find, {}, projection=[1]) + self.assertTrue(all(set(x.keys()) == {"_id"} for x in res)) + yield self.assertFailure(self.coll.find({}, projection=[1]), TypeError) + + # Alternative form + res = yield self.coll.find().projection({"a": 1, "c": 1, "_id": 0}) + self.assertTrue(all(set(x.keys()) == {"a", "c"} for x in res)) + res = yield self.coll.find().projection(["a", "c"]) + self.assertTrue(all(set(x.keys()) == {"a", "c", "_id"} for x in res)) + res = yield self.coll.find().projection([]) + self.assertTrue(all(set(x.keys()) == {"_id"} for x in res)) + yield self.assertFailure(self.coll.find().projection([1]), TypeError) def __make_big_object(self): return {"_id": ObjectId(), "x": "a" * 1000} @@ -218,9 +227,9 @@ def test_TimeoutAndDeadline(self): self.assertEqual(len(result), 10) # Timeout cases - dfr = self.coll.find({"$where": "sleep(55); true"}, timeout=0.5) + dfr = self.coll.find({"$where": "sleep(55); true"}).timeout(0.5) yield self.assertFailure(dfr, TimeExceeded) - dfr = self.coll.find({"$where": "sleep(55); true"}, timeout=0.5, batch_size=2) + dfr = self.coll.find({"$where": "sleep(55); true"}).timeout(0.5).batch_size(2) yield self.assertFailure(dfr, TimeExceeded) # Deadline cases @@ -284,7 +293,6 @@ def test_FindOneNone(self): @defer.inlineCallbacks def test_AllowPartialResults(self): - with patch.object( MongoProtocol, "send_msg", side_effect=MongoProtocol.send_msg, autospec=True ) as mock: @@ -295,6 +303,16 @@ def test_AllowPartialResults(self): cmd = bson.decode(msg.body) self.assertEqual(cmd["allowPartialResults"], True) + with patch.object( + MongoProtocol, "send_msg", side_effect=MongoProtocol.send_msg, autospec=True + ) as mock: + yield self.coll.find().limit(1).allow_partial_results() + + mock.assert_called_once() + msg = mock.call_args[0][1] + cmd = bson.decode(msg.body) + self.assertEqual(cmd["allowPartialResults"], True) + async def test_FindIterate(self): await self.coll.insert_many([{"b": i} for i in range(50)]) @@ -344,6 +362,61 @@ def test_IterateNextBatch(self): yield self.__check_no_open_cursors() + @defer.inlineCallbacks + def test_SettingOptionsAfterCommandIsSent(self): + yield self.coll.insert_many([{"c": i} for i in range(50)]) + + cursor = self.coll.find().batch_size(10) + yield cursor.next_batch() + + # all these commands should raise InvalidOperation because query command is already sent + self.assertRaises(InvalidOperation, cursor.projection, {"x": 1}) + self.assertRaises(InvalidOperation, cursor.sort, {"x": 1}) + self.assertRaises(InvalidOperation, cursor.hint, {"x": 1}) + self.assertRaises(InvalidOperation, cursor.comment, "hello") + self.assertRaises(InvalidOperation, cursor.explain) + self.assertRaises(InvalidOperation, cursor.skip, 10) + self.assertRaises(InvalidOperation, cursor.limit, 10) + self.assertRaises(InvalidOperation, cursor.batch_size, 10) + self.assertRaises(InvalidOperation, cursor.allow_partial_results) + self.assertRaises(InvalidOperation, cursor.timeout, 500) + + yield cursor.close() + yield self.__check_no_open_cursors() + + @defer.inlineCallbacks + def test_NextBatchBeforePreviousComplete(self): + """If next_batch() is called before previous one is fired, it will return the same batch""" + yield self.coll.insert_many([{"c": i} for i in range(50)]) + cursor = self.coll.find().batch_size(10) + + batches = yield defer.gatherResults([cursor.next_batch() for _ in range(5)]) + self.assertFalse(cursor.exhausted) + for batch in batches[1:]: + self.assertEqual(batch, batches[0]) + + batch2 = yield cursor.next_batch() + self.assertNotEqual(batch2, batches[0]) + + yield cursor.close() + + @defer.inlineCallbacks + def test_CursorId(self): + yield self.coll.insert_many([{"c": i} for i in range(50)]) + + cursor = self.coll.find().batch_size(45) + yield cursor.next_batch() + try: + self.assertIsInstance(cursor.cursor_id, int) + self.assertNotEqual(cursor.cursor_id, 0) + self.assertIsNotNone(cursor.cursor_id) + finally: + yield cursor.close() + + def test_CursorCollection(self): + cursor = self.coll.find().batch_size(45) + self.assertIs(cursor.collection, self.coll) + class TestLimit(SingleCollectionTest): @@ -354,42 +427,41 @@ def test_LimitBelowBatchThreshold(self): yield self.coll.insert_many([{"v": i} for i in range(50)]) res = yield self.coll.find(limit=20) self.assertEqual(len(res), 20) + res = yield self.coll.find().limit(20) + self.assertEqual(len(res), 20) @defer.inlineCallbacks def test_LimitAboveBatchThreshold(self): yield self.coll.insert_many([{"v": i} for i in range(200)]) res = yield self.coll.find(limit=150) self.assertEqual(len(res), 150) + res = yield self.coll.find().limit(150) + self.assertEqual(len(res), 150) @defer.inlineCallbacks def test_LimitAtBatchThresholdEdge(self): yield self.coll.insert_many([{"v": i} for i in range(200)]) - res = yield self.coll.find(limit=100) - self.assertEqual(len(res), 100) - - yield self.coll.drop() - - yield self.coll.insert_many([{"v": i} for i in range(200)]) - res = yield self.coll.find(limit=101) - self.assertEqual(len(res), 101) - - yield self.coll.drop() - - yield self.coll.insert_many([{"v": i} for i in range(200)]) - res = yield self.coll.find(limit=102) - self.assertEqual(len(res), 102) + for limit in [100, 101, 102]: + res = yield self.coll.find(limit=limit, batch_size=100) + self.assertEqual(len(res), limit) + res = yield self.coll.find().limit(limit).batch_size(100) + self.assertEqual(len(res), limit) @defer.inlineCallbacks def test_LimitAboveMessageSizeThreshold(self): yield self.coll.insert_many([{"v": " " * (2**20)} for _ in range(8)]) res = yield self.coll.find(limit=5) self.assertEqual(len(res), 5) + res = yield self.coll.find().limit(5) + self.assertEqual(len(res), 5) @defer.inlineCallbacks def test_HardLimit(self): yield self.coll.insert_many([{"v": i} for i in range(200)]) res = yield self.coll.find(limit=-150) self.assertEqual(len(res), 150) + res = yield self.coll.find().limit(-150) + self.assertEqual(len(res), 150) class TestSkip(SingleCollectionTest): @@ -399,44 +471,36 @@ class TestSkip(SingleCollectionTest): @defer.inlineCallbacks def test_Skip(self): yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=3) - self.assertEqual(len(res), 2) - yield self.coll.drop() + tests = { + 2: 3, + 3: 2, + 5: 0, + 6: 0, + } - yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=5) - self.assertEqual(len(res), 0) - - yield self.coll.drop() - - yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=6) - self.assertEqual(len(res), 0) + for skip, expected in tests.items(): + res = yield self.coll.find(skip=skip) + self.assertEqual(len(res), expected) + res = yield self.coll.find().skip(skip) + self.assertEqual(len(res), expected) @defer.inlineCallbacks def test_SkipWithLimit(self): yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=3, limit=1) - self.assertEqual(len(res), 1) - - yield self.coll.drop() - - yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=4, limit=2) - self.assertEqual(len(res), 1) - - yield self.coll.drop() - - yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=4, limit=1) - self.assertEqual(len(res), 1) - - yield self.coll.drop() - yield self.coll.insert_many([{"v": i} for i in range(5)]) - res = yield self.coll.find(skip=5, limit=1) - self.assertEqual(len(res), 0) + tests = { + (3, 1): 1, + (4, 2): 1, + (4, 1): 1, + (5, 1): 0, + } + + for (skip, limit), expected in tests.items(): + res = yield self.coll.find(skip=skip, limit=limit) + self.assertEqual(len(res), expected) + res = yield self.coll.find().skip(skip).limit(limit) + self.assertEqual(len(res), expected) class TestCommand(SingleCollectionTest): diff --git a/txmongo/collection.py b/txmongo/collection.py index 896821d..995ac1e 100644 --- a/txmongo/collection.py +++ b/txmongo/collection.py @@ -2,11 +2,13 @@ # Use of this source code is governed by the Apache License that can be # found in the LICENSE file. +from __future__ import annotations + import collections.abc import time import warnings from operator import itemgetter -from typing import Iterable, List, Optional +from typing import Iterable, List, Mapping, Optional, Union from bson import ObjectId from bson.codec_options import CodecOptions @@ -35,11 +37,51 @@ from txmongo import filter as qf from txmongo._bulk import _Bulk, _Run from txmongo._bulk_constants import _INSERT +from txmongo.filter import SortArgument from txmongo.protocol import QUERY_PARTIAL, QUERY_SLAVE_OK, MongoProtocol, Msg from txmongo.pymongo_internals import _check_write_command_response, _merge_command from txmongo.types import Document from txmongo.utils import check_deadline, timeout +_timeout_decorator = timeout + + +def _normalize_fields_projection(fields): + """ + transform a list of fields from ["a", "b"] to {"a":1, "b":1} + """ + if fields is None: + return None + + if isinstance(fields, dict): + return fields + + # Consider fields as iterable + as_dict = {} + for field in fields: + if not isinstance(field, (bytes, str)): + raise TypeError("TxMongo: fields must be a list of key names.") + as_dict[field] = 1 + if not as_dict: + # Empty list should be treated as "_id only" + as_dict = {"_id": 1} + return as_dict + + +def _apply_find_filter(spec, c_filter): + if c_filter: + if "query" not in spec: + spec = {"$query": spec} + + for k, v in c_filter.items(): + if isinstance(v, (list, tuple)): + spec["$" + k] = SON(v) + else: + spec["$" + k] = v + + return spec + + _DEFERRED_METHODS = frozenset( { "addCallback", @@ -95,14 +137,9 @@ def query(): print(doc) """ - cursor_id: Optional[int] = None - """MongoDB cursor id""" - - exhausted: bool = False - """ - Is the cursor exhausted? If not, you can call :meth:`next_batch()` to retrieve the next batch - or :meth:`close()` to close the cursor on the MongoDB's side - """ + _command_sent: bool = False + _cursor_id: Optional[int] = None + _exhausted: bool = False _next_batch_deferreds: Optional[List[Deferred]] = None _current_loading_op: Optional[defer.Deferred] = None @@ -112,27 +149,159 @@ def query(): def __init__( self, - collection: "Collection", - command: Msg, - batch_size: int, - timeout: Optional[float], + collection: Collection, + filter: Optional[dict] = None, + projection: Optional[dict] = None, + skip: int = 0, + limit: int = 0, + modifiers: Optional[Mapping] = None, + batch_size: int = 0, + *, + allow_partial_results: bool = False, + flags: int = 0, + timeout: Optional[float] = None, ): super().__init__() - self.collection = collection - self.command = command - self.batch_size = batch_size - self._timeout = timeout + self._collection = collection + + if filter is None: + filter = {} + if not isinstance(filter, dict): + raise TypeError("TxMongo: filter must be an instance of dict.") + self._filter = filter + + if modifiers: + validate_is_mapping("sort", modifiers) + self._modifiers = modifiers or {} + + self.projection(projection) + self.skip(skip) + self.limit(limit) + self.batch_size(batch_size) + self.allow_partial_results(allow_partial_results) + self.timeout(timeout) + + self._flags = flags + + # When used as deferred, we should treat `timeout` argument as a overall + # timeout for the whole find() operation, including all batches + self._old_style_deadline = (time.time() + timeout) if timeout else None + + @property + def cursor_id(self) -> int: + """MongoDB cursor id""" + return self._cursor_id + + @property + def exhausted(self) -> bool: + """ + Is the cursor exhausted? If not, you can call :meth:`next_batch()` to retrieve the next batch + or :meth:`close()` to close the cursor on the MongoDB's side + """ + return self._exhausted - if timeout: - # When used as deferred, we should treat `timeout` argument as a overall - # timeout for the whole find() operation, including all batches - self._old_style_deadline = time.time() + timeout + @property + def collection(self) -> Collection: + return self._collection + + def _check_command_not_sent(self): + if self._command_sent: + raise InvalidOperation( + "TxMongo: Cannot set cursor options after executing query." + ) + + def projection(self, projection) -> Cursor: + """ + a list of field names that should be returned for each document + in the result set or a dict specifying field names to include or + exclude. If `projection` is a list ``_id`` fields will always be + returned. Use a dict form to exclude fields: ``{"_id": False}``. + """ + if not isinstance(projection, (dict, list)) and projection is not None: + raise TypeError("TxMongo: projection must be an instance of dict or list.") + self._check_command_not_sent() + self._projection = projection + return self + + def sort(self, sort: SortArgument) -> Cursor: + """Specify the order in which to return query results.""" + self._check_command_not_sent() + self._modifiers.update(qf.sort(sort)) + return self + + def hint(self, hint: Union[str, SortArgument]) -> Cursor: + """Adds a `hint`, telling MongoDB the proper index to use for the query.""" + self._check_command_not_sent() + self._modifiers.update(qf.hint(hint)) + return self + + def comment(self, comment: str) -> Cursor: + """Adds a comment to the query.""" + self._check_command_not_sent() + self._modifiers.update(qf.comment(comment)) + return self + + def explain(self) -> Cursor: + """Returns an explain plan for the query.""" + self._check_command_not_sent() + self._modifiers.update(qf.explain()) + return self + + def skip(self, skip: int) -> Cursor: + """ + Set the number of documents to omit from the start of the result set. + """ + if not isinstance(skip, int): + raise TypeError("TxMongo: skip must be an instance of int.") + self._check_command_not_sent() + self._skip = skip + return self + + def limit(self, limit: int) -> Cursor: + """ + Set the maximum number of documents to return. All documents are returned when `limit` is zero. + """ + if not isinstance(limit, int): + raise TypeError("TxMongo: limit must be an instance of int.") + self._check_command_not_sent() + self._limit = limit + return self + + def batch_size(self, batch_size: int) -> Cursor: + """ + Set the number of documents to return in each batch of results. + """ + if not isinstance(batch_size, int): + raise TypeError("TxMongo: batch_size must be an instance of int.") + self._check_command_not_sent() + self._batch_size = batch_size + return self + + def allow_partial_results(self, allow_partial_results: bool = True) -> Cursor: + """ + If True, mongos will return partial results if some shards are down instead of returning an error + """ + self._check_command_not_sent() + self._allow_partial_results = bool(allow_partial_results) + return self + + def timeout(self, timeout: Optional[float]) -> Cursor: + """ + Set the timeout for retrieving batches of results. If Cursor object is used as a Deferred, + this timeout will be used as an overall timeout for the whole results set loading. + """ + if timeout is not None and not isinstance(timeout, (int, float)): + raise TypeError("TxMongo: timeout must be an instance of float or None.") + self._check_command_not_sent() + self._timeout = timeout + self._old_style_deadline = (time.time() + timeout) if timeout else None + return self @inlineCallbacks def _old_style_find(self): result = [] try: - while not self.exhausted: + while not self._exhausted: batch = yield self.next_batch(deadline=self._old_style_deadline) if not batch: continue @@ -153,23 +322,104 @@ def __getattribute__(self, item): return value return super().__getattribute__(item) - def _after_connection(self, proto: MongoProtocol, _deadline: Optional[float]): - return proto.send_msg(self.command, self.collection.codec_options).addCallback( - self._after_reply, _deadline + _MODIFIERS = { + "$query": "filter", + "$orderby": "sort", + "$hint": "hint", + "$comment": "comment", + "$maxScan": "maxScan", + "$maxTimeMS": "maxTimeMS", + "$max": "max", + "$min": "min", + "$returnKey": "returnKey", + "$showRecordId": "showRecordId", + "$showDiskLoc": "showRecordId", # <= MongoDB 3.0 + } + + def _gen_find_command( + self, + db_name: str, + coll_name: str, + filter_with_modifiers, + projection, + skip, + limit, + batch_size, + allow_partial_results, + flags: int, + ) -> Msg: + cmd = {"find": coll_name} + if "$query" in filter_with_modifiers: + cmd.update( + [ + ( + (self._MODIFIERS[key], val) + if key in self._MODIFIERS + else (key, val) + ) + for key, val in filter_with_modifiers.items() + ] + ) + else: + cmd["filter"] = filter_with_modifiers + + if projection: + cmd["projection"] = projection + if skip: + cmd["skip"] = skip + if limit: + cmd["limit"] = abs(limit) + if limit < 0: + cmd["singleBatch"] = True + cmd["batchSize"] = abs(limit) + if batch_size: + cmd["batchSize"] = batch_size + + if flags & QUERY_SLAVE_OK: + cmd["$readPreference"] = {"mode": "secondaryPreferred"} + if allow_partial_results or flags & QUERY_PARTIAL: + cmd["allowPartialResults"] = True + + if "$explain" in filter_with_modifiers: + cmd.pop("$explain") + cmd = {"explain": cmd} + + cmd["$db"] = db_name + return Msg.create(cmd, codec_options=self._collection.codec_options) + + def _build_command(self) -> Msg: + projection = _normalize_fields_projection(self._projection) + filter = _apply_find_filter(self._filter, self._modifiers) + return self._gen_find_command( + self._collection.database.name, + self._collection.name, + filter, + projection, + self._skip, + self._limit, + self._batch_size, + self._allow_partial_results, + self._flags, ) + def _after_connection(self, proto: MongoProtocol, _deadline: Optional[float]): + self._command_sent = True + return proto.send_msg( + self._build_command(), self._collection.codec_options + ).addCallback(self._after_reply, _deadline) + def _get_more(self, proto: MongoProtocol, _deadline: Optional[float]): get_more = { - "getMore": self.cursor_id, - "$db": self.collection.database.name, - "collection": self.collection.name, + "getMore": self._cursor_id, + "$db": self._collection.database.name, + "collection": self._collection.name, } - if self.batch_size: - get_more["batchSize"] = self.batch_size + if self._batch_size: + get_more["batchSize"] = self._batch_size return proto.send_msg( - Msg.create(get_more, codec_options=self.collection.codec_options), - self.collection.codec_options, + Msg.create(get_more, codec_options=self._collection.codec_options), + self._collection.codec_options, ).addCallback(self._after_reply, _deadline) def _after_reply(self, reply: dict, _deadline: Optional[float]): @@ -177,16 +427,16 @@ def _after_reply(self, reply: dict, _deadline: Optional[float]): if "cursor" not in reply: # For example, when we run `explain` command - self.cursor_id = None - self.exhausted = True + self._cursor_id = None + self._exhausted = True return [reply] else: cursor = reply["cursor"] - self.cursor_id = cursor["id"] - self.exhausted = not self.cursor_id + self._cursor_id = cursor["id"] + self._exhausted = not self._cursor_id return cursor["nextBatch" if "nextBatch" in cursor else "firstBatch"] - @timeout + @_timeout_decorator def next_batch(self, _deadline: Optional[float]) -> Deferred[List[dict]]: """next_batch() -> Deferred[list[dict]] @@ -196,7 +446,7 @@ def next_batch(self, _deadline: Optional[float]) -> Deferred[List[dict]]: Check :attr:`exhausted` after calling this method to know if this is a last batch. """ - if self.exhausted: + if self._exhausted: return defer.succeed([]) def on_cancel(d): @@ -221,9 +471,9 @@ def on_result(result): d.callback(result) self._current_loading_op = ( - self.collection.database.connection.getprotocol() + self._collection.database.connection.getprotocol() .addCallback( - (self._after_connection if self.cursor_id is None else self._get_more), + (self._after_connection if self._cursor_id is None else self._get_more), _deadline, ) .addBoth(on_result) @@ -237,10 +487,10 @@ def close(self) -> defer.Deferred: the cursor object as an async generator. But if you use it by calling :meth:`next_batch()`, be sure to close cursor if you stop iterating before the cursor is exhausted. """ - if not self.cursor_id: + if not self._cursor_id: return defer.succeed(None) - return self.collection.database.connection.getprotocol().addCallback( - self.collection._close_cursor_without_response, self.cursor_id + return self._collection.database.connection.getprotocol().addCallback( + self._collection._close_cursor_without_response, self._cursor_id ) async def batches(self): @@ -253,7 +503,7 @@ async def query(): print(doc) """ try: - while not self.exhausted: + while not self._exhausted: batch = await self.next_batch(timeout=self._timeout) if not batch: continue @@ -419,28 +669,6 @@ def with_options(self, **kwargs): codec_options=codec_options, ) - @staticmethod - def _normalize_fields_projection(fields): - """ - transform a list of fields from ["a", "b"] to {"a":1, "b":1} - """ - if fields is None: - return None - - if isinstance(fields, dict): - return fields - - # Consider fields as iterable - as_dict = {} - for field in fields: - if not isinstance(field, (bytes, str)): - raise TypeError("TxMongo: fields must be a list of key names.") - as_dict[field] = 1 - if not as_dict: - # Empty list should be treated as "_id only" - as_dict = {"_id": 1} - return as_dict - @staticmethod def _gen_index_name(keys): return "_".join(["%s_%s" % item for item in keys]) @@ -558,32 +786,19 @@ def query(): if timeout is None and deadline is not None: timeout = deadline - time.time() - return self._create_cursor( + return Cursor( + self, filter, projection, skip, limit, - sort, + modifiers=sort, batch_size=batch_size, allow_partial_results=allow_partial_results, flags=flags, timeout=timeout, ) - @staticmethod - def __apply_find_filter(spec, c_filter): - if c_filter: - if "query" not in spec: - spec = {"$query": spec} - - for k, v in c_filter.items(): - if isinstance(v, (list, tuple)): - spec["$" + k] = SON(v) - else: - spec["$" + k] = v - - return spec - @timeout def find_with_cursor( self, @@ -624,12 +839,13 @@ def query(): DeprecationWarning, ) - cursor = self._create_cursor( + cursor = Cursor( + self, filter, projection, - skip, - limit, - sort, + skip=skip, + limit=limit, + modifiers=sort, batch_size=batch_size, allow_partial_results=allow_partial_results, flags=flags, @@ -645,117 +861,6 @@ def on_batch(batch, this_func): return cursor.next_batch(deadline=_deadline).addCallback(on_batch, on_batch) - def _create_cursor( - self, - filter=None, - projection=None, - skip=0, - limit=0, - sort=None, - batch_size=0, - *, - allow_partial_results: bool = False, - flags=0, - timeout: Optional[float] = None, - ): - if filter is None: - filter = SON() - - if not isinstance(filter, dict): - raise TypeError("TxMongo: filter must be an instance of dict.") - if not isinstance(projection, (dict, list)) and projection is not None: - raise TypeError("TxMongo: projection must be an instance of dict or list.") - if not isinstance(skip, int): - raise TypeError("TxMongo: skip must be an instance of int.") - if not isinstance(limit, int): - raise TypeError("TxMongo: limit must be an instance of int.") - if not isinstance(batch_size, int): - raise TypeError("TxMongo: batch_size must be an instance of int.") - if sort: - validate_is_mapping("sort", sort) - - projection = self._normalize_fields_projection(projection) - - filter = self.__apply_find_filter(filter, sort) - - cmd = self._gen_find_command( - self.database.name, - self.name, - filter, - projection, - skip, - limit, - batch_size, - allow_partial_results, - flags, - ) - return Cursor(self, cmd, batch_size, timeout) - - _MODIFIERS = { - "$query": "filter", - "$orderby": "sort", - "$hint": "hint", - "$comment": "comment", - "$maxScan": "maxScan", - "$maxTimeMS": "maxTimeMS", - "$max": "max", - "$min": "min", - "$returnKey": "returnKey", - "$showRecordId": "showRecordId", - "$showDiskLoc": "showRecordId", # <= MongoDB 3.0 - } - - def _gen_find_command( - self, - db_name: str, - coll_name: str, - filter_with_modifiers, - projection, - skip, - limit, - batch_size, - allow_partial_results, - flags: int, - ) -> Msg: - cmd = {"find": coll_name} - if "$query" in filter_with_modifiers: - cmd.update( - [ - ( - (self._MODIFIERS[key], val) - if key in self._MODIFIERS - else (key, val) - ) - for key, val in filter_with_modifiers.items() - ] - ) - else: - cmd["filter"] = filter_with_modifiers - - if projection: - cmd["projection"] = projection - if skip: - cmd["skip"] = skip - if limit: - cmd["limit"] = abs(limit) - if limit < 0: - cmd["singleBatch"] = True - cmd["batchSize"] = abs(limit) - if batch_size: - cmd["batchSize"] = batch_size - - if flags & QUERY_SLAVE_OK: - cmd["$readPreference"] = {"mode": "secondaryPreferred"} - if allow_partial_results or flags & QUERY_PARTIAL: - cmd["allowPartialResults"] = True - - if "$explain" in filter_with_modifiers: - cmd.pop("$explain") - cmd = {"explain": cmd} - - cmd["$db"] = db_name - return Msg.create(cmd, codec_options=self.codec_options) - def _close_cursor_without_response(self, proto: MongoProtocol, cursor_id: int): proto.send_msg( Msg.create( @@ -796,12 +901,13 @@ def find_one( # Since we specify limit=1, MongoDB will close cursor automatically for us return ( - self._create_cursor( + Cursor( + self, filter, projection, skip, limit=1, - sort=sort, + modifiers=sort, allow_partial_results=allow_partial_results, flags=flags, ) @@ -1304,7 +1410,7 @@ def _find_and_modify( } if projection is not None: - cmd["fields"] = self._normalize_fields_projection(projection) + cmd["fields"] = _normalize_fields_projection(projection) if sort is not None: cmd["sort"] = dict(sort["orderby"]) diff --git a/txmongo/database.py b/txmongo/database.py index c62663a..4652e88 100644 --- a/txmongo/database.py +++ b/txmongo/database.py @@ -2,6 +2,7 @@ # Use of this source code is governed by the Apache License that can be # found in the LICENSE file. from twisted.internet import defer +from twisted.internet.defer import inlineCallbacks from txmongo.collection import Collection from txmongo.protocol import Msg @@ -127,18 +128,39 @@ def drop_collection(self, name_or_collection, _deadline=None): ) @timeout - def collection_names(self, _deadline=None): + @inlineCallbacks + def collection_names(self, *, batch_size=0, _deadline=None): """collection_names()""" - def ok(results): - names = [r["name"] for r in results] - names = [ - n[len(str(self)) + 1 :] for n in names if n.startswith(str(self) + ".") - ] - names = [n for n in names if "$" not in n] - return names - - return self["system.namespaces"].find(deadline=_deadline).addCallback(ok) + cmd = { + "listCollections": 1, + "nameOnly": True, + "authorizedCollections": True, + } + if batch_size: + # "cursor" parameter is undocumented, but working in 4.0-8.0 and + # is useful for testing purposes + cmd["cursor"] = {"batchSize": batch_size} + response = yield self.command( + cmd, + _deadline=_deadline, + ) + names = [] + cursor_id = response["cursor"]["id"] + names.extend(coll["name"] for coll in response["cursor"]["firstBatch"]) + while cursor_id: + response = yield self.command( + { + "getMore": cursor_id, + "$db": self.name, + "collection": "$cmd.listCollections", + "batchSize": batch_size, + }, + _deadline=_deadline, + ) + cursor_id = response["cursor"]["id"] + names.extend(coll["name"] for coll in response["cursor"]["nextBatch"]) + return names def authenticate(self, name, password, mechanism="DEFAULT"): """ diff --git a/txmongo/filter.py b/txmongo/filter.py index e3354d7..24bcbe1 100644 --- a/txmongo/filter.py +++ b/txmongo/filter.py @@ -3,6 +3,7 @@ # found in the LICENSE file. from collections import defaultdict +from typing import List, Literal, Mapping, Tuple, Union """Query filters""" @@ -56,12 +57,21 @@ def TEXT(keys): return _direction(keys, "text") +AllowedDirectionType = Literal[1, -1, "2d", "2dsphere", "geoHaystack", "text"] +SortArgument = Union[ + Mapping[str, AllowedDirectionType], + List[Tuple[str, AllowedDirectionType]], + Tuple[Tuple[str, AllowedDirectionType]], + Tuple[str, AllowedDirectionType], +] + + class _QueryFilter(defaultdict): ALLOWED_DIRECTIONS = {1, -1, "2d", "2dsphere", "geoHaystack", "text"} def __init__(self): - defaultdict.__init__(self, lambda: ()) + super().__init__(lambda: ()) def __add__(self, obj): for k, v in obj.items(): @@ -99,11 +109,11 @@ def __repr__(self): class sort(_QueryFilter): """Sorts the results of a query.""" - def __init__(self, key_list): - _QueryFilter.__init__(self) - try: - assert isinstance(key_list[0], (list, tuple)) - except: + def __init__(self, key_list: SortArgument): + super().__init__() + if isinstance(key_list, Mapping): + key_list = list(key_list.items()) + elif not isinstance(key_list[0], (list, tuple)): key_list = (key_list,) self._index_document("orderby", key_list) @@ -112,7 +122,9 @@ class hint(_QueryFilter): """Adds a `hint`, telling Mongo the proper index to use for the query.""" def __init__(self, index_list_or_name): - _QueryFilter.__init__(self) + super().__init__() + if isinstance(index_list_or_name, Mapping): + index_list_or_name = list(index_list_or_name.items()) if isinstance(index_list_or_name, (list, tuple)): if not isinstance(index_list_or_name[0], (list, tuple)): index_list_or_name = (index_list_or_name,) @@ -125,17 +137,17 @@ class explain(_QueryFilter): """Returns an explain plan for the query.""" def __init__(self): - _QueryFilter.__init__(self) + super().__init__() self["explain"] = True class snapshot(_QueryFilter): def __init__(self): - _QueryFilter.__init__(self) + super().__init__() self["snapshot"] = True class comment(_QueryFilter): def __init__(self, comment): - _QueryFilter.__init__(self) + super().__init__() self["comment"] = comment diff --git a/txmongo/protocol.py b/txmongo/protocol.py index 251be06..7f5208f 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -13,6 +13,8 @@ decoding as well as Exception types, when applicable. """ +from __future__ import annotations + import base64 import hashlib import hmac @@ -117,7 +119,7 @@ def _payload(self) -> List[bytes]: ... @abstractmethod def decode( cls, request_id: int, response_to: int, opcode: int, message_data: bytes - ) -> "BaseMessage": ... + ) -> BaseMessage: ... QUERY_TAILABLE_CURSOR = 1 << 1 @@ -155,7 +157,7 @@ def _payload(self) -> List[bytes]: @classmethod def decode( cls, request_id: int, response_to: int, opcode: int, message_data: bytes - ) -> "Query": + ) -> Query: (flags,) = struct.unpack(" List[bytes]: @classmethod def decode( cls, request_id: int, response_to: int, opcode: int, message_data: bytes - ) -> "Reply": + ) -> Reply: msg_len = len(message_data) (response_flags, cursor_id, starting_from, n_returned) = struct.unpack( " "Msg": + ) -> Msg: encoded_payload = {} if payload: encoded_payload = { @@ -330,7 +332,7 @@ def _payload(self) -> List[bytes]: @classmethod def decode( cls, request_id: int, response_to: int, opcode: int, message_data: bytes - ) -> "Msg": + ) -> Msg: msg_length = len(message_data) body = None payload: Dict[str, List[bytes]] = {}