mirror of https://github.com/django/django.git
Fixed #25666 -- Fixed the exact lookup of ArrayField.
This commit is contained in:
parent
b8f78823ee
commit
263b3d2ba1
|
@ -5,6 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField
|
|||
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
||||
from django.core import checks, exceptions
|
||||
from django.db.models import Field, IntegerField, Transform
|
||||
from django.db.models.lookups import Exact
|
||||
from django.utils import six
|
||||
from django.utils.translation import string_concat, ugettext_lazy as _
|
||||
|
||||
|
@ -166,7 +167,7 @@ class ArrayField(Field):
|
|||
class ArrayContains(lookups.DataContains):
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = super(ArrayContains, self).as_sql(qn, connection)
|
||||
sql += '::%s' % self.lhs.output_field.db_type(connection)
|
||||
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
|
||||
return sql, params
|
||||
|
||||
|
||||
|
@ -174,7 +175,15 @@ class ArrayContains(lookups.DataContains):
|
|||
class ArrayContainedBy(lookups.ContainedBy):
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = super(ArrayContainedBy, self).as_sql(qn, connection)
|
||||
sql += '::%s' % self.lhs.output_field.db_type(connection)
|
||||
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
|
||||
return sql, params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayExact(Exact):
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = super(ArrayExact, self).as_sql(qn, connection)
|
||||
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
|
||||
return sql, params
|
||||
|
||||
|
||||
|
@ -182,7 +191,7 @@ class ArrayContainedBy(lookups.ContainedBy):
|
|||
class ArrayOverlap(lookups.Overlap):
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = super(ArrayOverlap, self).as_sql(qn, connection)
|
||||
sql += '::%s' % self.lhs.output_field.db_type(connection)
|
||||
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
|
||||
return sql, params
|
||||
|
||||
|
||||
|
|
|
@ -34,3 +34,5 @@ Bugfixes
|
|||
* Fixed serialization of
|
||||
:class:`~django.contrib.postgres.fields.DateRangeField` and
|
||||
:class:`~django.contrib.postgres.fields.DateTimeRangeField` (:ticket:`24937`).
|
||||
|
||||
* Fixed the exact lookup of ``ArrayField`` (:ticket:`25666`).
|
||||
|
|
|
@ -122,6 +122,20 @@ class TestQuerying(PostgreSQLTestCase):
|
|||
self.objs[:1]
|
||||
)
|
||||
|
||||
def test_exact_charfield(self):
|
||||
instance = CharArrayModel.objects.create(field=['text'])
|
||||
self.assertSequenceEqual(
|
||||
CharArrayModel.objects.filter(field=['text']),
|
||||
[instance]
|
||||
)
|
||||
|
||||
def test_exact_nested(self):
|
||||
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
|
||||
self.assertSequenceEqual(
|
||||
NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]),
|
||||
[instance]
|
||||
)
|
||||
|
||||
def test_isnull(self):
|
||||
self.assertSequenceEqual(
|
||||
NullableIntegerArrayModel.objects.filter(field__isnull=True),
|
||||
|
@ -244,6 +258,73 @@ class TestQuerying(PostgreSQLTestCase):
|
|||
)
|
||||
|
||||
|
||||
class TestDateTimeExactQuerying(PostgreSQLTestCase):
|
||||
|
||||
def setUp(self):
|
||||
now = timezone.now()
|
||||
self.datetimes = [now]
|
||||
self.dates = [now.date()]
|
||||
self.times = [now.time()]
|
||||
self.objs = [
|
||||
DateTimeArrayModel.objects.create(
|
||||
datetimes=self.datetimes,
|
||||
dates=self.dates,
|
||||
times=self.times,
|
||||
)
|
||||
]
|
||||
|
||||
def test_exact_datetimes(self):
|
||||
self.assertSequenceEqual(
|
||||
DateTimeArrayModel.objects.filter(datetimes=self.datetimes),
|
||||
self.objs
|
||||
)
|
||||
|
||||
def test_exact_dates(self):
|
||||
self.assertSequenceEqual(
|
||||
DateTimeArrayModel.objects.filter(dates=self.dates),
|
||||
self.objs
|
||||
)
|
||||
|
||||
def test_exact_times(self):
|
||||
self.assertSequenceEqual(
|
||||
DateTimeArrayModel.objects.filter(times=self.times),
|
||||
self.objs
|
||||
)
|
||||
|
||||
|
||||
class TestOtherTypesExactQuerying(PostgreSQLTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.ips = ['192.168.0.1', '::1']
|
||||
self.uuids = [uuid.uuid4()]
|
||||
self.decimals = [decimal.Decimal(1.25), 1.75]
|
||||
self.objs = [
|
||||
OtherTypesArrayModel.objects.create(
|
||||
ips=self.ips,
|
||||
uuids=self.uuids,
|
||||
decimals=self.decimals,
|
||||
)
|
||||
]
|
||||
|
||||
def test_exact_ip_addresses(self):
|
||||
self.assertSequenceEqual(
|
||||
OtherTypesArrayModel.objects.filter(ips=self.ips),
|
||||
self.objs
|
||||
)
|
||||
|
||||
def test_exact_uuids(self):
|
||||
self.assertSequenceEqual(
|
||||
OtherTypesArrayModel.objects.filter(uuids=self.uuids),
|
||||
self.objs
|
||||
)
|
||||
|
||||
def test_exact_decimals(self):
|
||||
self.assertSequenceEqual(
|
||||
OtherTypesArrayModel.objects.filter(decimals=self.decimals),
|
||||
self.objs
|
||||
)
|
||||
|
||||
|
||||
class TestChecks(PostgreSQLTestCase):
|
||||
|
||||
def test_field_checks(self):
|
||||
|
|
Loading…
Reference in New Issue