diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 471eda2970..806ecd1b78 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -1,4 +1,4 @@ -from django.contrib.postgres.fields import JSONField +from django.contrib.postgres.fields import ArrayField, JSONField from django.db.models.aggregates import Aggregate __all__ = [ @@ -10,6 +10,10 @@ class ArrayAgg(Aggregate): function = 'ARRAY_AGG' template = '%(function)s(%(distinct)s%(expressions)s)' + @property + def output_field(self): + return ArrayField(self.source_expressions[0].output_field) + def __init__(self, expression, distinct=False, **extra): super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra) diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index 056d08441b..d4a01ff027 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -46,6 +46,18 @@ class TestGeneralAggregate(PostgreSQLTestCase): values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field')) self.assertEqual(values, {'arrayagg': []}) + def test_array_agg_lookups(self): + aggr1 = AggregateTestModel.objects.create() + aggr2 = AggregateTestModel.objects.create() + StatTestModel.objects.create(related_field=aggr1, int1=1, int2=0) + StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0) + StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0) + StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0) + qs = StatTestModel.objects.values('related_field').annotate( + array=ArrayAgg('int1') + ).filter(array__overlap=[2]).values_list('array', flat=True) + self.assertCountEqual(qs.get(), [1, 2]) + def test_bit_and_general(self): values = AggregateTestModel.objects.filter( integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field'))