diff --git a/src/django_jsonfield_backport/models.py b/src/django_jsonfield_backport/models.py index 357c5f6..49b73ee 100644 --- a/src/django_jsonfield_backport/models.py +++ b/src/django_jsonfield_backport/models.py @@ -491,6 +491,33 @@ def as_sqlite(self, compiler, connection): return super().as_sql(compiler, connection) +class KeyTransformIn(lookups.In): + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == "mysql" and connection.mysql_is_mariadb: + return "JSON_UNQUOTE(%s)" % lhs, lhs_params + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if not connection.features.has_native_json_field: + func = () + if connection.vendor == "oracle": + func = [] + for value in rhs_params: + value = json.loads(value) + function = "JSON_QUERY" if isinstance(value, (list, dict)) else "JSON_VALUE" + func.append("%s('%s', '$.value')" % (function, json.dumps({"value": value}))) + func = tuple(func) + rhs_params = () + elif connection.vendor == "mysql" and connection.mysql_is_mariadb: + func = ("JSON_UNQUOTE(JSON_EXTRACT(%s, '$'))",) * len(rhs_params) + else: + func = ("JSON_EXTRACT(%s, '$')",) * len(rhs_params) + rhs = rhs % func + return rhs, rhs_params + + class KeyTransformExact(JSONExact): def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) @@ -601,6 +628,7 @@ def register_lookups(): JSONField.register_lookup(HasAnyKeys) JSONField.register_lookup(JSONExact) + KeyTransform.register_lookup(KeyTransformIn) KeyTransform.register_lookup(KeyTransformExact) KeyTransform.register_lookup(KeyTransformIExact) KeyTransform.register_lookup(KeyTransformIsNull) diff --git a/tests/test_model_field.py b/tests/test_model_field.py index f4f5cf3..a17c357 100644 --- a/tests/test_model_field.py +++ b/tests/test_model_field.py @@ -626,6 +626,24 @@ def test_key_iexact(self): self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact="BaR").exists(), True) self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False) + def test_key_in(self): + tests = [ + ("value__c__in", [14], self.objs[3:5]), + ("value__c__in", [14, 15], self.objs[3:5]), + ("value__0__in", [1], [self.objs[5]]), + ("value__0__in", [1, 3], [self.objs[5]]), + ("value__foo__in", ["bar"], [self.objs[7]]), + ("value__foo__in", ["bar", "baz"], [self.objs[7]]), + ("value__bar__in", [["foo", "bar"]], [self.objs[7]]), + ("value__bar__in", [["foo", "bar"], ["a"]], [self.objs[7]]), + ("value__bax__in", [{"foo": "bar"}, {"a": "b"}], [self.objs[7]]), + ] + for lookup, value, expected in tests: + with self.subTest(lookup=lookup, value=value): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(**{lookup: value}), expected + ) + @skipUnlessDBFeature("supports_json_field_contains") def test_key_contains(self): self.assertIs(NullableJSONModel.objects.filter(value__foo__contains="ar").exists(), False)