mirror of https://github.com/django/django.git
[4.2.x] Fixed #34285 -- Fixed index/slice lookups on filtered aggregates with ArrayField.
Thanks Simon Charette for the review.
Backport of ae1fe72e9b
from main
This commit is contained in:
parent
714d59d57f
commit
e8a39da396
|
@ -325,7 +325,9 @@ class IndexTransform(Transform):
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
lhs, params = compiler.compile(self.lhs)
|
lhs, params = compiler.compile(self.lhs)
|
||||||
return "%s[%%s]" % lhs, params + [self.index]
|
if not lhs.endswith("]"):
|
||||||
|
lhs = "(%s)" % lhs
|
||||||
|
return "%s[%%s]" % lhs, (*params, self.index)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_field(self):
|
def output_field(self):
|
||||||
|
@ -349,7 +351,9 @@ class SliceTransform(Transform):
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
lhs, params = compiler.compile(self.lhs)
|
lhs, params = compiler.compile(self.lhs)
|
||||||
return "%s[%%s:%%s]" % lhs, params + [self.start, self.end]
|
if not lhs.endswith("]"):
|
||||||
|
lhs = "(%s)" % lhs
|
||||||
|
return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
|
||||||
|
|
||||||
|
|
||||||
class SliceTransformFactory:
|
class SliceTransformFactory:
|
||||||
|
|
|
@ -365,6 +365,49 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||||
)
|
)
|
||||||
self.assertCountEqual(qs.get(), [1, 2])
|
self.assertCountEqual(qs.get(), [1, 2])
|
||||||
|
|
||||||
|
def test_array_agg_filter_index(self):
|
||||||
|
aggr1 = AggregateTestModel.objects.create(integer_field=1)
|
||||||
|
aggr2 = AggregateTestModel.objects.create(integer_field=2)
|
||||||
|
StatTestModel.objects.bulk_create(
|
||||||
|
[
|
||||||
|
StatTestModel(related_field=aggr1, int1=1, int2=0),
|
||||||
|
StatTestModel(related_field=aggr1, int1=2, int2=1),
|
||||||
|
StatTestModel(related_field=aggr2, int1=3, int2=0),
|
||||||
|
StatTestModel(related_field=aggr2, int1=4, int2=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
qs = (
|
||||||
|
AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk])
|
||||||
|
.annotate(
|
||||||
|
array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0))
|
||||||
|
)
|
||||||
|
.annotate(array_value=F("array__0"))
|
||||||
|
.values_list("array_value", flat=True)
|
||||||
|
)
|
||||||
|
self.assertCountEqual(qs, [1, 3])
|
||||||
|
|
||||||
|
def test_array_agg_filter_slice(self):
|
||||||
|
aggr1 = AggregateTestModel.objects.create(integer_field=1)
|
||||||
|
aggr2 = AggregateTestModel.objects.create(integer_field=2)
|
||||||
|
StatTestModel.objects.bulk_create(
|
||||||
|
[
|
||||||
|
StatTestModel(related_field=aggr1, int1=1, int2=0),
|
||||||
|
StatTestModel(related_field=aggr1, int1=2, int2=1),
|
||||||
|
StatTestModel(related_field=aggr2, int1=3, int2=0),
|
||||||
|
StatTestModel(related_field=aggr2, int1=4, int2=1),
|
||||||
|
StatTestModel(related_field=aggr2, int1=5, int2=0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
qs = (
|
||||||
|
AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk])
|
||||||
|
.annotate(
|
||||||
|
array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0))
|
||||||
|
)
|
||||||
|
.annotate(array_value=F("array__1_2"))
|
||||||
|
.values_list("array_value", flat=True)
|
||||||
|
)
|
||||||
|
self.assertCountEqual(qs, [[], [5]])
|
||||||
|
|
||||||
def test_bit_and_general(self):
|
def test_bit_and_general(self):
|
||||||
values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
|
values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
|
||||||
bitand=BitAnd("integer_field")
|
bitand=BitAnd("integer_field")
|
||||||
|
|
Loading…
Reference in New Issue