Fixed #10182 -- Corrected realiasing and the process of evaluating values() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.

This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2009-02-23 14:47:59 +00:00
parent 4bd24474c0
commit 542709d0d1
7 changed files with 102 additions and 32 deletions

View File

@ -46,7 +46,7 @@ class Aggregate(object):
# Validate that the backend has a fully supported, correct # Validate that the backend has a fully supported, correct
# implementation of this aggregate # implementation of this aggregate
query.connection.ops.check_aggregate_support(aggregate) query.connection.ops.check_aggregate_support(aggregate)
query.aggregate_select[alias] = aggregate query.aggregates[alias] = aggregate
class Avg(Aggregate): class Avg(Aggregate):
name = 'Avg' name = 'Avg'

View File

@ -596,7 +596,7 @@ class QuerySet(object):
obj = self._clone() obj = self._clone()
obj._setup_aggregate_query() obj._setup_aggregate_query(kwargs.keys())
# Add the aggregates to the query # Add the aggregates to the query
for (alias, aggregate_expr) in kwargs.items(): for (alias, aggregate_expr) in kwargs.items():
@ -693,7 +693,7 @@ class QuerySet(object):
""" """
pass pass
def _setup_aggregate_query(self): def _setup_aggregate_query(self, aggregates):
""" """
Prepare the query for computing a result that contains aggregate annotations. Prepare the query for computing a result that contains aggregate annotations.
""" """
@ -773,6 +773,8 @@ class ValuesQuerySet(QuerySet):
self.query.select = [] self.query.select = []
self.query.add_fields(self.field_names, False) self.query.add_fields(self.field_names, False)
if self.aggregate_names is not None:
self.query.set_aggregate_mask(self.aggregate_names)
def _clone(self, klass=None, setup=False, **kwargs): def _clone(self, klass=None, setup=False, **kwargs):
""" """
@ -798,13 +800,17 @@ class ValuesQuerySet(QuerySet):
raise TypeError("Merging '%s' classes must involve the same values in each case." raise TypeError("Merging '%s' classes must involve the same values in each case."
% self.__class__.__name__) % self.__class__.__name__)
def _setup_aggregate_query(self): def _setup_aggregate_query(self, aggregates):
""" """
Prepare the query for computing a result that contains aggregate annotations. Prepare the query for computing a result that contains aggregate annotations.
""" """
self.query.set_group_by() self.query.set_group_by()
super(ValuesQuerySet, self)._setup_aggregate_query() if self.aggregate_names is not None:
self.aggregate_names.extend(aggregates)
self.query.set_aggregate_mask(self.aggregate_names)
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
def as_sql(self): def as_sql(self):
""" """
@ -824,6 +830,7 @@ class ValuesListQuerySet(ValuesQuerySet):
def iterator(self): def iterator(self):
if self.extra_names is not None: if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names) self.query.trim_extra_select(self.extra_names)
if self.flat and len(self._fields) == 1: if self.flat and len(self._fields) == 1:
for row in self.query.results_iter(): for row in self.query.results_iter():
yield row[0] yield row[0]
@ -837,6 +844,7 @@ class ValuesListQuerySet(ValuesQuerySet):
extra_names = self.query.extra_select.keys() extra_names = self.query.extra_select.keys()
field_names = self.field_names field_names = self.field_names
aggregate_names = self.query.aggregate_select.keys() aggregate_names = self.query.aggregate_select.keys()
names = extra_names + field_names + aggregate_names names = extra_names + field_names + aggregate_names
# If a field list has been specified, use it. Otherwise, use the # If a field list has been specified, use it. Otherwise, use the

View File

@ -77,7 +77,9 @@ class BaseQuery(object):
self.related_select_cols = [] self.related_select_cols = []
# SQL aggregate-related attributes # SQL aggregate-related attributes
self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
self.aggregate_select_mask = None
self._aggregate_select_cache = None
# Arbitrary maximum limit for select_related. Prevents infinite # Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related(). # recursion. Can be changed by the depth parameter to select_related().
@ -187,7 +189,15 @@ class BaseQuery(object):
obj.distinct = self.distinct obj.distinct = self.distinct
obj.select_related = self.select_related obj.select_related = self.select_related
obj.related_select_cols = [] obj.related_select_cols = []
obj.aggregate_select = self.aggregate_select.copy() obj.aggregates = self.aggregates.copy()
if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None
else:
obj.aggregate_select_mask = self.aggregate_select_mask[:]
if self._aggregate_select_cache is None:
obj._aggregate_select_cache = None
else:
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy() obj.extra_select = self.extra_select.copy()
obj.extra_tables = self.extra_tables obj.extra_tables = self.extra_tables
@ -940,12 +950,15 @@ class BaseQuery(object):
""" """
assert set(change_map.keys()).intersection(set(change_map.values())) == set() assert set(change_map.keys()).intersection(set(change_map.values())) == set()
# 1. Update references in "select" and "where". # 1. Update references in "select" (normal columns plus aliases),
# "group by", "where" and "having".
self.where.relabel_aliases(change_map) self.where.relabel_aliases(change_map)
for pos, col in enumerate(self.select): self.having.relabel_aliases(change_map)
for columns in (self.select, self.aggregates.values(), self.group_by or []):
for pos, col in enumerate(columns):
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
old_alias = col[0] old_alias = col[0]
self.select[pos] = (change_map.get(old_alias, old_alias), col[1]) columns[pos] = (change_map.get(old_alias, old_alias), col[1])
else: else:
col.relabel_aliases(change_map) col.relabel_aliases(change_map)
@ -1205,11 +1218,11 @@ class BaseQuery(object):
opts = model._meta opts = model._meta
field_list = aggregate.lookup.split(LOOKUP_SEP) field_list = aggregate.lookup.split(LOOKUP_SEP)
if (len(field_list) == 1 and if (len(field_list) == 1 and
aggregate.lookup in self.aggregate_select.keys()): aggregate.lookup in self.aggregates.keys()):
# Aggregate is over an annotation # Aggregate is over an annotation
field_name = field_list[0] field_name = field_list[0]
col = field_name col = field_name
source = self.aggregate_select[field_name] source = self.aggregates[field_name]
elif (len(field_list) > 1 or elif (len(field_list) > 1 or
field_list[0] not in [i.name for i in opts.fields]): field_list[0] not in [i.name for i in opts.fields]):
field, source, opts, join_list, last, _ = self.setup_joins( field, source, opts, join_list, last, _ = self.setup_joins(
@ -1299,7 +1312,7 @@ class BaseQuery(object):
value = SQLEvaluator(value, self) value = SQLEvaluator(value, self)
having_clause = value.contains_aggregate having_clause = value.contains_aggregate
for alias, aggregate in self.aggregate_select.items(): for alias, aggregate in self.aggregates.items():
if alias == parts[0]: if alias == parts[0]:
entry = self.where_class() entry = self.where_class()
entry.add((aggregate, lookup_type, value), AND) entry.add((aggregate, lookup_type, value), AND)
@ -1824,8 +1837,8 @@ class BaseQuery(object):
self.group_by = [] self.group_by = []
if self.connection.features.allows_group_by_pk: if self.connection.features.allows_group_by_pk:
if len(self.select) == len(self.model._meta.fields): if len(self.select) == len(self.model._meta.fields):
self.group_by.append('.'.join([self.model._meta.db_table, self.group_by.append((self.model._meta.db_table,
self.model._meta.pk.column])) self.model._meta.pk.column))
return return
for sel in self.select: for sel in self.select:
@ -1858,7 +1871,11 @@ class BaseQuery(object):
# Distinct handling is done in Count(), so don't do it at this # Distinct handling is done in Count(), so don't do it at this
# level. # level.
self.distinct = False self.distinct = False
self.aggregate_select = {None: count}
# Set only aggregate to be the count column.
# Clear out the select cache to reflect the new unmasked aggregates.
self.aggregates = {None: count}
self.set_aggregate_mask(None)
def add_select_related(self, fields): def add_select_related(self, fields):
""" """
@ -1920,6 +1937,29 @@ class BaseQuery(object):
for key in set(self.extra_select).difference(set(names)): for key in set(self.extra_select).difference(set(names)):
del self.extra_select[key] del self.extra_select[key]
def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT"
self.aggregate_select_mask = names
self._aggregate_select_cache = None
def _aggregate_select(self):
"""The SortedDict of aggregate columns that are not masked, and should
be used in the SELECT clause.
This result is cached for optimization purposes.
"""
if self._aggregate_select_cache is not None:
return self._aggregate_select_cache
elif self.aggregate_select_mask is not None:
self._aggregate_select_cache = SortedDict([
(k,v) for k,v in self.aggregates.items()
if k in self.aggregate_select_mask
])
return self._aggregate_select_cache
else:
return self.aggregates
aggregate_select = property(_aggregate_select)
def set_start(self, start): def set_start(self, start):
""" """
Sets the table from which to start joining. The start position is Sets the table from which to start joining. The start position is

View File

@ -213,10 +213,14 @@ class WhereNode(tree.Node):
elif isinstance(child, tree.Node): elif isinstance(child, tree.Node):
self.relabel_aliases(change_map, child) self.relabel_aliases(change_map, child)
else: else:
if isinstance(child[0], (list, tuple)):
elt = list(child[0]) elt = list(child[0])
if elt[0] in change_map: if elt[0] in change_map:
elt[0] = change_map[elt[0]] elt[0] = change_map[elt[0]]
node.children[pos] = (tuple(elt),) + child[1:] node.children[pos] = (tuple(elt),) + child[1:]
else:
child[0].relabel_aliases(change_map)
# Check if the query value also requires relabelling # Check if the query value also requires relabelling
if hasattr(child[3], 'relabel_aliases'): if hasattr(child[3], 'relabel_aliases'):
child[3].relabel_aliases(change_map) child[3].relabel_aliases(change_map)

View File

@ -284,9 +284,6 @@ two authors with the same name, their results will be merged into a single
result in the output of the query; the average will be computed as the result in the output of the query; the average will be computed as the
average over the books written by both authors. average over the books written by both authors.
The annotation name will be added to the fields returned
as part of the ``ValuesQuerySet``.
Order of ``annotate()`` and ``values()`` clauses Order of ``annotate()`` and ``values()`` clauses
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -303,12 +300,21 @@ output.
For example, if we reverse the order of the ``values()`` and ``annotate()`` For example, if we reverse the order of the ``values()`` and ``annotate()``
clause from our previous example:: clause from our previous example::
>>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name') >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name', 'average_rating')
This will now yield one unique result for each author; however, only This will now yield one unique result for each author; however, only
the author's name and the ``average_rating`` annotation will be returned the author's name and the ``average_rating`` annotation will be returned
in the output data. in the output data.
You should also note that ``average_rating`` has been explicitly included
in the list of values to be returned. This is required because of the
ordering of the ``values()`` and ``annotate()`` clause.
If the ``values()`` clause precedes the ``annotate()`` clause, any annotations
will be automatically added to the result set. However, if the ``values()``
clause is applied after the ``annotate()`` clause, you need to explicitly
include the aggregate column.
Aggregating annotations Aggregating annotations
----------------------- -----------------------

View File

@ -207,10 +207,9 @@ u'The Definitive Guide to Django: Web Development Done Right'
>>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('pk', 'isbn', 'mean_age') >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('pk', 'isbn', 'mean_age')
[{'pk': 1, 'isbn': u'159059725', 'mean_age': 34.5}] [{'pk': 1, 'isbn': u'159059725', 'mean_age': 34.5}]
# Calling it with paramters reduces the output but does not remove the # Calling values() with parameters reduces the output
# annotation.
>>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('name') >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('name')
[{'name': u'The Definitive Guide to Django: Web Development Done Right', 'mean_age': 34.5}] [{'name': u'The Definitive Guide to Django: Web Development Done Right'}]
# An empty values() call before annotating has the same effect as an # An empty values() call before annotating has the same effect as an
# empty values() call after annotating # empty values() call after annotating

View File

@ -95,10 +95,18 @@ __test__ = {'API_TESTS': """
>>> sorted(Book.objects.all().values().annotate(mean_auth_age=Avg('authors__age')).extra(select={'manufacture_cost' : 'price * .5'}).get(pk=2).items()) >>> sorted(Book.objects.all().values().annotate(mean_auth_age=Avg('authors__age')).extra(select={'manufacture_cost' : 'price * .5'}).get(pk=2).items())
[('contact_id', 3), ('id', 2), ('isbn', u'067232959'), ('manufacture_cost', ...11.545...), ('mean_auth_age', 45.0), ('name', u'Sams Teach Yourself Django in 24 Hours'), ('pages', 528), ('price', Decimal("23.09")), ('pubdate', datetime.date(2008, 3, 3)), ('publisher_id', 2), ('rating', 3.0)] [('contact_id', 3), ('id', 2), ('isbn', u'067232959'), ('manufacture_cost', ...11.545...), ('mean_auth_age', 45.0), ('name', u'Sams Teach Yourself Django in 24 Hours'), ('pages', 528), ('price', Decimal("23.09")), ('pubdate', datetime.date(2008, 3, 3)), ('publisher_id', 2), ('rating', 3.0)]
# A values query that selects specific columns reduces the output # If the annotation precedes the values clause, it won't be included
# unless it is explicitly named
>>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name').get(pk=1).items()) >>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name').get(pk=1).items())
[('name', u'The Definitive Guide to Django: Web Development Done Right')]
>>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name','mean_auth_age').get(pk=1).items())
[('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')] [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')]
# If an annotation isn't included in the values, it can still be used in a filter
>>> Book.objects.annotate(n_authors=Count('authors')).values('name').filter(n_authors__gt=2)
[{'name': u'Python Web Development with Django'}]
# The annotations are added to values output if values() precedes annotate() # The annotations are added to values output if values() precedes annotate()
>>> sorted(Book.objects.all().values('name').annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).get(pk=1).items()) >>> sorted(Book.objects.all().values('name').annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).get(pk=1).items())
[('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')] [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')]
@ -207,6 +215,11 @@ FieldError: Cannot resolve keyword 'foo' into field. Choices are: authors, conta
>>> Book.objects.extra(select={'pub':'publisher_id','foo':'pages'}).values('pub').annotate(Count('id')).order_by('pub') >>> Book.objects.extra(select={'pub':'publisher_id','foo':'pages'}).values('pub').annotate(Count('id')).order_by('pub')
[{'pub': 1, 'id__count': 2}, {'pub': 2, 'id__count': 1}, {'pub': 3, 'id__count': 2}, {'pub': 4, 'id__count': 1}] [{'pub': 1, 'id__count': 2}, {'pub': 2, 'id__count': 1}, {'pub': 3, 'id__count': 2}, {'pub': 4, 'id__count': 1}]
# Regression for #10182 - Queries with aggregate calls are correctly realiased when used in a subquery
>>> ids = Book.objects.filter(pages__gt=100).annotate(n_authors=Count('authors')).filter(n_authors__gt=2).order_by('n_authors')
>>> Book.objects.filter(id__in=ids)
[<Book: Python Web Development with Django>]
# Regression for #10199 - Aggregate calls clone the original query so the original query can still be used # Regression for #10199 - Aggregate calls clone the original query so the original query can still be used
>>> books = Book.objects.all() >>> books = Book.objects.all()
>>> _ = books.aggregate(Avg('authors__age')) >>> _ = books.aggregate(Avg('authors__age'))