From bd337184f1b4901abdfd07b0219d6b9b45d0de2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Thu, 20 Nov 2014 14:30:25 +0200 Subject: [PATCH] Fixed #23877 -- aggregation's subquery missed target col Aggregation over subquery produced syntactically incorrect queries in some cases as Django didn't ensure that source expressions of the aggregation were present in the subquery. --- django/db/models/sql/compiler.py | 12 ++++++-- django/db/models/sql/query.py | 49 +++++++++++++++++++++++++++--- tests/aggregation_regress/tests.py | 27 ++++++++++++++++ 3 files changed, 81 insertions(+), 7 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index dbbc9aecc8b..bfce9063e5e 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -580,10 +580,10 @@ class SQLCompiler(object): if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - self.compile(col) + sql, col_params = self.compile(col) else: sql = '(%s)' % str(col) - if sql not in seen: + if sql not in seen or col_params: result.append(sql) params.extend(col_params) seen.add(sql) @@ -604,6 +604,14 @@ class SQLCompiler(object): sql = '(%s)' % str(extra_select) result.append(sql) params.extend(extra_params) + # Finally, add needed group by cols from annotations + for annotation in self.query.annotation_select.values(): + cols = annotation.get_group_by_cols() + for col in cols: + sql = '%s.%s' % (qn(col[0]), qn(col[1])) + if sql not in seen: + result.append(sql) + seen.add(sql) return result, params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 81b38703391..9f5ca0dc506 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -313,6 +313,41 @@ class Query(object): clone.change_aliases(change_map) return clone + def rewrite_cols(self, annotation, col_cnt): + # We must make sure the inner query has the referred columns in it. + # If we are aggregating over an annotation, then Django uses Ref() + # instances to note this. However, if we are annotating over a column + # of a related model, then it might be that column isn't part of the + # SELECT clause of the inner query, and we must manually make sure + # the column is selected. An example case is: + # .aggregate(Sum('author__awards')) + # Resolving this expression results in a join to author, but there + # is no guarantee the awards column of author is in the select clause + # of the query. Thus we must manually add the column to the inner + # query. + orig_exprs = annotation.get_source_expressions() + new_exprs = [] + for expr in orig_exprs: + if isinstance(expr, Ref): + # Its already a Ref to subquery (see resolve_ref() for + # details) + new_exprs.append(expr) + elif isinstance(expr, Col): + # Reference to column. Make sure the referenced column + # is selected. + col_cnt += 1 + col_alias = '__col%d' % col_cnt + self.annotation_select[col_alias] = expr + self.append_annotation_mask([col_alias]) + new_exprs.append(Ref(col_alias, expr)) + else: + # Some other expression not referencing database values + # directly. Its subexpression might contain Cols. + new_expr, col_cnt = self.rewrite_cols(expr, col_cnt) + new_exprs.append(new_expr) + annotation.set_source_expressions(new_exprs) + return annotation, col_cnt + def get_aggregation(self, using, added_aggregate_names): """ Returns the dictionary with the values of the existing aggregations. @@ -350,11 +385,11 @@ class Query(object): relabels[None] = 'subquery' # Remove any aggregates marked for reduction from the subquery # and move them to the outer AggregateQuery. - for alias, annotation in inner_query.annotation_select.items(): - if annotation.is_summary: - # The annotation is already referring the subquery alias, so we - # just need to move the annotation to the outer query. - outer_query.annotations[alias] = annotation.relabeled_clone(relabels) + col_cnt = 0 + for alias, expression in inner_query.annotation_select.items(): + if expression.is_summary: + expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) + outer_query.annotations[alias] = expression.relabeled_clone(relabels) del inner_query.annotation_select[alias] try: outer_query.add_subquery(inner_query, using) @@ -1495,6 +1530,10 @@ class Query(object): raise FieldError("Joined field references are not permitted in this query") if name in self.annotations: if summarize: + # Summarize currently means we are doing an aggregate() query + # which is executed as a wrapped subquery if any of the + # aggregate() elements reference an existing annotation. In + # that case we need to return a Ref to the subquery's annotation. return Ref(name, self.annotation_select[name]) else: return self.annotation_select[name] diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index b9daa63f1be..de18372f9cd 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -1168,3 +1168,30 @@ class JoinPromotionTests(TestCase): def test_non_nullable_fk_not_promoted(self): qs = Book.objects.annotate(Count('contact__name')) self.assertIn(' INNER JOIN ', str(qs.query)) + + +class AggregationOnRelationTest(TestCase): + def setUp(self): + self.a = Author.objects.create(name='Anssi', age=33) + self.p = Publisher.objects.create(name='Manning', num_awards=3) + Book.objects.create(isbn='asdf', name='Foo', pages=10, rating=0.1, price="0.0", + contact=self.a, publisher=self.p, pubdate=datetime.date.today()) + + def test_annotate_on_relation(self): + qs = Book.objects.annotate(avg_price=Avg('price'), publisher_name=F('publisher__name')) + self.assertEqual(qs[0].avg_price, 0.0) + self.assertEqual(qs[0].publisher_name, "Manning") + + def test_aggregate_on_relation(self): + # A query with an existing annotation aggregation on a relation should + # succeed. + qs = Book.objects.annotate(avg_price=Avg('price')).aggregate( + publisher_awards=Sum('publisher__num_awards') + ) + self.assertEqual(qs['publisher_awards'], 3) + Book.objects.create(isbn='asdf', name='Foo', pages=10, rating=0.1, price="0.0", + contact=self.a, publisher=self.p, pubdate=datetime.date.today()) + qs = Book.objects.annotate(avg_price=Avg('price')).aggregate( + publisher_awards=Sum('publisher__num_awards') + ) + self.assertEqual(qs['publisher_awards'], 6)