Skip to content

Commit

Permalink
Merge pull request #17 from laymonage/fix-in
Browse files Browse the repository at this point in the history
Fix __in lookup on key transforms
  • Loading branch information
laymonage authored Sep 6, 2020
2 parents 6e2fdf9 + 55c2c9f commit 67aeb2b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/django_jsonfield_backport/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,33 @@ def as_sqlite(self, compiler, connection):
return super().as_sql(compiler, connection)


class KeyTransformIn(lookups.In):
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
return "JSON_UNQUOTE(%s)" % lhs, lhs_params
return lhs, lhs_params

def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if not connection.features.has_native_json_field:
func = ()
if connection.vendor == "oracle":
func = []
for value in rhs_params:
value = json.loads(value)
function = "JSON_QUERY" if isinstance(value, (list, dict)) else "JSON_VALUE"
func.append("%s('%s', '$.value')" % (function, json.dumps({"value": value})))
func = tuple(func)
rhs_params = ()
elif connection.vendor == "mysql" and connection.mysql_is_mariadb:
func = ("JSON_UNQUOTE(JSON_EXTRACT(%s, '$'))",) * len(rhs_params)
else:
func = ("JSON_EXTRACT(%s, '$')",) * len(rhs_params)
rhs = rhs % func
return rhs, rhs_params


class KeyTransformExact(JSONExact):
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
Expand Down Expand Up @@ -601,6 +628,7 @@ def register_lookups():
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)

KeyTransform.register_lookup(KeyTransformIn)
KeyTransform.register_lookup(KeyTransformExact)
KeyTransform.register_lookup(KeyTransformIExact)
KeyTransform.register_lookup(KeyTransformIsNull)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_model_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,24 @@ def test_key_iexact(self):
self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact="BaR").exists(), True)
self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False)

def test_key_in(self):
tests = [
("value__c__in", [14], self.objs[3:5]),
("value__c__in", [14, 15], self.objs[3:5]),
("value__0__in", [1], [self.objs[5]]),
("value__0__in", [1, 3], [self.objs[5]]),
("value__foo__in", ["bar"], [self.objs[7]]),
("value__foo__in", ["bar", "baz"], [self.objs[7]]),
("value__bar__in", [["foo", "bar"]], [self.objs[7]]),
("value__bar__in", [["foo", "bar"], ["a"]], [self.objs[7]]),
("value__bax__in", [{"foo": "bar"}, {"a": "b"}], [self.objs[7]]),
]
for lookup, value, expected in tests:
with self.subTest(lookup=lookup, value=value):
self.assertSequenceEqual(
NullableJSONModel.objects.filter(**{lookup: value}), expected
)

@skipUnlessDBFeature("supports_json_field_contains")
def test_key_contains(self):
self.assertIs(NullableJSONModel.objects.filter(value__foo__contains="ar").exists(), False)
Expand Down

0 comments on commit 67aeb2b

Please sign in to comment.