Fixed #23875 -- cleaned up query.get_count()

This commit is contained in:
Anssi Kääriäinen 2014-11-20 12:35:56 +02:00 committed by Tim Graham
parent 87bd13617c
commit c7fd9b242d
3 changed files with 33 additions and 96 deletions

View File

@ -335,12 +335,11 @@ class QuerySet(object):
kwargs[arg.default_alias] = arg kwargs[arg.default_alias] = arg
query = self.query.clone() query = self.query.clone()
force_subq = query.low_mark != 0 or query.high_mark is not None
for (alias, aggregate_expr) in kwargs.items(): for (alias, aggregate_expr) in kwargs.items():
query.add_annotation(aggregate_expr, self.model, alias, is_summary=True) query.add_annotation(aggregate_expr, alias, is_summary=True)
if not query.annotations[alias].contains_aggregate: if not query.annotations[alias].contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias) raise TypeError("%s is not an aggregate expression" % alias)
return query.get_aggregation(using=self.db, force_subq=force_subq) return query.get_aggregation(self.db, kwargs.keys())
def count(self): def count(self):
""" """
@ -824,7 +823,7 @@ class QuerySet(object):
if alias in names: if alias in names:
raise ValueError("The annotation '%s' conflicts with a field on " raise ValueError("The annotation '%s' conflicts with a field on "
"the model." % alias) "the model." % alias)
obj.query.add_annotation(annotation, self.model, alias, is_summary=False) obj.query.add_annotation(annotation, alias, is_summary=False)
# expressions need to be added to the query before we know if they contain aggregates # expressions need to be added to the query before we know if they contain aggregates
added_aggregates = [] added_aggregates = []
for alias, annotation in obj.query.annotations.items(): for alias, annotation in obj.query.annotations.items():

View File

@ -1097,6 +1097,11 @@ class SQLAggregateCompiler(SQLCompiler):
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of
parameters. parameters.
""" """
# Empty SQL for the inner query is a marker that the inner query
# isn't going to produce any results. This can happen when doing
# LIMIT 0 queries (generated by qs[:0]) for example.
if not self.query.subquery:
raise EmptyResultSet
sql, params = [], [] sql, params = [], []
for annotation in self.query.annotation_select.values(): for annotation in self.query.annotation_select.values():
agg_sql, agg_params = self.compile(annotation) agg_sql, agg_params = self.compile(annotation)

View File

@ -313,32 +313,35 @@ class Query(object):
clone.change_aliases(change_map) clone.change_aliases(change_map)
return clone return clone
def get_aggregation(self, using, force_subq=False): def get_aggregation(self, using, added_aggregate_names):
""" """
Returns the dictionary with the values of the existing aggregations. Returns the dictionary with the values of the existing aggregations.
""" """
if not self.annotation_select: if not self.annotation_select:
return {} return {}
has_limit = self.low_mark != 0 or self.high_mark is not None
# annotations must be forced into subquery has_existing_annotations = any(
has_annotation = any(
annotation for alias, annotation annotation for alias, annotation
in self.annotation_select.items() in self.annotations.items()
if not annotation.contains_aggregate) if alias not in added_aggregate_names
)
# If there is a group by clause, aggregating does not add useful # Decide if we need to use a subquery.
# information but retrieves only the first row. Aggregate #
# over the subquery instead. # Existing annotations would cause incorrect results as get_aggregation()
if self.group_by is not None or force_subq or has_annotation: # must produce just one result and thus must not use GROUP BY. But we
# aren't smart enough to remove the existing annotations from the
# query, so those would force us to use GROUP BY.
#
# If the query has limit or distinct, then those operations must be
# done in a subquery so that we are aggregating on the limit and/or
# distinct results instead of applying the distinct and limit after the
# aggregation.
if (self.group_by or has_limit or has_existing_annotations or self.distinct):
from django.db.models.sql.subqueries import AggregateQuery from django.db.models.sql.subqueries import AggregateQuery
outer_query = AggregateQuery(self.model) outer_query = AggregateQuery(self.model)
inner_query = self.clone() inner_query = self.clone()
if not force_subq: if not has_limit and not self.distinct_fields:
# In forced subq case the ordering and limits will likely
# affect the results.
inner_query.clear_ordering(True) inner_query.clear_ordering(True)
inner_query.clear_limits()
inner_query.select_for_update = False inner_query.select_for_update = False
inner_query.select_related = False inner_query.select_related = False
inner_query.related_select_cols = [] inner_query.related_select_cols = []
@ -398,34 +401,10 @@ class Query(object):
Performs a COUNT() query using the current filter constraints. Performs a COUNT() query using the current filter constraints.
""" """
obj = self.clone() obj = self.clone()
if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields): obj.add_annotation(Count('*'), alias='__count', is_summary=True)
# If a select clause exists, then the query has already started to number = obj.get_aggregation(using, ['__count'])['__count']
# specify the columns that are to be returned. if number is None:
# In this case, we need to use a subquery to evaluate the count. number = 0
from django.db.models.sql.subqueries import AggregateQuery
subquery = obj
subquery.clear_ordering(True)
subquery.clear_limits()
obj = AggregateQuery(obj.model)
try:
obj.add_subquery(subquery, using=using)
except EmptyResultSet:
# add_subquery evaluates the query, if it's an EmptyResultSet
# then there are can be no results, and therefore there the
# count is obviously 0
return 0
obj.add_count_column()
number = obj.get_aggregation(using=using)[None]
# Apply offset and limit constraints manually, since using LIMIT/OFFSET
# in SQL (in variants that provide them) doesn't change the COUNT
# output.
number = max(0, number - self.low_mark)
if self.high_mark is not None:
number = min(number, self.high_mark - self.low_mark)
return number return number
def has_filters(self): def has_filters(self):
@ -986,9 +965,9 @@ class Query(object):
warnings.warn( warnings.warn(
"add_aggregate() is deprecated. Use add_annotation() instead.", "add_aggregate() is deprecated. Use add_annotation() instead.",
RemovedInDjango20Warning, stacklevel=2) RemovedInDjango20Warning, stacklevel=2)
self.add_annotation(aggregate, model, alias, is_summary) self.add_annotation(aggregate, alias, is_summary)
def add_annotation(self, annotation, model, alias, is_summary): def add_annotation(self, annotation, alias, is_summary):
""" """
Adds a single annotation expression to the Query Adds a single annotation expression to the Query
""" """
@ -1746,52 +1725,6 @@ class Query(object):
for col in annotation.get_group_by_cols(): for col in annotation.get_group_by_cols():
self.group_by.append(col) self.group_by.append(col)
def add_count_column(self):
"""
Converts the query to do count(...) or count(distinct(pk)) in order to
get its size.
"""
summarize = False
if not self.distinct:
if not self.select:
count = Count('*')
summarize = True
else:
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select': %r" % self.select
col = self.select[0].col
if isinstance(col, (tuple, list)):
count = Count(col[1])
else:
count = Count(col)
else:
opts = self.get_meta()
if not self.select:
lookup = self.join((None, opts.db_table, None)), opts.pk.column
count = Count(lookup[1], distinct=True)
summarize = True
else:
# Because of SQL portability issues, multi-column, distinct
# counts need a sub-query -- see get_count() for details.
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select'."
col = self.select[0].col
if isinstance(col, (tuple, list)):
count = Count(col[1], distinct=True)
else:
count = Count(col, distinct=True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
# Set only aggregate to be the count column.
# Clear out the select cache to reflect the new unmasked annotations.
count = count.resolve_expression(self, summarize=summarize)
self._annotations = {None: count}
self.set_annotation_mask(None)
self.group_by = None
def add_select_related(self, fields): def add_select_related(self, fields):
""" """
Sets up the select_related data structure so that we only select Sets up the select_related data structure so that we only select