diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index fafc1beee8..00df1ae206 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -378,6 +378,30 @@ class KeyTransformIsNull(lookups.IsNull): return super().as_sql(compiler, connection) +class KeyTransformIn(lookups.In): + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if not connection.features.has_native_json_field: + func = () + if connection.vendor == 'oracle': + func = [] + for value in rhs_params: + value = json.loads(value) + function = 'JSON_QUERY' if isinstance(value, (list, dict)) else 'JSON_VALUE' + func.append("%s('%s', '$.value')" % ( + function, + json.dumps({'value': value}), + )) + func = tuple(func) + rhs_params = () + elif connection.vendor == 'mysql' and connection.mysql_is_mariadb: + func = ("JSON_UNQUOTE(JSON_EXTRACT(%s, '$'))",) * len(rhs_params) + elif connection.vendor in {'sqlite', 'mysql'}: + func = ("JSON_EXTRACT(%s, '$')",) * len(rhs_params) + rhs = rhs % func + return rhs, rhs_params + + class KeyTransformExact(JSONExact): def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) @@ -479,6 +503,7 @@ class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual pass +KeyTransform.register_lookup(KeyTransformIn) KeyTransform.register_lookup(KeyTransformExact) KeyTransform.register_lookup(KeyTransformIExact) KeyTransform.register_lookup(KeyTransformIsNull) diff --git a/docs/releases/3.1.1.txt b/docs/releases/3.1.1.txt index 6359f181fa..83ff2b07d2 100644 --- a/docs/releases/3.1.1.txt +++ b/docs/releases/3.1.1.txt @@ -39,3 +39,7 @@ Bugfixes * Enforced thread sensitivity of the :class:`MiddlewareMixin.process_request() ` and ``process_response()`` hooks when in an async context (:ticket:`31905`). + +* Fixed ``__in`` lookup on key transforms for + :class:`~django.db.models.JSONField` with MariaDB, MySQL, Oracle, and SQLite + (:ticket:`31936`). diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py index 26643adcdd..debf02517a 100644 --- a/tests/model_fields/test_jsonfield.py +++ b/tests/model_fields/test_jsonfield.py @@ -612,6 +612,25 @@ class TestQuerying(TestCase): self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='BaR').exists(), True) self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False) + def test_key_in(self): + tests = [ + ('value__c__in', [14], self.objs[3:5]), + ('value__c__in', [14, 15], self.objs[3:5]), + ('value__0__in', [1], [self.objs[5]]), + ('value__0__in', [1, 3], [self.objs[5]]), + ('value__foo__in', ['bar'], [self.objs[7]]), + ('value__foo__in', ['bar', 'baz'], [self.objs[7]]), + ('value__bar__in', [['foo', 'bar']], [self.objs[7]]), + ('value__bar__in', [['foo', 'bar'], ['a']], [self.objs[7]]), + ('value__bax__in', [{'foo': 'bar'}, {'a': 'b'}], [self.objs[7]]), + ] + for lookup, value, expected in tests: + with self.subTest(lookup=lookup, value=value): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(**{lookup: value}), + expected, + ) + @skipUnlessDBFeature('supports_json_field_contains') def test_key_contains(self): self.assertIs(NullableJSONModel.objects.filter(value__foo__contains='ar').exists(), False)