From 041237f7059e1186ed23b7b966753b810a46cad3 Mon Sep 17 00:00:00 2001
From: sage <laymonage@gmail.com>
Date: Sun, 6 Sep 2020 17:22:13 +0700
Subject: [PATCH] Fix __in lookup on key transforms

---
 src/django_jsonfield_backport/models.py | 22 ++++++++++++++++++++++
 tests/test_model_field.py               | 18 ++++++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/src/django_jsonfield_backport/models.py b/src/django_jsonfield_backport/models.py
index 357c5f6..4d112b9 100644
--- a/src/django_jsonfield_backport/models.py
+++ b/src/django_jsonfield_backport/models.py
@@ -491,6 +491,27 @@ def as_sqlite(self, compiler, connection):
         return super().as_sql(compiler, connection)
 
 
+class KeyTransformIn(lookups.In):
+    def process_rhs(self, compiler, connection):
+        rhs, rhs_params = super().process_rhs(compiler, connection)
+        if not connection.features.has_native_json_field:
+            func = ()
+            if connection.vendor == "oracle":
+                func = []
+                for value in rhs_params:
+                    value = json.loads(value)
+                    function = "JSON_QUERY" if isinstance(value, (list, dict)) else "JSON_VALUE"
+                    func.append("%s('%s', '$.value')" % (function, json.dumps({"value": value})))
+                func = tuple(func)
+                rhs_params = ()
+            elif connection.vendor == "mysql" and connection.mysql_is_mariadb:
+                func = ("JSON_UNQUOTE(JSON_EXTRACT(%s, '$'))",) * len(rhs_params)
+            else:
+                func = ("JSON_EXTRACT(%s, '$')",) * len(rhs_params)
+            rhs = rhs % func
+        return rhs, rhs_params
+
+
 class KeyTransformExact(JSONExact):
     def process_lhs(self, compiler, connection):
         lhs, lhs_params = super().process_lhs(compiler, connection)
@@ -601,6 +622,7 @@ def register_lookups():
     JSONField.register_lookup(HasAnyKeys)
     JSONField.register_lookup(JSONExact)
 
+    KeyTransform.register_lookup(KeyTransformIn)
     KeyTransform.register_lookup(KeyTransformExact)
     KeyTransform.register_lookup(KeyTransformIExact)
     KeyTransform.register_lookup(KeyTransformIsNull)
diff --git a/tests/test_model_field.py b/tests/test_model_field.py
index f4f5cf3..a17c357 100644
--- a/tests/test_model_field.py
+++ b/tests/test_model_field.py
@@ -626,6 +626,24 @@ def test_key_iexact(self):
         self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact="BaR").exists(), True)
         self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False)
 
+    def test_key_in(self):
+        tests = [
+            ("value__c__in", [14], self.objs[3:5]),
+            ("value__c__in", [14, 15], self.objs[3:5]),
+            ("value__0__in", [1], [self.objs[5]]),
+            ("value__0__in", [1, 3], [self.objs[5]]),
+            ("value__foo__in", ["bar"], [self.objs[7]]),
+            ("value__foo__in", ["bar", "baz"], [self.objs[7]]),
+            ("value__bar__in", [["foo", "bar"]], [self.objs[7]]),
+            ("value__bar__in", [["foo", "bar"], ["a"]], [self.objs[7]]),
+            ("value__bax__in", [{"foo": "bar"}, {"a": "b"}], [self.objs[7]]),
+        ]
+        for lookup, value, expected in tests:
+            with self.subTest(lookup=lookup, value=value):
+                self.assertSequenceEqual(
+                    NullableJSONModel.objects.filter(**{lookup: value}), expected
+                )
+
     @skipUnlessDBFeature("supports_json_field_contains")
     def test_key_contains(self):
         self.assertIs(NullableJSONModel.objects.filter(value__foo__contains="ar").exists(), False)