diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 5f30ed1ab16..9c1bb96b612 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -4,7 +4,7 @@ from django.contrib.postgres import lookups from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions -from django.db.models import Field, IntegerField, Transform +from django.db.models import Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.lookups import Exact, In from django.utils.translation import gettext_lazy as _ @@ -198,7 +198,22 @@ class ArrayField(CheckFieldDefaultMixin, Field): }) -class ArrayCastRHSMixin: +class ArrayRHSMixin: + def __init__(self, lhs, rhs): + if isinstance(rhs, (tuple, list)): + expressions = [] + for value in rhs: + if not hasattr(value, 'resolve_expression'): + field = lhs.output_field + value = Value(field.base_field.get_prep_value(value)) + expressions.append(value) + rhs = Func( + *expressions, + function='ARRAY', + template='%(function)s[%(expressions)s]', + ) + super().__init__(lhs, rhs) + def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) cast_type = self.lhs.output_field.cast_db_type(connection) @@ -206,22 +221,22 @@ class ArrayCastRHSMixin: @ArrayField.register_lookup -class ArrayContains(ArrayCastRHSMixin, lookups.DataContains): +class ArrayContains(ArrayRHSMixin, lookups.DataContains): pass @ArrayField.register_lookup -class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy): +class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy): pass @ArrayField.register_lookup -class ArrayExact(ArrayCastRHSMixin, Exact): +class ArrayExact(ArrayRHSMixin, Exact): pass @ArrayField.register_lookup -class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap): +class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): pass diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 83b717893ef..a96a777e2d7 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -143,6 +143,9 @@ Minor features allow creating and dropping collations on PostgreSQL. See :ref:`manage-postgresql-collations` for more details. +* Lookups for :class:`~django.contrib.postgres.fields.ArrayField` now allow + (non-nested) arrays containing expressions as right-hand sides. + :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 4d6f302b91e..27ccf18581c 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -9,8 +9,8 @@ from django.core import checks, exceptions, serializers, validators from django.core.exceptions import FieldError from django.core.management import call_command from django.db import IntegrityError, connection, models -from django.db.models.expressions import Exists, OuterRef, RawSQL -from django.db.models.functions import Cast +from django.db.models.expressions import Exists, OuterRef, RawSQL, Value +from django.db.models.functions import Cast, Upper from django.test import TransactionTestCase, modify_settings, override_settings from django.test.utils import isolate_apps from django.utils import timezone @@ -226,6 +226,12 @@ class TestQuerying(PostgreSQLTestCase): self.objs[:1] ) + def test_exact_with_expression(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]), + self.objs[:1], + ) + def test_exact_charfield(self): instance = CharArrayModel.objects.create(field=['text']) self.assertSequenceEqual( @@ -296,15 +302,10 @@ class TestQuerying(PostgreSQLTestCase): self.objs[:2] ) - @unittest.expectedFailure def test_contained_by_including_F_object(self): - # This test asserts that Array objects passed to filters can be - # constructed to contain F objects. This currently doesn't work as the - # psycopg2 mogrify method that generates the ARRAY() syntax is - # expecting literals, not column references (#27095). self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]), - self.objs[:2] + self.objs[:3], ) def test_contains(self): @@ -326,6 +327,14 @@ class TestQuerying(PostgreSQLTestCase): self.objs[1:3], ) + def test_contains_including_expression(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field__contains=[2, Value(6) / Value(2)], + ), + self.objs[2:3], + ) + def test_icontains(self): # Using the __icontains lookup with ArrayField is inefficient. instance = CharArrayModel.objects.create(field=['FoO']) @@ -353,6 +362,18 @@ class TestQuerying(PostgreSQLTestCase): [] ) + def test_overlap_charfield_including_expression(self): + obj_1 = CharArrayModel.objects.create(field=['TEXT', 'lower text']) + obj_2 = CharArrayModel.objects.create(field=['lower text', 'TEXT']) + CharArrayModel.objects.create(field=['lower text', 'text']) + self.assertSequenceEqual( + CharArrayModel.objects.filter(field__overlap=[ + Upper(Value('text')), + 'other', + ]), + [obj_1, obj_2], + ) + def test_lookups_autofield_array(self): qs = NullableIntegerArrayModel.objects.filter( field__0__isnull=False,