diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index a249f4cdbf8..d0a55ac1e72 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -369,27 +369,26 @@ class KeyTransformIsNull(lookups.IsNull): class KeyTransformIn(lookups.In): - def process_rhs(self, compiler, connection): - rhs, rhs_params = super().process_rhs(compiler, connection) - if not connection.features.has_native_json_field: - func = () + def resolve_expression_parameter(self, compiler, connection, sql, param): + sql, params = super().resolve_expression_parameter( + compiler, connection, sql, param, + ) + if ( + not hasattr(param, 'as_sql') and + not connection.features.has_native_json_field + ): if connection.vendor == 'oracle': - func = [] - for value in rhs_params: - value = json.loads(value) - function = 'JSON_QUERY' if isinstance(value, (list, dict)) else 'JSON_VALUE' - func.append("%s('%s', '$.value')" % ( - function, - json.dumps({'value': value}), - )) - func = tuple(func) - rhs_params = () - elif connection.vendor == 'mysql' and connection.mysql_is_mariadb: - func = ("JSON_UNQUOTE(JSON_EXTRACT(%s, '$'))",) * len(rhs_params) + value = json.loads(param) + if isinstance(value, (list, dict)): + sql = "JSON_QUERY(%s, '$.value')" + else: + sql = "JSON_VALUE(%s, '$.value')" + params = (json.dumps({'value': value}),) elif connection.vendor in {'sqlite', 'mysql'}: - func = ("JSON_EXTRACT(%s, '$')",) * len(rhs_params) - rhs = rhs % func - return rhs, rhs_params + sql = "JSON_EXTRACT(%s, '$')" + if connection.vendor == 'mysql' and connection.mysql_is_mariadb: + sql = 'JSON_UNQUOTE(%s)' % sql + return sql, params class KeyTransformExact(JSONExact): diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index a4fbb046484..576ff400c46 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -241,7 +241,7 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): if hasattr(param, 'resolve_expression'): param = param.resolve_expression(compiler.query) if hasattr(param, 'as_sql'): - sql, params = param.as_sql(compiler, connection) + sql, params = compiler.compile(param) return sql, params def batch_process_rhs(self, compiler, connection, rhs=None): diff --git a/docs/releases/3.1.3.txt b/docs/releases/3.1.3.txt index 4f547577ce4..82cd5238297 100644 --- a/docs/releases/3.1.3.txt +++ b/docs/releases/3.1.3.txt @@ -26,3 +26,7 @@ Bugfixes :class:`~django.contrib.postgres.aggregates.JSONBAgg`, and :class:`~django.contrib.postgres.aggregates.StringAgg` with ``ordering`` on key transforms for :class:`~django.db.models.JSONField` (:ticket:`32096`). + +* Fixed a regression in Django 3.1 that caused a crash of ``__in`` lookup when + using key transforms for :class:`~django.db.models.JSONField` in the lookup + value (:ticket:`32096`). diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py index 79f2609c35b..a590ac12216 100644 --- a/tests/model_fields/test_jsonfield.py +++ b/tests/model_fields/test_jsonfield.py @@ -700,6 +700,16 @@ class TestQuerying(TestCase): ('value__0__in', [1], [self.objs[5]]), ('value__0__in', [1, 3], [self.objs[5]]), ('value__foo__in', ['bar'], [self.objs[7]]), + ( + 'value__foo__in', + [KeyTransform('foo', KeyTransform('bax', 'value'))], + [self.objs[7]], + ), + ( + 'value__foo__in', + [KeyTransform('foo', KeyTransform('bax', 'value')), 'baz'], + [self.objs[7]], + ), ('value__foo__in', ['bar', 'baz'], [self.objs[7]]), ('value__bar__in', [['foo', 'bar']], [self.objs[7]]), ('value__bar__in', [['foo', 'bar'], ['a']], [self.objs[7]]),