mirror of https://github.com/django/django.git
Fixed #25307 -- Fixed QuerySet.annotate() crash with conditional expressions.
Thanks Travis Newport for the tests and Josh Smeaton for contributing to the patch.
This commit is contained in:
parent
f2d2f17896
commit
1df89a60c5
|
@ -56,6 +56,17 @@ class Lookup(object):
|
||||||
sqls, sqls_params = ['%s'] * len(params), params
|
sqls, sqls_params = ['%s'] * len(params), params
|
||||||
return sqls, sqls_params
|
return sqls, sqls_params
|
||||||
|
|
||||||
|
def get_source_expressions(self):
|
||||||
|
if self.rhs_is_direct_value():
|
||||||
|
return [self.lhs]
|
||||||
|
return [self.lhs, self.rhs]
|
||||||
|
|
||||||
|
def set_source_expressions(self, new_exprs):
|
||||||
|
if len(new_exprs) == 1:
|
||||||
|
self.lhs = new_exprs[0]
|
||||||
|
else:
|
||||||
|
self.lhs, self.rhs = new_exprs
|
||||||
|
|
||||||
def get_prep_lookup(self):
|
def get_prep_lookup(self):
|
||||||
if hasattr(self.rhs, '_prepare'):
|
if hasattr(self.rhs, '_prepare'):
|
||||||
return self.rhs._prepare(self.lhs.output_field)
|
return self.rhs._prepare(self.lhs.output_field)
|
||||||
|
@ -116,6 +127,10 @@ class Lookup(object):
|
||||||
def contains_aggregate(self):
|
def contains_aggregate(self):
|
||||||
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
|
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_summary(self):
|
||||||
|
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
|
||||||
|
|
||||||
|
|
||||||
class Transform(RegisterLookupMixin, Func):
|
class Transform(RegisterLookupMixin, Func):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -18,6 +18,7 @@ from django.db.models.aggregates import Count
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.expressions import Col, Ref
|
from django.db.models.expressions import Col, Ref
|
||||||
from django.db.models.fields.related_lookups import MultiColSource
|
from django.db.models.fields.related_lookups import MultiColSource
|
||||||
|
from django.db.models.lookups import Lookup
|
||||||
from django.db.models.query_utils import (
|
from django.db.models.query_utils import (
|
||||||
Q, check_rel_lookup_compatibility, refs_expression,
|
Q, check_rel_lookup_compatibility, refs_expression,
|
||||||
)
|
)
|
||||||
|
@ -366,11 +367,20 @@ class Query(object):
|
||||||
orig_exprs = annotation.get_source_expressions()
|
orig_exprs = annotation.get_source_expressions()
|
||||||
new_exprs = []
|
new_exprs = []
|
||||||
for expr in orig_exprs:
|
for expr in orig_exprs:
|
||||||
|
# FIXME: These conditions are fairly arbitrary. Identify a better
|
||||||
|
# method of having expressions decide which code path they should
|
||||||
|
# take.
|
||||||
if isinstance(expr, Ref):
|
if isinstance(expr, Ref):
|
||||||
# Its already a Ref to subquery (see resolve_ref() for
|
# Its already a Ref to subquery (see resolve_ref() for
|
||||||
# details)
|
# details)
|
||||||
new_exprs.append(expr)
|
new_exprs.append(expr)
|
||||||
elif isinstance(expr, Col):
|
elif isinstance(expr, (WhereNode, Lookup)):
|
||||||
|
# Decompose the subexpressions further. The code here is
|
||||||
|
# copied from the else clause, but this condition must appear
|
||||||
|
# before the contains_aggregate/is_summary condition below.
|
||||||
|
new_expr, col_cnt = self.rewrite_cols(expr, col_cnt)
|
||||||
|
new_exprs.append(new_expr)
|
||||||
|
elif isinstance(expr, Col) or (expr.contains_aggregate and not expr.is_summary):
|
||||||
# Reference to column. Make sure the referenced column
|
# Reference to column. Make sure the referenced column
|
||||||
# is selected.
|
# is selected.
|
||||||
col_cnt += 1
|
col_cnt += 1
|
||||||
|
|
|
@ -118,6 +118,13 @@ class WhereNode(tree.Node):
|
||||||
cols.extend(child.get_group_by_cols())
|
cols.extend(child.get_group_by_cols())
|
||||||
return cols
|
return cols
|
||||||
|
|
||||||
|
def get_source_expressions(self):
|
||||||
|
return self.children[:]
|
||||||
|
|
||||||
|
def set_source_expressions(self, children):
|
||||||
|
assert len(children) == len(self.children)
|
||||||
|
self.children = children
|
||||||
|
|
||||||
def relabel_aliases(self, change_map):
|
def relabel_aliases(self, change_map):
|
||||||
"""
|
"""
|
||||||
Relabels the alias values of any children. 'change_map' is a dictionary
|
Relabels the alias values of any children. 'change_map' is a dictionary
|
||||||
|
@ -160,6 +167,10 @@ class WhereNode(tree.Node):
|
||||||
def contains_aggregate(self):
|
def contains_aggregate(self):
|
||||||
return self._contains_aggregate(self)
|
return self._contains_aggregate(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_summary(self):
|
||||||
|
return any(child.is_summary for child in self.children)
|
||||||
|
|
||||||
|
|
||||||
class NothingNode(object):
|
class NothingNode(object):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -9,7 +9,8 @@ from django.contrib.contenttypes.models import ContentType
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
Avg, Count, F, Max, Q, StdDev, Sum, Value, Variance,
|
Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum,
|
||||||
|
Value, Variance, When,
|
||||||
)
|
)
|
||||||
from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature
|
||||||
from django.test.utils import Approximate
|
from django.test.utils import Approximate
|
||||||
|
@ -371,6 +372,51 @@ class AggregationTests(TestCase):
|
||||||
{'c__max': 3}
|
{'c__max': 3}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_conditional_aggreate(self):
|
||||||
|
# Conditional aggregation of a grouped queryset.
|
||||||
|
self.assertEqual(
|
||||||
|
Book.objects.annotate(c=Count('authors')).values('pk').aggregate(test=Sum(
|
||||||
|
Case(When(c__gt=1, then=1), output_field=IntegerField())
|
||||||
|
))['test'],
|
||||||
|
3
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sliced_conditional_aggregate(self):
|
||||||
|
self.assertEqual(
|
||||||
|
Author.objects.all()[:5].aggregate(test=Sum(Case(
|
||||||
|
When(age__lte=35, then=1), output_field=IntegerField()
|
||||||
|
)))['test'],
|
||||||
|
3
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_annotated_conditional_aggregate(self):
|
||||||
|
annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
annotated_qs.aggregate(test=Avg(Case(
|
||||||
|
When(pages__lt=400, then='discount_price'),
|
||||||
|
output_field=DecimalField()
|
||||||
|
)))['test'],
|
||||||
|
22.27, places=2
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_distinct_conditional_aggregate(self):
|
||||||
|
self.assertEqual(
|
||||||
|
Book.objects.distinct().aggregate(test=Avg(Case(
|
||||||
|
When(price=Decimal('29.69'), then='pages'),
|
||||||
|
output_field=IntegerField()
|
||||||
|
)))['test'],
|
||||||
|
325
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_conditional_aggregate_on_complex_condition(self):
|
||||||
|
self.assertEqual(
|
||||||
|
Book.objects.distinct().aggregate(test=Avg(Case(
|
||||||
|
When(Q(price__gte=Decimal('29')) & Q(price__lt=Decimal('30')), then='pages'),
|
||||||
|
output_field=IntegerField()
|
||||||
|
)))['test'],
|
||||||
|
325
|
||||||
|
)
|
||||||
|
|
||||||
def test_decimal_aggregate_annotation_filter(self):
|
def test_decimal_aggregate_annotation_filter(self):
|
||||||
"""
|
"""
|
||||||
Filtering on an aggregate annotation with Decimal values should work.
|
Filtering on an aggregate annotation with Decimal values should work.
|
||||||
|
|
Loading…
Reference in New Issue