Fixed #24171 -- Fixed failure with complex aggregate query and expressions

The query used a construct of qs.annotate().values().aggregate() where
the first annotate used an F-object reference and the values() and
aggregate() calls referenced that F-object.

Also made sure the inner query's select clause is as simple as possible,
and made sure .values().distinct().aggreate() works correctly.
This commit is contained in:
Anssi Kääriäinen 2015-03-04 14:56:20 +02:00 committed by Tim Graham
parent 63f2dd4ad7
commit fb146193c4
4 changed files with 48 additions and 9 deletions

View File

@ -351,7 +351,7 @@ class Query(object):
# is selected. # is selected.
col_cnt += 1 col_cnt += 1
col_alias = '__col%d' % col_cnt col_alias = '__col%d' % col_cnt
self.annotation_select[col_alias] = expr self.annotations[col_alias] = expr
self.append_annotation_mask([col_alias]) self.append_annotation_mask([col_alias])
new_exprs.append(Ref(col_alias, expr)) new_exprs.append(Ref(col_alias, expr))
else: else:
@ -390,10 +390,22 @@ class Query(object):
from django.db.models.sql.subqueries import AggregateQuery from django.db.models.sql.subqueries import AggregateQuery
outer_query = AggregateQuery(self.model) outer_query = AggregateQuery(self.model)
inner_query = self.clone() inner_query = self.clone()
if not has_limit and not self.distinct_fields:
inner_query.clear_ordering(True)
inner_query.select_for_update = False inner_query.select_for_update = False
inner_query.select_related = False inner_query.select_related = False
if not has_limit and not self.distinct_fields:
# Queries with distinct_fields need ordering and when a limit
# is applied we must take the slice from the ordered query.
# Otherwise no need for ordering.
inner_query.clear_ordering(True)
if not inner_query.distinct:
# If the inner query uses default select and it has some
# aggregate annotations, then we must make sure the inner
# query is grouped by the main model's primary key. However,
# clearing the select clause can alter results if distinct is
# used.
if inner_query.default_cols and has_existing_annotations:
inner_query.group_by = [self.model._meta.pk.get_col(inner_query.get_initial_alias())]
inner_query.default_cols = False
relabels = {t: 'subquery' for t in inner_query.tables} relabels = {t: 'subquery' for t in inner_query.tables}
relabels[None] = 'subquery' relabels[None] = 'subquery'
@ -404,7 +416,14 @@ class Query(object):
if expression.is_summary: if expression.is_summary:
expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt)
outer_query.annotations[alias] = expression.relabeled_clone(relabels) outer_query.annotations[alias] = expression.relabeled_clone(relabels)
del inner_query.annotation_select[alias] del inner_query.annotations[alias]
# Make sure the annotation_select wont use cached results.
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
if inner_query.select == [] and not inner_query.default_cols and not inner_query.annotation_select_mask:
# In case of Model.objects[0:3].count(), there would be no
# field selected in the inner query, yet we must use a subquery.
# So, make sure at least one field is selected.
inner_query.select = [self.model._meta.pk.get_col(inner_query.get_initial_alias())]
try: try:
outer_query.add_subquery(inner_query, using) outer_query.add_subquery(inner_query, using)
except EmptyResultSet: except EmptyResultSet:

View File

@ -7,7 +7,9 @@ from operator import attrgetter
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models import F, Q, Avg, Count, Max, StdDev, Sum, Variance from django.db.models import (
F, Q, Avg, Count, Max, StdDev, Sum, Value, Variance,
)
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils import six from django.utils import six
@ -1232,6 +1234,14 @@ class AggregationTests(TestCase):
) )
self.assertEqual(qs['publisher_awards'], 30) self.assertEqual(qs['publisher_awards'], 30)
def test_annotate_distinct_aggregate(self):
# There are three books with rating of 4.0 and two of the books have
# the same price. Hence, the distinct removes one rating of 4.0
# from the results.
vals1 = Book.objects.values('rating', 'price').distinct().aggregate(result=Sum('rating'))
vals2 = Book.objects.aggregate(result=Sum('rating') - Value(4.0))
self.assertEqual(vals1, vals2)
class JoinPromotionTests(TestCase): class JoinPromotionTests(TestCase):
def test_ticket_21150(self): def test_ticket_21150(self):

View File

@ -12,6 +12,7 @@ from django.utils.encoding import python_2_unicode_compatible
class Employee(models.Model): class Employee(models.Model):
firstname = models.CharField(max_length=50) firstname = models.CharField(max_length=50)
lastname = models.CharField(max_length=50) lastname = models.CharField(max_length=50)
salary = models.IntegerField(blank=True, null=True)
def __str__(self): def __str__(self):
return '%s %s' % (self.firstname, self.lastname) return '%s %s' % (self.firstname, self.lastname)

View File

@ -5,7 +5,7 @@ import uuid
from copy import deepcopy from copy import deepcopy
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, transaction from django.db import DatabaseError, connection, models, transaction
from django.db.models import TimeField, UUIDField from django.db.models import TimeField, UUIDField
from django.db.models.aggregates import ( from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance, Avg, Count, Max, Min, StdDev, Sum, Variance,
@ -30,15 +30,15 @@ class BasicExpressionsTests(TestCase):
def setUpTestData(cls): def setUpTestData(cls):
Company.objects.create( Company.objects.create(
name="Example Inc.", num_employees=2300, num_chairs=5, name="Example Inc.", num_employees=2300, num_chairs=5,
ceo=Employee.objects.create(firstname="Joe", lastname="Smith") ceo=Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
) )
Company.objects.create( Company.objects.create(
name="Foobar Ltd.", num_employees=3, num_chairs=4, name="Foobar Ltd.", num_employees=3, num_chairs=4,
ceo=Employee.objects.create(firstname="Frank", lastname="Meyer") ceo=Employee.objects.create(firstname="Frank", lastname="Meyer", salary=20)
) )
Company.objects.create( Company.objects.create(
name="Test GmbH", num_employees=32, num_chairs=1, name="Test GmbH", num_employees=32, num_chairs=1,
ceo=Employee.objects.create(firstname="Max", lastname="Mustermann") ceo=Employee.objects.create(firstname="Max", lastname="Mustermann", salary=30)
) )
def setUp(self): def setUp(self):
@ -48,6 +48,15 @@ class BasicExpressionsTests(TestCase):
"name", "num_employees", "num_chairs" "name", "num_employees", "num_chairs"
) )
def test_annotate_values_aggregate(self):
companies = Company.objects.annotate(
salaries=F('ceo__salary'),
).values('num_employees', 'salaries').aggregate(
result=Sum(F('salaries') + F('num_employees'),
output_field=models.IntegerField()),
)
self.assertEqual(companies['result'], 2395)
def test_filter_inter_attribute(self): def test_filter_inter_attribute(self):
# We can filter on attribute relationships on same model obj, e.g. # We can filter on attribute relationships on same model obj, e.g.
# find companies where the number of employees is greater # find companies where the number of employees is greater