mirror of https://github.com/django/django.git
Refs #27095 -- Allowed (non-nested) arrays containing expressions for ArrayField lookups.
This commit is contained in:
parent
755b327552
commit
33403bf80f
|
@ -4,7 +4,7 @@ from django.contrib.postgres import lookups
|
||||||
from django.contrib.postgres.forms import SimpleArrayField
|
from django.contrib.postgres.forms import SimpleArrayField
|
||||||
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
||||||
from django.core import checks, exceptions
|
from django.core import checks, exceptions
|
||||||
from django.db.models import Field, IntegerField, Transform
|
from django.db.models import Field, Func, IntegerField, Transform, Value
|
||||||
from django.db.models.fields.mixins import CheckFieldDefaultMixin
|
from django.db.models.fields.mixins import CheckFieldDefaultMixin
|
||||||
from django.db.models.lookups import Exact, In
|
from django.db.models.lookups import Exact, In
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
@ -198,7 +198,22 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
class ArrayCastRHSMixin:
|
class ArrayRHSMixin:
|
||||||
|
def __init__(self, lhs, rhs):
|
||||||
|
if isinstance(rhs, (tuple, list)):
|
||||||
|
expressions = []
|
||||||
|
for value in rhs:
|
||||||
|
if not hasattr(value, 'resolve_expression'):
|
||||||
|
field = lhs.output_field
|
||||||
|
value = Value(field.base_field.get_prep_value(value))
|
||||||
|
expressions.append(value)
|
||||||
|
rhs = Func(
|
||||||
|
*expressions,
|
||||||
|
function='ARRAY',
|
||||||
|
template='%(function)s[%(expressions)s]',
|
||||||
|
)
|
||||||
|
super().__init__(lhs, rhs)
|
||||||
|
|
||||||
def process_rhs(self, compiler, connection):
|
def process_rhs(self, compiler, connection):
|
||||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||||
cast_type = self.lhs.output_field.cast_db_type(connection)
|
cast_type = self.lhs.output_field.cast_db_type(connection)
|
||||||
|
@ -206,22 +221,22 @@ class ArrayCastRHSMixin:
|
||||||
|
|
||||||
|
|
||||||
@ArrayField.register_lookup
|
@ArrayField.register_lookup
|
||||||
class ArrayContains(ArrayCastRHSMixin, lookups.DataContains):
|
class ArrayContains(ArrayRHSMixin, lookups.DataContains):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ArrayField.register_lookup
|
@ArrayField.register_lookup
|
||||||
class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy):
|
class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ArrayField.register_lookup
|
@ArrayField.register_lookup
|
||||||
class ArrayExact(ArrayCastRHSMixin, Exact):
|
class ArrayExact(ArrayRHSMixin, Exact):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ArrayField.register_lookup
|
@ArrayField.register_lookup
|
||||||
class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap):
|
class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -143,6 +143,9 @@ Minor features
|
||||||
allow creating and dropping collations on PostgreSQL. See
|
allow creating and dropping collations on PostgreSQL. See
|
||||||
:ref:`manage-postgresql-collations` for more details.
|
:ref:`manage-postgresql-collations` for more details.
|
||||||
|
|
||||||
|
* Lookups for :class:`~django.contrib.postgres.fields.ArrayField` now allow
|
||||||
|
(non-nested) arrays containing expressions as right-hand sides.
|
||||||
|
|
||||||
:mod:`django.contrib.redirects`
|
:mod:`django.contrib.redirects`
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,8 @@ from django.core import checks, exceptions, serializers, validators
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.db import IntegrityError, connection, models
|
from django.db import IntegrityError, connection, models
|
||||||
from django.db.models.expressions import Exists, OuterRef, RawSQL
|
from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
|
||||||
from django.db.models.functions import Cast
|
from django.db.models.functions import Cast, Upper
|
||||||
from django.test import TransactionTestCase, modify_settings, override_settings
|
from django.test import TransactionTestCase, modify_settings, override_settings
|
||||||
from django.test.utils import isolate_apps
|
from django.test.utils import isolate_apps
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
@ -226,6 +226,12 @@ class TestQuerying(PostgreSQLTestCase):
|
||||||
self.objs[:1]
|
self.objs[:1]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_exact_with_expression(self):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]),
|
||||||
|
self.objs[:1],
|
||||||
|
)
|
||||||
|
|
||||||
def test_exact_charfield(self):
|
def test_exact_charfield(self):
|
||||||
instance = CharArrayModel.objects.create(field=['text'])
|
instance = CharArrayModel.objects.create(field=['text'])
|
||||||
self.assertSequenceEqual(
|
self.assertSequenceEqual(
|
||||||
|
@ -296,15 +302,10 @@ class TestQuerying(PostgreSQLTestCase):
|
||||||
self.objs[:2]
|
self.objs[:2]
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_contained_by_including_F_object(self):
|
def test_contained_by_including_F_object(self):
|
||||||
# This test asserts that Array objects passed to filters can be
|
|
||||||
# constructed to contain F objects. This currently doesn't work as the
|
|
||||||
# psycopg2 mogrify method that generates the ARRAY() syntax is
|
|
||||||
# expecting literals, not column references (#27095).
|
|
||||||
self.assertSequenceEqual(
|
self.assertSequenceEqual(
|
||||||
NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
|
NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
|
||||||
self.objs[:2]
|
self.objs[:3],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_contains(self):
|
def test_contains(self):
|
||||||
|
@ -326,6 +327,14 @@ class TestQuerying(PostgreSQLTestCase):
|
||||||
self.objs[1:3],
|
self.objs[1:3],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_contains_including_expression(self):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
NullableIntegerArrayModel.objects.filter(
|
||||||
|
field__contains=[2, Value(6) / Value(2)],
|
||||||
|
),
|
||||||
|
self.objs[2:3],
|
||||||
|
)
|
||||||
|
|
||||||
def test_icontains(self):
|
def test_icontains(self):
|
||||||
# Using the __icontains lookup with ArrayField is inefficient.
|
# Using the __icontains lookup with ArrayField is inefficient.
|
||||||
instance = CharArrayModel.objects.create(field=['FoO'])
|
instance = CharArrayModel.objects.create(field=['FoO'])
|
||||||
|
@ -353,6 +362,18 @@ class TestQuerying(PostgreSQLTestCase):
|
||||||
[]
|
[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_overlap_charfield_including_expression(self):
|
||||||
|
obj_1 = CharArrayModel.objects.create(field=['TEXT', 'lower text'])
|
||||||
|
obj_2 = CharArrayModel.objects.create(field=['lower text', 'TEXT'])
|
||||||
|
CharArrayModel.objects.create(field=['lower text', 'text'])
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
CharArrayModel.objects.filter(field__overlap=[
|
||||||
|
Upper(Value('text')),
|
||||||
|
'other',
|
||||||
|
]),
|
||||||
|
[obj_1, obj_2],
|
||||||
|
)
|
||||||
|
|
||||||
def test_lookups_autofield_array(self):
|
def test_lookups_autofield_array(self):
|
||||||
qs = NullableIntegerArrayModel.objects.filter(
|
qs = NullableIntegerArrayModel.objects.filter(
|
||||||
field__0__isnull=False,
|
field__0__isnull=False,
|
||||||
|
|
Loading…
Reference in New Issue