[4.2.x] Fixed #34285 -- Fixed index/slice lookups on filtered aggregates with ArrayField.

Thanks Simon Charette for the review.

Backport of ae1fe72e9b from main
This commit is contained in:
Nils VAN ZUIJLEN 2023-02-06 22:46:44 +01:00 committed by Mariusz Felisiak
parent 714d59d57f
commit e8a39da396
2 changed files with 49 additions and 2 deletions

View File

@ -325,7 +325,9 @@ class IndexTransform(Transform):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return "%s[%%s]" % lhs, params + [self.index] if not lhs.endswith("]"):
lhs = "(%s)" % lhs
return "%s[%%s]" % lhs, (*params, self.index)
@property @property
def output_field(self): def output_field(self):
@ -349,7 +351,9 @@ class SliceTransform(Transform):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return "%s[%%s:%%s]" % lhs, params + [self.start, self.end] if not lhs.endswith("]"):
lhs = "(%s)" % lhs
return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
class SliceTransformFactory: class SliceTransformFactory:

View File

@ -365,6 +365,49 @@ class TestGeneralAggregate(PostgreSQLTestCase):
) )
self.assertCountEqual(qs.get(), [1, 2]) self.assertCountEqual(qs.get(), [1, 2])
def test_array_agg_filter_index(self):
aggr1 = AggregateTestModel.objects.create(integer_field=1)
aggr2 = AggregateTestModel.objects.create(integer_field=2)
StatTestModel.objects.bulk_create(
[
StatTestModel(related_field=aggr1, int1=1, int2=0),
StatTestModel(related_field=aggr1, int1=2, int2=1),
StatTestModel(related_field=aggr2, int1=3, int2=0),
StatTestModel(related_field=aggr2, int1=4, int2=1),
]
)
qs = (
AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk])
.annotate(
array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0))
)
.annotate(array_value=F("array__0"))
.values_list("array_value", flat=True)
)
self.assertCountEqual(qs, [1, 3])
def test_array_agg_filter_slice(self):
aggr1 = AggregateTestModel.objects.create(integer_field=1)
aggr2 = AggregateTestModel.objects.create(integer_field=2)
StatTestModel.objects.bulk_create(
[
StatTestModel(related_field=aggr1, int1=1, int2=0),
StatTestModel(related_field=aggr1, int1=2, int2=1),
StatTestModel(related_field=aggr2, int1=3, int2=0),
StatTestModel(related_field=aggr2, int1=4, int2=1),
StatTestModel(related_field=aggr2, int1=5, int2=0),
]
)
qs = (
AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk])
.annotate(
array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0))
)
.annotate(array_value=F("array__1_2"))
.values_list("array_value", flat=True)
)
self.assertCountEqual(qs, [[], [5]])
def test_bit_and_general(self): def test_bit_and_general(self):
values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate( values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
bitand=BitAnd("integer_field") bitand=BitAnd("integer_field")