diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 5d49d49e22..87760613fb 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -5,6 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions from django.db.models import Field, IntegerField, Transform +from django.db.models.lookups import Exact from django.utils import six from django.utils.translation import string_concat, ugettext_lazy as _ @@ -162,7 +163,7 @@ class ArrayField(Field): class ArrayContains(lookups.DataContains): def as_sql(self, qn, connection): sql, params = super(ArrayContains, self).as_sql(qn, connection) - sql += '::%s' % self.lhs.output_field.db_type(connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) return sql, params @@ -170,7 +171,15 @@ class ArrayContains(lookups.DataContains): class ArrayContainedBy(lookups.ContainedBy): def as_sql(self, qn, connection): sql, params = super(ArrayContainedBy, self).as_sql(qn, connection) - sql += '::%s' % self.lhs.output_field.db_type(connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) + return sql, params + + +@ArrayField.register_lookup +class ArrayExact(Exact): + def as_sql(self, qn, connection): + sql, params = super(ArrayExact, self).as_sql(qn, connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) return sql, params @@ -178,7 +187,7 @@ class ArrayContainedBy(lookups.ContainedBy): class ArrayOverlap(lookups.Overlap): def as_sql(self, qn, connection): sql, params = super(ArrayOverlap, self).as_sql(qn, connection) - sql += '::%s' % self.lhs.output_field.db_type(connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) return sql, params diff --git a/docs/releases/1.8.7.txt b/docs/releases/1.8.7.txt index f73d53e799..acc82f1581 100644 --- a/docs/releases/1.8.7.txt +++ b/docs/releases/1.8.7.txt @@ -34,3 +34,5 @@ Bugfixes * Fixed serialization of :class:`~django.contrib.postgres.fields.DateRangeField` and :class:`~django.contrib.postgres.fields.DateTimeRangeField` (:ticket:`24937`). + +* Fixed the exact lookup of ``ArrayField`` (:ticket:`25666`). diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index d85c0f58ec..fb7eb16da2 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -110,6 +110,20 @@ class TestQuerying(TestCase): self.objs[:1] ) + def test_exact_charfield(self): + instance = CharArrayModel.objects.create(field=['text']) + self.assertSequenceEqual( + CharArrayModel.objects.filter(field=['text']), + [instance] + ) + + def test_exact_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), + [instance] + ) + def test_isnull(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__isnull=True), @@ -232,6 +246,73 @@ class TestQuerying(TestCase): ) +class TestDateTimeExactQuerying(TestCase): + + def setUp(self): + now = timezone.now() + self.datetimes = [now] + self.dates = [now.date()] + self.times = [now.time()] + self.objs = [ + DateTimeArrayModel.objects.create( + datetimes=self.datetimes, + dates=self.dates, + times=self.times, + ) + ] + + def test_exact_datetimes(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(datetimes=self.datetimes), + self.objs + ) + + def test_exact_dates(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(dates=self.dates), + self.objs + ) + + def test_exact_times(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(times=self.times), + self.objs + ) + + +class TestOtherTypesExactQuerying(TestCase): + + def setUp(self): + self.ips = ['192.168.0.1', '::1'] + self.uuids = [uuid.uuid4()] + self.decimals = [decimal.Decimal(1.25), 1.75] + self.objs = [ + OtherTypesArrayModel.objects.create( + ips=self.ips, + uuids=self.uuids, + decimals=self.decimals, + ) + ] + + def test_exact_ip_addresses(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(ips=self.ips), + self.objs + ) + + def test_exact_uuids(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(uuids=self.uuids), + self.objs + ) + + def test_exact_decimals(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(decimals=self.decimals), + self.objs + ) + + class TestChecks(TestCase): def test_field_checks(self):