Skip to content

Commit

Permalink
Allow .update(set__type='foo') to work without raising an error in ca…
Browse files Browse the repository at this point in the history
…se 'type' is one of the field. Avoids that it clash with the type operator
  • Loading branch information
bagerard committed Jan 9, 2023
1 parent d9dd375 commit 41689b8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Development
it is not recommended to do so (DecimalField uses float/str to store the value, Decimal128Field uses Decimal128).
- BREAKING CHANGE: When using ListField(EnumField) or DictField(EnumField), the values weren't always cast into the Enum (#2531)
- BREAKING CHANGE (bugfix) Querying ObjectIdField or ComplexDateTimeField with None will no longer raise a ValidationError (#2681)
- Allow updating a field that has an operator name e.g. "type" with .update(set__type="foo"). It was raising an error previously. #2595

Changes in 0.25.0
=================
Expand Down
22 changes: 17 additions & 5 deletions mongoengine/queryset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,24 @@ def update(_doc_cls=None, **update):
op = operator_map.get(op, op)

match = None
if len(parts) > 1 and parts[-1] in COMPARISON_OPERATORS:
match = parts.pop()

# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
parts.pop()
if len(parts) == 1:
# typical update like set__field
# but also allows to update a field named like a comparison operator
# like set__type = "something" (without clashing with the 'type' operator)
pass
elif len(parts) > 1:
# can be either an embedded field like set__foo__bar
# or a comparison operator as in pull__foo__in
if parts[-1] in COMPARISON_OPERATORS:
match = parts.pop() # e.g. pop 'in' from pull__foo__in

# Allow to escape operator-like field name by __
# e.g. in the case of an embedded foo.type field
# Doc.objects().update(set__foo__type="bar")
# see https://github.com/MongoEngine/mongoengine/pull/1351
if parts[-1] == "":
match = parts.pop() # e.g. pop last '__' from set__foo__type__

if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
Expand Down
26 changes: 25 additions & 1 deletion tests/queryset/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class Doc(Document):
assert Doc.objects(df__type=2).count() == 1 # str
assert Doc.objects(df__type=16).count() == 1 # int

def test_last_field_name_like_operator(self):
def test_embedded_field_name_like_operator(self):
class EmbeddedItem(EmbeddedDocument):
type = StringField()
name = StringField()
Expand All @@ -295,6 +295,30 @@ class Doc(Document):
assert 1 == Doc.objects(item__type__="sword").count()
assert 0 == Doc.objects(item__type__="axe").count()

def test_regular_field_named_like_operator(self):
class SimpleDoc(Document):
size = StringField()
type = StringField()

SimpleDoc.drop_collection()
SimpleDoc(type="ok", size="ok").save()

qry = transform.query(SimpleDoc, type="testtype")
assert qry == {"type": "testtype"}

assert SimpleDoc.objects(type="ok").count() == 1
assert SimpleDoc.objects(size="ok").count() == 1

update = transform.update(SimpleDoc, set__type="testtype")
assert update == {"$set": {"type": "testtype"}}

SimpleDoc.objects.update(set__type="testtype")
SimpleDoc.objects.update(set__size="testsize")

s = SimpleDoc.objects.first()
assert s.type == "testtype"
assert s.size == "testsize"

def test_understandable_error_raised(self):
class Event(Document):
title = StringField()
Expand Down

0 comments on commit 41689b8

Please sign in to comment.