diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 27c369d61f..4da2a9658b 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -5,6 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions from django.db.models import Field, IntegerField, Transform +from django.db.models.lookups import Exact from django.utils import six from django.utils.translation import string_concat, ugettext_lazy as _ @@ -166,7 +167,7 @@ class ArrayField(Field): class ArrayContains(lookups.DataContains): def as_sql(self, qn, connection): sql, params = super(ArrayContains, self).as_sql(qn, connection) - sql += '::%s' % self.lhs.output_field.db_type(connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) return sql, params @@ -174,7 +175,15 @@ class ArrayContains(lookups.DataContains): class ArrayContainedBy(lookups.ContainedBy): def as_sql(self, qn, connection): sql, params = super(ArrayContainedBy, self).as_sql(qn, connection) - sql += '::%s' % self.lhs.output_field.db_type(connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) + return sql, params + + +@ArrayField.register_lookup +class ArrayExact(Exact): + def as_sql(self, qn, connection): + sql, params = super(ArrayExact, self).as_sql(qn, connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) return sql, params @@ -182,7 +191,7 @@ class ArrayContainedBy(lookups.ContainedBy): class ArrayOverlap(lookups.Overlap): def as_sql(self, qn, connection): sql, params = super(ArrayOverlap, self).as_sql(qn, connection) - sql += '::%s' % self.lhs.output_field.db_type(connection) + sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) return sql, params diff --git a/docs/releases/1.8.7.txt b/docs/releases/1.8.7.txt index f73d53e799..acc82f1581 100644 --- a/docs/releases/1.8.7.txt +++ b/docs/releases/1.8.7.txt @@ -34,3 +34,5 @@ Bugfixes * Fixed serialization of :class:`~django.contrib.postgres.fields.DateRangeField` and :class:`~django.contrib.postgres.fields.DateTimeRangeField` (:ticket:`24937`). + +* Fixed the exact lookup of ``ArrayField`` (:ticket:`25666`). diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 9f982178c7..b5a594592d 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -122,6 +122,20 @@ class TestQuerying(PostgreSQLTestCase): self.objs[:1] ) + def test_exact_charfield(self): + instance = CharArrayModel.objects.create(field=['text']) + self.assertSequenceEqual( + CharArrayModel.objects.filter(field=['text']), + [instance] + ) + + def test_exact_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), + [instance] + ) + def test_isnull(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__isnull=True), @@ -244,6 +258,73 @@ class TestQuerying(PostgreSQLTestCase): ) +class TestDateTimeExactQuerying(PostgreSQLTestCase): + + def setUp(self): + now = timezone.now() + self.datetimes = [now] + self.dates = [now.date()] + self.times = [now.time()] + self.objs = [ + DateTimeArrayModel.objects.create( + datetimes=self.datetimes, + dates=self.dates, + times=self.times, + ) + ] + + def test_exact_datetimes(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(datetimes=self.datetimes), + self.objs + ) + + def test_exact_dates(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(dates=self.dates), + self.objs + ) + + def test_exact_times(self): + self.assertSequenceEqual( + DateTimeArrayModel.objects.filter(times=self.times), + self.objs + ) + + +class TestOtherTypesExactQuerying(PostgreSQLTestCase): + + def setUp(self): + self.ips = ['192.168.0.1', '::1'] + self.uuids = [uuid.uuid4()] + self.decimals = [decimal.Decimal(1.25), 1.75] + self.objs = [ + OtherTypesArrayModel.objects.create( + ips=self.ips, + uuids=self.uuids, + decimals=self.decimals, + ) + ] + + def test_exact_ip_addresses(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(ips=self.ips), + self.objs + ) + + def test_exact_uuids(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(uuids=self.uuids), + self.objs + ) + + def test_exact_decimals(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(decimals=self.decimals), + self.objs + ) + + class TestChecks(PostgreSQLTestCase): def test_field_checks(self):