Fixed #30188 -- Fixed a crash when aggregating over a subquery annotation.

This commit is contained in:
Simon Charette 2019-03-08 11:26:53 -05:00 committed by Tim Graham
parent f19a4945e1
commit bdc07f176e
2 changed files with 15 additions and 2 deletions

View File

@ -22,7 +22,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import ( from django.db.models.expressions import (
BaseExpression, Col, F, OuterRef, Ref, SimpleCol, BaseExpression, Col, F, OuterRef, Ref, SimpleCol, Subquery,
) )
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource from django.db.models.fields.related_lookups import MultiColSource
@ -382,7 +382,7 @@ class Query(BaseExpression):
# before the contains_aggregate/is_summary condition below. # before the contains_aggregate/is_summary condition below.
new_expr, col_cnt = self.rewrite_cols(expr, col_cnt) new_expr, col_cnt = self.rewrite_cols(expr, col_cnt)
new_exprs.append(new_expr) new_exprs.append(new_expr)
elif isinstance(expr, Col) or (expr.contains_aggregate and not expr.is_summary): elif isinstance(expr, (Col, Subquery)) or (expr.contains_aggregate and not expr.is_summary):
# Reference to column. Make sure the referenced column # Reference to column. Make sure the referenced column
# is selected. # is selected.
col_cnt += 1 col_cnt += 1

View File

@ -551,6 +551,19 @@ class BasicExpressionsTests(TestCase):
) )
self.assertEqual(qs.get().float, 1.2) self.assertEqual(qs.get().float, 1.2)
@skipUnlessDBFeature('supports_subqueries_in_group_by')
def test_aggregate_subquery_annotation(self):
aggregate = Company.objects.annotate(
ceo_salary=Subquery(
Employee.objects.filter(
id=OuterRef('ceo_id'),
).values('salary')
),
).aggregate(
ceo_salary_gt_20=Count('pk', filter=Q(ceo_salary__gt=20)),
)
self.assertEqual(aggregate, {'ceo_salary_gt_20': 1})
def test_explicit_output_field(self): def test_explicit_output_field(self):
class FuncA(Func): class FuncA(Func):
output_field = models.CharField() output_field = models.CharField()