diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 73944056d5..1180795e8f 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -82,6 +82,10 @@ class ArrayField(CheckFieldDefaultMixin, Field): size = self.size or '' return '%s[%s]' % (self.base_field.db_type(connection), size) + def cast_db_type(self, connection): + size = self.size or '' + return '%s[%s]' % (self.base_field.cast_db_type(connection), size) + def get_placeholder(self, value, compiler, connection): return '%s::{}'.format(self.db_type(connection)) @@ -193,7 +197,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): class ArrayCastRHSMixin: def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) - cast_type = self.lhs.output_field.db_type(connection) + cast_type = self.lhs.output_field.cast_db_type(connection) return '%s::%s' % (rhs, cast_type), rhs_params diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index fc4c07ea86..42885e61f6 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -25,6 +25,7 @@ from .models import ( ) try: + from django.contrib.postgres.aggregates import ArrayAgg from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields.array import IndexTransform, SliceTransform from django.contrib.postgres.forms import ( @@ -280,6 +281,27 @@ class TestQuerying(PostgreSQLTestCase): [] ) + def test_lookups_autofield_array(self): + qs = NullableIntegerArrayModel.objects.filter( + field__0__isnull=False, + ).values('field__0').annotate( + arrayagg=ArrayAgg('id'), + ).order_by('field__0') + tests = ( + ('contained_by', [self.objs[1].pk, self.objs[2].pk, 0], [2]), + ('contains', [self.objs[2].pk], [2]), + ('exact', [self.objs[3].pk], [20]), + ('overlap', [self.objs[1].pk, self.objs[3].pk], [2, 20]), + ) + for lookup, value, expected in tests: + with self.subTest(lookup=lookup): + self.assertSequenceEqual( + qs.filter( + **{'arrayagg__' + lookup: value}, + ).values_list('field__0', flat=True), + expected, + ) + def test_index(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0=2),