Refs #27095 -- Allowed (non-nested) arrays containing expressions for ArrayField lookups.

This commit is contained in:
Hannes Ljungberg 2020-11-02 20:38:34 +01:00 committed by Mariusz Felisiak
parent 755b327552
commit 33403bf80f
3 changed files with 53 additions and 14 deletions

View File

@ -4,7 +4,7 @@ from django.contrib.postgres import lookups
from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions from django.core import checks, exceptions
from django.db.models import Field, IntegerField, Transform from django.db.models import Field, Func, IntegerField, Transform, Value
from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.fields.mixins import CheckFieldDefaultMixin
from django.db.models.lookups import Exact, In from django.db.models.lookups import Exact, In
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -198,7 +198,22 @@ class ArrayField(CheckFieldDefaultMixin, Field):
}) })
class ArrayCastRHSMixin: class ArrayRHSMixin:
def __init__(self, lhs, rhs):
if isinstance(rhs, (tuple, list)):
expressions = []
for value in rhs:
if not hasattr(value, 'resolve_expression'):
field = lhs.output_field
value = Value(field.base_field.get_prep_value(value))
expressions.append(value)
rhs = Func(
*expressions,
function='ARRAY',
template='%(function)s[%(expressions)s]',
)
super().__init__(lhs, rhs)
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection)
cast_type = self.lhs.output_field.cast_db_type(connection) cast_type = self.lhs.output_field.cast_db_type(connection)
@ -206,22 +221,22 @@ class ArrayCastRHSMixin:
@ArrayField.register_lookup @ArrayField.register_lookup
class ArrayContains(ArrayCastRHSMixin, lookups.DataContains): class ArrayContains(ArrayRHSMixin, lookups.DataContains):
pass pass
@ArrayField.register_lookup @ArrayField.register_lookup
class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy): class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
pass pass
@ArrayField.register_lookup @ArrayField.register_lookup
class ArrayExact(ArrayCastRHSMixin, Exact): class ArrayExact(ArrayRHSMixin, Exact):
pass pass
@ArrayField.register_lookup @ArrayField.register_lookup
class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap): class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
pass pass

View File

@ -143,6 +143,9 @@ Minor features
allow creating and dropping collations on PostgreSQL. See allow creating and dropping collations on PostgreSQL. See
:ref:`manage-postgresql-collations` for more details. :ref:`manage-postgresql-collations` for more details.
* Lookups for :class:`~django.contrib.postgres.fields.ArrayField` now allow
(non-nested) arrays containing expressions as right-hand sides.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -9,8 +9,8 @@ from django.core import checks, exceptions, serializers, validators
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.core.management import call_command from django.core.management import call_command
from django.db import IntegrityError, connection, models from django.db import IntegrityError, connection, models
from django.db.models.expressions import Exists, OuterRef, RawSQL from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
from django.db.models.functions import Cast from django.db.models.functions import Cast, Upper
from django.test import TransactionTestCase, modify_settings, override_settings from django.test import TransactionTestCase, modify_settings, override_settings
from django.test.utils import isolate_apps from django.test.utils import isolate_apps
from django.utils import timezone from django.utils import timezone
@ -226,6 +226,12 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:1] self.objs[:1]
) )
def test_exact_with_expression(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]),
self.objs[:1],
)
def test_exact_charfield(self): def test_exact_charfield(self):
instance = CharArrayModel.objects.create(field=['text']) instance = CharArrayModel.objects.create(field=['text'])
self.assertSequenceEqual( self.assertSequenceEqual(
@ -296,15 +302,10 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:2] self.objs[:2]
) )
@unittest.expectedFailure
def test_contained_by_including_F_object(self): def test_contained_by_including_F_object(self):
# This test asserts that Array objects passed to filters can be
# constructed to contain F objects. This currently doesn't work as the
# psycopg2 mogrify method that generates the ARRAY() syntax is
# expecting literals, not column references (#27095).
self.assertSequenceEqual( self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]), NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
self.objs[:2] self.objs[:3],
) )
def test_contains(self): def test_contains(self):
@ -326,6 +327,14 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[1:3], self.objs[1:3],
) )
def test_contains_including_expression(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(
field__contains=[2, Value(6) / Value(2)],
),
self.objs[2:3],
)
def test_icontains(self): def test_icontains(self):
# Using the __icontains lookup with ArrayField is inefficient. # Using the __icontains lookup with ArrayField is inefficient.
instance = CharArrayModel.objects.create(field=['FoO']) instance = CharArrayModel.objects.create(field=['FoO'])
@ -353,6 +362,18 @@ class TestQuerying(PostgreSQLTestCase):
[] []
) )
def test_overlap_charfield_including_expression(self):
obj_1 = CharArrayModel.objects.create(field=['TEXT', 'lower text'])
obj_2 = CharArrayModel.objects.create(field=['lower text', 'TEXT'])
CharArrayModel.objects.create(field=['lower text', 'text'])
self.assertSequenceEqual(
CharArrayModel.objects.filter(field__overlap=[
Upper(Value('text')),
'other',
]),
[obj_1, obj_2],
)
def test_lookups_autofield_array(self): def test_lookups_autofield_array(self):
qs = NullableIntegerArrayModel.objects.filter( qs = NullableIntegerArrayModel.objects.filter(
field__0__isnull=False, field__0__isnull=False,