diff --git a/docs/changelog.rst b/docs/changelog.rst index 1ac334e5a..d1c962216 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -20,6 +20,7 @@ Development it is not recommended to do so (DecimalField uses float/str to store the value, Decimal128Field uses Decimal128). - BREAKING CHANGE: When using ListField(EnumField) or DictField(EnumField), the values weren't always cast into the Enum (#2531) - BREAKING CHANGE (bugfix) Querying ObjectIdField or ComplexDateTimeField with None will no longer raise a ValidationError (#2681) +- Allow updating a field that has an operator name e.g. "type" with .update(set__type="foo"). It was raising an error previously. #2595 Changes in 0.25.0 ================= diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 749cf1f61..a95a84681 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -264,12 +264,24 @@ def update(_doc_cls=None, **update): op = operator_map.get(op, op) match = None - if len(parts) > 1 and parts[-1] in COMPARISON_OPERATORS: - match = parts.pop() - # Allow to escape operator-like field name by __ - if len(parts) > 1 and parts[-1] == "": - parts.pop() + if len(parts) == 1: + # typical update like set__field + # but also allows to update a field named like a comparison operator + # like set__type = "something" (without clashing with the 'type' operator) + pass + elif len(parts) > 1: + # can be either an embedded field like set__foo__bar + # or a comparison operator as in pull__foo__in + if parts[-1] in COMPARISON_OPERATORS: + match = parts.pop() # e.g. pop 'in' from pull__foo__in + + # Allow to escape operator-like field name by __ + # e.g. in the case of an embedded foo.type field + # Doc.objects().update(set__foo__type="bar") + # see https://github.com/MongoEngine/mongoengine/pull/1351 + if parts[-1] == "": + match = parts.pop() # e.g. pop last '__' from set__foo__type__ if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] diff --git a/tests/queryset/test_transform.py b/tests/queryset/test_transform.py index 9a7d6365e..5627597f8 100644 --- a/tests/queryset/test_transform.py +++ b/tests/queryset/test_transform.py @@ -275,7 +275,7 @@ class Doc(Document): assert Doc.objects(df__type=2).count() == 1 # str assert Doc.objects(df__type=16).count() == 1 # int - def test_last_field_name_like_operator(self): + def test_embedded_field_name_like_operator(self): class EmbeddedItem(EmbeddedDocument): type = StringField() name = StringField() @@ -295,6 +295,30 @@ class Doc(Document): assert 1 == Doc.objects(item__type__="sword").count() assert 0 == Doc.objects(item__type__="axe").count() + def test_regular_field_named_like_operator(self): + class SimpleDoc(Document): + size = StringField() + type = StringField() + + SimpleDoc.drop_collection() + SimpleDoc(type="ok", size="ok").save() + + qry = transform.query(SimpleDoc, type="testtype") + assert qry == {"type": "testtype"} + + assert SimpleDoc.objects(type="ok").count() == 1 + assert SimpleDoc.objects(size="ok").count() == 1 + + update = transform.update(SimpleDoc, set__type="testtype") + assert update == {"$set": {"type": "testtype"}} + + SimpleDoc.objects.update(set__type="testtype") + SimpleDoc.objects.update(set__size="testsize") + + s = SimpleDoc.objects.first() + assert s.type == "testtype" + assert s.size == "testsize" + def test_understandable_error_raised(self): class Event(Document): title = StringField()