diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 7610c0dde4..f3bb0fcd33 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -87,6 +87,10 @@ class Transform(RegisterLookupMixin): bilateral_transforms.append((self.__class__, self.init_lookups)) return bilateral_transforms + @cached_property + def contains_aggregate(self): + return self.lhs.contains_aggregate + class Lookup(RegisterLookupMixin): lookup_name = None @@ -189,6 +193,10 @@ class Lookup(RegisterLookupMixin): def as_sql(self, compiler, connection): raise NotImplementedError + @cached_property + def contains_aggregate(self): + return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) + class BuiltinLookup(Lookup): def process_lhs(self, compiler, connection, lhs=None): diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 7751f988db..1d6e8f731b 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -35,6 +35,8 @@ class QueryWrapper(object): A type that indicates the contents are an SQL fragment and the associate parameters. Can be used to pass opaque data to a where-clause, for example. """ + contains_aggregate = False + def __init__(self, sql, params): self.data = sql, list(params) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 8f3af37031..210ca742f6 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -46,6 +46,7 @@ class SQLCompiler(object): """ self.setup_query() order_by = self.get_order_by() + self.where, self.having = self.query.where.split_having() extra_select = self.get_extra_select(order_by, self.select) group_by = self.get_group_by(self.select + extra_select, order_by) return extra_select, order_by, group_by @@ -116,12 +117,12 @@ class SQLCompiler(object): if is_ref: continue expressions.extend(expr.get_source_expressions()) - having = self.query.having.get_group_by_cols() - for expr in having: + having_group_by = self.having.get_group_by_cols() if self.having else () + for expr in having_group_by: expressions.append(expr) result = [] seen = set() - expressions = self.collapse_group_by(expressions, having) + expressions = self.collapse_group_by(expressions, having_group_by) for expr in expressions: sql, params = self.compile(expr) @@ -355,11 +356,8 @@ class SQLCompiler(object): If 'with_limits' is False, any limit/offset information is not included in the query. """ - # After executing the query, we must get rid of any joins the query - # setup created. So, take note of alias counts before the query ran. - # However we do not want to get rid of stuff done in pre_sql_setup(), - # as the pre_sql_setup will modify query state in a way that forbids - # another run of it. + if with_limits and self.query.low_mark == self.query.high_mark: + return '', () self.subquery = subquery refcounts_before = self.query.alias_refcount.copy() try: @@ -372,8 +370,8 @@ class SQLCompiler(object): # docstring of get_from_clause() for details. from_, f_params = self.get_from_clause() - where, w_params = self.compile(self.query.where) - having, h_params = self.compile(self.query.having) + where, w_params = self.compile(self.where) if self.where is not None else ("", []) + having, h_params = self.compile(self.having) if self.having is not None else ("", []) params = [] result = ['SELECT'] diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 6b50bb8bf6..719ef0f572 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -25,7 +25,7 @@ from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, from django.db.models.sql.datastructures import ( EmptyResultSet, Empty, MultiJoin, Join, BaseTable) from django.db.models.sql.where import (WhereNode, EverythingNode, - ExtraWhere, AND, OR, EmptyWhere) + ExtraWhere, AND, OR, NothingNode) from django.utils import six from django.utils.deprecation import RemovedInDjango20Warning from django.utils.encoding import force_text @@ -141,7 +141,6 @@ class Query(object): # - True: group by all select fields of the model # See compiler.get_group_by() for details. self.group_by = None - self.having = where() self.order_by = [] self.low_mark, self.high_mark = 0, None # Used for offset/limit self.distinct = False @@ -268,7 +267,6 @@ class Query(object): obj.group_by = True else: obj.group_by = self.group_by[:] - obj.having = self.having.clone() obj.order_by = self.order_by[:] obj.low_mark, obj.high_mark = self.low_mark, self.high_mark obj.distinct = self.distinct @@ -449,7 +447,7 @@ class Query(object): return number def has_filters(self): - return self.where or self.having + return self.where def has_results(self, using): q = self.clone() @@ -770,9 +768,8 @@ class Query(object): else: return col.relabeled_clone(change_map) # 1. Update references in "select" (normal columns plus aliases), - # "group by", "where" and "having". + # "group by" and "where". self.where.relabel_aliases(change_map) - self.having.relabel_aliases(change_map) if isinstance(self.group_by, list): self.group_by = [relabel_column(col) for col in self.group_by] self.select = [col.relabeled_clone(change_map) for col in self.select] @@ -1093,7 +1090,7 @@ class Query(object): """ Builds a WhereNode for a single filter clause, but doesn't add it to this Query. Query.add_q() will then add this filter to the where - or having Node. + Node. The 'branch_negated' tells us if the current branch contains any negations. This will be used to determine if subqueries are needed. @@ -1197,59 +1194,11 @@ class Query(object): def add_filter(self, filter_clause): self.add_q(Q(**{filter_clause[0]: filter_clause[1]})) - def need_having(self, obj): - """ - Returns whether or not all elements of this q_object need to be put - together in the HAVING clause. - """ - if not self._annotations: - return False - if hasattr(obj, 'refs_aggregate'): - return obj.refs_aggregate(self.annotations)[0] - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0] - or (hasattr(obj[1], 'refs_aggregate') - and obj[1].refs_aggregate(self.annotations)[0])) - - def split_having_parts(self, q_object, negated=False): - """ - Returns a list of q_objects which need to go into the having clause - instead of the where clause. Removes the splitted out nodes from the - given q_object. Note that the q_object is altered, so cloning it is - needed. - """ - having_parts = [] - for c in q_object.children[:]: - # When constructing the having nodes we need to take care to - # preserve the negation status from the upper parts of the tree - if isinstance(c, Node): - # For each negated child, flip the in_negated flag. - in_negated = c.negated ^ negated - if c.connector == OR and self.need_having(c): - # A subtree starting from OR clause must go into having in - # whole if any part of that tree references an aggregate. - q_object.children.remove(c) - having_parts.append(c) - c.negated = in_negated - else: - having_parts.extend( - self.split_having_parts(c, in_negated)[1]) - elif self.need_having(c): - q_object.children.remove(c) - new_q = self.where_class(children=[c], negated=negated) - having_parts.append(new_q) - return q_object, having_parts - def add_q(self, q_object): """ - A preprocessor for the internal _add_q(). Responsible for - splitting the given q_object into where and having parts and - setting up some internal variables. + A preprocessor for the internal _add_q(). Responsible for doing final + join promotion. """ - if not self.need_having(q_object): - where_part, having_parts = q_object, [] - else: - where_part, having_parts = self.split_having_parts( - q_object.clone(), q_object.negated) # For join promotion this case is doing an AND for the added q_object # and existing conditions. So, any existing inner join forces the join # type to remain inner. Existing outer joins can however be demoted. @@ -1258,11 +1207,8 @@ class Query(object): # So, demotion is OK. existing_inner = set( (a for a in self.alias_map if self.alias_map[a].join_type == INNER)) - clause, require_inner = self._add_q(where_part, self.used_aliases) + clause, require_inner = self._add_q(q_object, self.used_aliases) self.where.add(clause, AND) - for hp in having_parts: - clause, _ = self._add_q(hp, self.used_aliases) - self.having.add(clause, AND) self.demote_joins(existing_inner) def _add_q(self, q_object, used_aliases, branch_negated=False, @@ -1557,11 +1503,10 @@ class Query(object): return condition, needed_inner def set_empty(self): - self.where = EmptyWhere() - self.having = EmptyWhere() + self.where.add(NothingNode(), AND) def is_empty(self): - return isinstance(self.where, EmptyWhere) or isinstance(self.having, EmptyWhere) + return any(isinstance(c, NothingNode) for c in self.where.children) def set_limits(self, low=None, high=None): """ diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index be7438b2a7..727eabc5f5 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -54,10 +54,8 @@ class DeleteQuery(Query): self.get_initial_alias() innerq_used_tables = [t for t in innerq.tables if innerq.alias_refcount[t]] - if ((not innerq_used_tables or innerq_used_tables == self.tables) - and not len(innerq.having)): - # There is only the base table in use in the query, and there is - # no aggregate filtering going on. + if not innerq_used_tables or innerq_used_tables == self.tables: + # There is only the base table in use in the query. self.where = innerq.where else: pk = query.model._meta.pk diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 3b894b8aed..6472c47f4f 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -1,9 +1,10 @@ """ Code to manage the creation and SQL rendering of 'where' constraints. """ + from django.db.models.sql.datastructures import EmptyResultSet -from django.utils.functional import cached_property from django.utils import tree +from django.utils.functional import cached_property # Connection types @@ -37,6 +38,37 @@ class WhereNode(tree.Node): """ default = AND + def split_having(self, negated=False): + """ + Returns two possibly None nodes: one for those parts of self that + should be pushed to having and one for those parts of self + that should be in where. + """ + in_negated = negated ^ self.negated + # If the effective connector is OR and this node contains an aggregate, + # then we need to push the whole branch to HAVING clause. + may_need_split = ( + (in_negated and self.connector == AND) or + (not in_negated and self.connector == OR)) + if may_need_split and self.contains_aggregate: + return None, self + where_parts = [] + having_parts = [] + for c in self.children: + if hasattr(c, 'split_having'): + where_part, having_part = c.split_having(in_negated) + if where_part: + where_parts.append(where_part) + if having_part: + having_parts.append(having_part) + elif c.contains_aggregate: + having_parts.append(c) + else: + where_parts.append(c) + having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None + where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None + return where_node, having_node + def as_sql(self, compiler, connection): """ Returns the SQL version of the where clause and the value to be @@ -145,27 +177,20 @@ class WhereNode(tree.Node): @classmethod def _contains_aggregate(cls, obj): - if not isinstance(obj, tree.Node): - return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) - return any(cls._contains_aggregate(c) for c in obj.children) + if isinstance(obj, tree.Node): + return any(cls._contains_aggregate(c) for c in obj.children) + return obj.contains_aggregate @cached_property def contains_aggregate(self): return self._contains_aggregate(self) -class EmptyWhere(WhereNode): - def add(self, data, connector): - return - - def as_sql(self, compiler=None, connection=None): - raise EmptyResultSet - - class EverythingNode(object): """ A node that matches everything. """ + contains_aggregate = False def as_sql(self, compiler=None, connection=None): return '', [] @@ -175,11 +200,16 @@ class NothingNode(object): """ A node that matches nothing. """ + contains_aggregate = False + def as_sql(self, compiler=None, connection=None): raise EmptyResultSet class ExtraWhere(object): + # The contents are a black box - assume no aggregates are used. + contains_aggregate = False + def __init__(self, sqls, params): self.sqls = sqls self.params = params @@ -190,6 +220,10 @@ class ExtraWhere(object): class SubqueryConstraint(object): + # Even if aggregates would be used in a subquery, the outer query isn't + # interested about those. + contains_aggregate = False + def __init__(self, alias, columns, targets, query_object): self.alias = alias self.columns = columns