diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 45348a3f754..ce8829c593c 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -46,7 +46,7 @@ class Aggregate(object): # Validate that the backend has a fully supported, correct # implementation of this aggregate query.connection.ops.check_aggregate_support(aggregate) - query.aggregate_select[alias] = aggregate + query.aggregates[alias] = aggregate class Avg(Aggregate): name = 'Avg' diff --git a/django/db/models/query.py b/django/db/models/query.py index 5b1f4e66fd4..f9b5577891d 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -596,7 +596,7 @@ class QuerySet(object): obj = self._clone() - obj._setup_aggregate_query() + obj._setup_aggregate_query(kwargs.keys()) # Add the aggregates to the query for (alias, aggregate_expr) in kwargs.items(): @@ -693,7 +693,7 @@ class QuerySet(object): """ pass - def _setup_aggregate_query(self): + def _setup_aggregate_query(self, aggregates): """ Prepare the query for computing a result that contains aggregate annotations. """ @@ -773,6 +773,8 @@ class ValuesQuerySet(QuerySet): self.query.select = [] self.query.add_fields(self.field_names, False) + if self.aggregate_names is not None: + self.query.set_aggregate_mask(self.aggregate_names) def _clone(self, klass=None, setup=False, **kwargs): """ @@ -798,13 +800,17 @@ class ValuesQuerySet(QuerySet): raise TypeError("Merging '%s' classes must involve the same values in each case." % self.__class__.__name__) - def _setup_aggregate_query(self): + def _setup_aggregate_query(self, aggregates): """ Prepare the query for computing a result that contains aggregate annotations. """ self.query.set_group_by() - super(ValuesQuerySet, self)._setup_aggregate_query() + if self.aggregate_names is not None: + self.aggregate_names.extend(aggregates) + self.query.set_aggregate_mask(self.aggregate_names) + + super(ValuesQuerySet, self)._setup_aggregate_query(aggregates) def as_sql(self): """ @@ -824,6 +830,7 @@ class ValuesListQuerySet(ValuesQuerySet): def iterator(self): if self.extra_names is not None: self.query.trim_extra_select(self.extra_names) + if self.flat and len(self._fields) == 1: for row in self.query.results_iter(): yield row[0] @@ -837,6 +844,7 @@ class ValuesListQuerySet(ValuesQuerySet): extra_names = self.query.extra_select.keys() field_names = self.field_names aggregate_names = self.query.aggregate_select.keys() + names = extra_names + field_names + aggregate_names # If a field list has been specified, use it. Otherwise, use the diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 629afa29e78..fbc5467b3cc 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -77,7 +77,9 @@ class BaseQuery(object): self.related_select_cols = [] # SQL aggregate-related attributes - self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function + self.aggregates = SortedDict() # Maps alias -> SQL aggregate function + self.aggregate_select_mask = None + self._aggregate_select_cache = None # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). @@ -187,7 +189,15 @@ class BaseQuery(object): obj.distinct = self.distinct obj.select_related = self.select_related obj.related_select_cols = [] - obj.aggregate_select = self.aggregate_select.copy() + obj.aggregates = self.aggregates.copy() + if self.aggregate_select_mask is None: + obj.aggregate_select_mask = None + else: + obj.aggregate_select_mask = self.aggregate_select_mask[:] + if self._aggregate_select_cache is None: + obj._aggregate_select_cache = None + else: + obj._aggregate_select_cache = self._aggregate_select_cache.copy() obj.max_depth = self.max_depth obj.extra_select = self.extra_select.copy() obj.extra_tables = self.extra_tables @@ -940,14 +950,17 @@ class BaseQuery(object): """ assert set(change_map.keys()).intersection(set(change_map.values())) == set() - # 1. Update references in "select" and "where". + # 1. Update references in "select" (normal columns plus aliases), + # "group by", "where" and "having". self.where.relabel_aliases(change_map) - for pos, col in enumerate(self.select): - if isinstance(col, (list, tuple)): - old_alias = col[0] - self.select[pos] = (change_map.get(old_alias, old_alias), col[1]) - else: - col.relabel_aliases(change_map) + self.having.relabel_aliases(change_map) + for columns in (self.select, self.aggregates.values(), self.group_by or []): + for pos, col in enumerate(columns): + if isinstance(col, (list, tuple)): + old_alias = col[0] + columns[pos] = (change_map.get(old_alias, old_alias), col[1]) + else: + col.relabel_aliases(change_map) # 2. Rename the alias in the internal table/alias datastructures. for old_alias, new_alias in change_map.iteritems(): @@ -1205,11 +1218,11 @@ class BaseQuery(object): opts = model._meta field_list = aggregate.lookup.split(LOOKUP_SEP) if (len(field_list) == 1 and - aggregate.lookup in self.aggregate_select.keys()): + aggregate.lookup in self.aggregates.keys()): # Aggregate is over an annotation field_name = field_list[0] col = field_name - source = self.aggregate_select[field_name] + source = self.aggregates[field_name] elif (len(field_list) > 1 or field_list[0] not in [i.name for i in opts.fields]): field, source, opts, join_list, last, _ = self.setup_joins( @@ -1299,7 +1312,7 @@ class BaseQuery(object): value = SQLEvaluator(value, self) having_clause = value.contains_aggregate - for alias, aggregate in self.aggregate_select.items(): + for alias, aggregate in self.aggregates.items(): if alias == parts[0]: entry = self.where_class() entry.add((aggregate, lookup_type, value), AND) @@ -1824,8 +1837,8 @@ class BaseQuery(object): self.group_by = [] if self.connection.features.allows_group_by_pk: if len(self.select) == len(self.model._meta.fields): - self.group_by.append('.'.join([self.model._meta.db_table, - self.model._meta.pk.column])) + self.group_by.append((self.model._meta.db_table, + self.model._meta.pk.column)) return for sel in self.select: @@ -1858,7 +1871,11 @@ class BaseQuery(object): # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False - self.aggregate_select = {None: count} + + # Set only aggregate to be the count column. + # Clear out the select cache to reflect the new unmasked aggregates. + self.aggregates = {None: count} + self.set_aggregate_mask(None) def add_select_related(self, fields): """ @@ -1920,6 +1937,29 @@ class BaseQuery(object): for key in set(self.extra_select).difference(set(names)): del self.extra_select[key] + def set_aggregate_mask(self, names): + "Set the mask of aggregates that will actually be returned by the SELECT" + self.aggregate_select_mask = names + self._aggregate_select_cache = None + + def _aggregate_select(self): + """The SortedDict of aggregate columns that are not masked, and should + be used in the SELECT clause. + + This result is cached for optimization purposes. + """ + if self._aggregate_select_cache is not None: + return self._aggregate_select_cache + elif self.aggregate_select_mask is not None: + self._aggregate_select_cache = SortedDict([ + (k,v) for k,v in self.aggregates.items() + if k in self.aggregate_select_mask + ]) + return self._aggregate_select_cache + else: + return self.aggregates + aggregate_select = property(_aggregate_select) + def set_start(self, start): """ Sets the table from which to start joining. The start position is diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 1d4df127fec..43ac42489aa 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -213,10 +213,14 @@ class WhereNode(tree.Node): elif isinstance(child, tree.Node): self.relabel_aliases(change_map, child) else: - elt = list(child[0]) - if elt[0] in change_map: - elt[0] = change_map[elt[0]] - node.children[pos] = (tuple(elt),) + child[1:] + if isinstance(child[0], (list, tuple)): + elt = list(child[0]) + if elt[0] in change_map: + elt[0] = change_map[elt[0]] + node.children[pos] = (tuple(elt),) + child[1:] + else: + child[0].relabel_aliases(change_map) + # Check if the query value also requires relabelling if hasattr(child[3], 'relabel_aliases'): child[3].relabel_aliases(change_map) diff --git a/docs/topics/db/aggregation.txt b/docs/topics/db/aggregation.txt index 51942d9a1c5..a861959e667 100644 --- a/docs/topics/db/aggregation.txt +++ b/docs/topics/db/aggregation.txt @@ -284,9 +284,6 @@ two authors with the same name, their results will be merged into a single result in the output of the query; the average will be computed as the average over the books written by both authors. -The annotation name will be added to the fields returned -as part of the ``ValuesQuerySet``. - Order of ``annotate()`` and ``values()`` clauses ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -303,12 +300,21 @@ output. For example, if we reverse the order of the ``values()`` and ``annotate()`` clause from our previous example:: - >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name') + >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name', 'average_rating') This will now yield one unique result for each author; however, only the author's name and the ``average_rating`` annotation will be returned in the output data. +You should also note that ``average_rating`` has been explicitly included +in the list of values to be returned. This is required because of the +ordering of the ``values()`` and ``annotate()`` clause. + +If the ``values()`` clause precedes the ``annotate()`` clause, any annotations +will be automatically added to the result set. However, if the ``values()`` +clause is applied after the ``annotate()`` clause, you need to explicitly +include the aggregate column. + Aggregating annotations ----------------------- diff --git a/tests/modeltests/aggregation/models.py b/tests/modeltests/aggregation/models.py index 8537d0bd66a..04b05f90f90 100644 --- a/tests/modeltests/aggregation/models.py +++ b/tests/modeltests/aggregation/models.py @@ -207,10 +207,9 @@ u'The Definitive Guide to Django: Web Development Done Right' >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('pk', 'isbn', 'mean_age') [{'pk': 1, 'isbn': u'159059725', 'mean_age': 34.5}] -# Calling it with paramters reduces the output but does not remove the -# annotation. +# Calling values() with parameters reduces the output >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('name') -[{'name': u'The Definitive Guide to Django: Web Development Done Right', 'mean_age': 34.5}] +[{'name': u'The Definitive Guide to Django: Web Development Done Right'}] # An empty values() call before annotating has the same effect as an # empty values() call after annotating diff --git a/tests/regressiontests/aggregation_regress/models.py b/tests/regressiontests/aggregation_regress/models.py index de913a0a9fb..fae0c2673ef 100644 --- a/tests/regressiontests/aggregation_regress/models.py +++ b/tests/regressiontests/aggregation_regress/models.py @@ -95,10 +95,18 @@ __test__ = {'API_TESTS': """ >>> sorted(Book.objects.all().values().annotate(mean_auth_age=Avg('authors__age')).extra(select={'manufacture_cost' : 'price * .5'}).get(pk=2).items()) [('contact_id', 3), ('id', 2), ('isbn', u'067232959'), ('manufacture_cost', ...11.545...), ('mean_auth_age', 45.0), ('name', u'Sams Teach Yourself Django in 24 Hours'), ('pages', 528), ('price', Decimal("23.09")), ('pubdate', datetime.date(2008, 3, 3)), ('publisher_id', 2), ('rating', 3.0)] -# A values query that selects specific columns reduces the output +# If the annotation precedes the values clause, it won't be included +# unless it is explicitly named >>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name').get(pk=1).items()) +[('name', u'The Definitive Guide to Django: Web Development Done Right')] + +>>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name','mean_auth_age').get(pk=1).items()) [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')] +# If an annotation isn't included in the values, it can still be used in a filter +>>> Book.objects.annotate(n_authors=Count('authors')).values('name').filter(n_authors__gt=2) +[{'name': u'Python Web Development with Django'}] + # The annotations are added to values output if values() precedes annotate() >>> sorted(Book.objects.all().values('name').annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).get(pk=1).items()) [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')] @@ -207,6 +215,11 @@ FieldError: Cannot resolve keyword 'foo' into field. Choices are: authors, conta >>> Book.objects.extra(select={'pub':'publisher_id','foo':'pages'}).values('pub').annotate(Count('id')).order_by('pub') [{'pub': 1, 'id__count': 2}, {'pub': 2, 'id__count': 1}, {'pub': 3, 'id__count': 2}, {'pub': 4, 'id__count': 1}] +# Regression for #10182 - Queries with aggregate calls are correctly realiased when used in a subquery +>>> ids = Book.objects.filter(pages__gt=100).annotate(n_authors=Count('authors')).filter(n_authors__gt=2).order_by('n_authors') +>>> Book.objects.filter(id__in=ids) +[] + # Regression for #10199 - Aggregate calls clone the original query so the original query can still be used >>> books = Book.objects.all() >>> _ = books.aggregate(Avg('authors__age'))