diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 1dddcc8c8e9..8f7db0723ca 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -56,6 +56,17 @@ class Lookup(object): sqls, sqls_params = ['%s'] * len(params), params return sqls, sqls_params + def get_source_expressions(self): + if self.rhs_is_direct_value(): + return [self.lhs] + return [self.lhs, self.rhs] + + def set_source_expressions(self, new_exprs): + if len(new_exprs) == 1: + self.lhs = new_exprs[0] + else: + self.lhs, self.rhs = new_exprs + def get_prep_lookup(self): if hasattr(self.rhs, '_prepare'): return self.rhs._prepare(self.lhs.output_field) @@ -116,6 +127,10 @@ class Lookup(object): def contains_aggregate(self): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) + @property + def is_summary(self): + return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) + class Transform(RegisterLookupMixin, Func): """ diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 16ed92a4d45..5f657863737 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,6 +18,7 @@ from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref from django.db.models.fields.related_lookups import MultiColSource +from django.db.models.lookups import Lookup from django.db.models.query_utils import ( Q, check_rel_lookup_compatibility, refs_expression, ) @@ -366,11 +367,20 @@ class Query(object): orig_exprs = annotation.get_source_expressions() new_exprs = [] for expr in orig_exprs: + # FIXME: These conditions are fairly arbitrary. Identify a better + # method of having expressions decide which code path they should + # take. if isinstance(expr, Ref): # Its already a Ref to subquery (see resolve_ref() for # details) new_exprs.append(expr) - elif isinstance(expr, Col): + elif isinstance(expr, (WhereNode, Lookup)): + # Decompose the subexpressions further. The code here is + # copied from the else clause, but this condition must appear + # before the contains_aggregate/is_summary condition below. + new_expr, col_cnt = self.rewrite_cols(expr, col_cnt) + new_exprs.append(new_expr) + elif isinstance(expr, Col) or (expr.contains_aggregate and not expr.is_summary): # Reference to column. Make sure the referenced column # is selected. col_cnt += 1 diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index df44803b8ac..bb4e9252b44 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -118,6 +118,13 @@ class WhereNode(tree.Node): cols.extend(child.get_group_by_cols()) return cols + def get_source_expressions(self): + return self.children[:] + + def set_source_expressions(self, children): + assert len(children) == len(self.children) + self.children = children + def relabel_aliases(self, change_map): """ Relabels the alias values of any children. 'change_map' is a dictionary @@ -160,6 +167,10 @@ class WhereNode(tree.Node): def contains_aggregate(self): return self._contains_aggregate(self) + @property + def is_summary(self): + return any(child.is_summary for child in self.children) + class NothingNode(object): """ diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index 0229020234a..2958736fe9a 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -9,7 +9,8 @@ from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldError from django.db import connection from django.db.models import ( - Avg, Count, F, Max, Q, StdDev, Sum, Value, Variance, + Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum, + Value, Variance, When, ) from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature from django.test.utils import Approximate @@ -371,6 +372,51 @@ class AggregationTests(TestCase): {'c__max': 3} ) + def test_conditional_aggreate(self): + # Conditional aggregation of a grouped queryset. + self.assertEqual( + Book.objects.annotate(c=Count('authors')).values('pk').aggregate(test=Sum( + Case(When(c__gt=1, then=1), output_field=IntegerField()) + ))['test'], + 3 + ) + + def test_sliced_conditional_aggregate(self): + self.assertEqual( + Author.objects.all()[:5].aggregate(test=Sum(Case( + When(age__lte=35, then=1), output_field=IntegerField() + )))['test'], + 3 + ) + + def test_annotated_conditional_aggregate(self): + annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75) + self.assertAlmostEqual( + annotated_qs.aggregate(test=Avg(Case( + When(pages__lt=400, then='discount_price'), + output_field=DecimalField() + )))['test'], + 22.27, places=2 + ) + + def test_distinct_conditional_aggregate(self): + self.assertEqual( + Book.objects.distinct().aggregate(test=Avg(Case( + When(price=Decimal('29.69'), then='pages'), + output_field=IntegerField() + )))['test'], + 325 + ) + + def test_conditional_aggregate_on_complex_condition(self): + self.assertEqual( + Book.objects.distinct().aggregate(test=Avg(Case( + When(Q(price__gte=Decimal('29')) & Q(price__lt=Decimal('30')), then='pages'), + output_field=IntegerField() + )))['test'], + 325 + ) + def test_decimal_aggregate_annotation_filter(self): """ Filtering on an aggregate annotation with Decimal values should work.