Fixed #24268 -- removed Query.having

Instead of splitting filter clauses to where and having parts before
adding them to query.where or query.having, add all filter clauses to
query.where, and when compiling the query split the where to having and
where parts.
This commit is contained in:
Anssi Kääriäinen 2014-12-23 15:16:56 +02:00 committed by Tim Graham
parent 2be621e44c
commit afe0bb7b13
6 changed files with 75 additions and 90 deletions

View File

@ -87,6 +87,10 @@ class Transform(RegisterLookupMixin):
bilateral_transforms.append((self.__class__, self.init_lookups)) bilateral_transforms.append((self.__class__, self.init_lookups))
return bilateral_transforms return bilateral_transforms
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate
class Lookup(RegisterLookupMixin): class Lookup(RegisterLookupMixin):
lookup_name = None lookup_name = None
@ -189,6 +193,10 @@ class Lookup(RegisterLookupMixin):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
raise NotImplementedError raise NotImplementedError
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
class BuiltinLookup(Lookup): class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):

View File

@ -35,6 +35,8 @@ class QueryWrapper(object):
A type that indicates the contents are an SQL fragment and the associate A type that indicates the contents are an SQL fragment and the associate
parameters. Can be used to pass opaque data to a where-clause, for example. parameters. Can be used to pass opaque data to a where-clause, for example.
""" """
contains_aggregate = False
def __init__(self, sql, params): def __init__(self, sql, params):
self.data = sql, list(params) self.data = sql, list(params)

View File

@ -46,6 +46,7 @@ class SQLCompiler(object):
""" """
self.setup_query() self.setup_query()
order_by = self.get_order_by() order_by = self.get_order_by()
self.where, self.having = self.query.where.split_having()
extra_select = self.get_extra_select(order_by, self.select) extra_select = self.get_extra_select(order_by, self.select)
group_by = self.get_group_by(self.select + extra_select, order_by) group_by = self.get_group_by(self.select + extra_select, order_by)
return extra_select, order_by, group_by return extra_select, order_by, group_by
@ -116,12 +117,12 @@ class SQLCompiler(object):
if is_ref: if is_ref:
continue continue
expressions.extend(expr.get_source_expressions()) expressions.extend(expr.get_source_expressions())
having = self.query.having.get_group_by_cols() having_group_by = self.having.get_group_by_cols() if self.having else ()
for expr in having: for expr in having_group_by:
expressions.append(expr) expressions.append(expr)
result = [] result = []
seen = set() seen = set()
expressions = self.collapse_group_by(expressions, having) expressions = self.collapse_group_by(expressions, having_group_by)
for expr in expressions: for expr in expressions:
sql, params = self.compile(expr) sql, params = self.compile(expr)
@ -355,11 +356,8 @@ class SQLCompiler(object):
If 'with_limits' is False, any limit/offset information is not included If 'with_limits' is False, any limit/offset information is not included
in the query. in the query.
""" """
# After executing the query, we must get rid of any joins the query if with_limits and self.query.low_mark == self.query.high_mark:
# setup created. So, take note of alias counts before the query ran. return '', ()
# However we do not want to get rid of stuff done in pre_sql_setup(),
# as the pre_sql_setup will modify query state in a way that forbids
# another run of it.
self.subquery = subquery self.subquery = subquery
refcounts_before = self.query.alias_refcount.copy() refcounts_before = self.query.alias_refcount.copy()
try: try:
@ -372,8 +370,8 @@ class SQLCompiler(object):
# docstring of get_from_clause() for details. # docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause() from_, f_params = self.get_from_clause()
where, w_params = self.compile(self.query.where) where, w_params = self.compile(self.where) if self.where is not None else ("", [])
having, h_params = self.compile(self.query.having) having, h_params = self.compile(self.having) if self.having is not None else ("", [])
params = [] params = []
result = ['SELECT'] result = ['SELECT']

View File

@ -25,7 +25,7 @@ from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
from django.db.models.sql.datastructures import ( from django.db.models.sql.datastructures import (
EmptyResultSet, Empty, MultiJoin, Join, BaseTable) EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
from django.db.models.sql.where import (WhereNode, EverythingNode, from django.db.models.sql.where import (WhereNode, EverythingNode,
ExtraWhere, AND, OR, EmptyWhere) ExtraWhere, AND, OR, NothingNode)
from django.utils import six from django.utils import six
from django.utils.deprecation import RemovedInDjango20Warning from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -141,7 +141,6 @@ class Query(object):
# - True: group by all select fields of the model # - True: group by all select fields of the model
# See compiler.get_group_by() for details. # See compiler.get_group_by() for details.
self.group_by = None self.group_by = None
self.having = where()
self.order_by = [] self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False self.distinct = False
@ -268,7 +267,6 @@ class Query(object):
obj.group_by = True obj.group_by = True
else: else:
obj.group_by = self.group_by[:] obj.group_by = self.group_by[:]
obj.having = self.having.clone()
obj.order_by = self.order_by[:] obj.order_by = self.order_by[:]
obj.low_mark, obj.high_mark = self.low_mark, self.high_mark obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
obj.distinct = self.distinct obj.distinct = self.distinct
@ -449,7 +447,7 @@ class Query(object):
return number return number
def has_filters(self): def has_filters(self):
return self.where or self.having return self.where
def has_results(self, using): def has_results(self, using):
q = self.clone() q = self.clone()
@ -770,9 +768,8 @@ class Query(object):
else: else:
return col.relabeled_clone(change_map) return col.relabeled_clone(change_map)
# 1. Update references in "select" (normal columns plus aliases), # 1. Update references in "select" (normal columns plus aliases),
# "group by", "where" and "having". # "group by" and "where".
self.where.relabel_aliases(change_map) self.where.relabel_aliases(change_map)
self.having.relabel_aliases(change_map)
if isinstance(self.group_by, list): if isinstance(self.group_by, list):
self.group_by = [relabel_column(col) for col in self.group_by] self.group_by = [relabel_column(col) for col in self.group_by]
self.select = [col.relabeled_clone(change_map) for col in self.select] self.select = [col.relabeled_clone(change_map) for col in self.select]
@ -1093,7 +1090,7 @@ class Query(object):
""" """
Builds a WhereNode for a single filter clause, but doesn't add it Builds a WhereNode for a single filter clause, but doesn't add it
to this Query. Query.add_q() will then add this filter to the where to this Query. Query.add_q() will then add this filter to the where
or having Node. Node.
The 'branch_negated' tells us if the current branch contains any The 'branch_negated' tells us if the current branch contains any
negations. This will be used to determine if subqueries are needed. negations. This will be used to determine if subqueries are needed.
@ -1197,59 +1194,11 @@ class Query(object):
def add_filter(self, filter_clause): def add_filter(self, filter_clause):
self.add_q(Q(**{filter_clause[0]: filter_clause[1]})) self.add_q(Q(**{filter_clause[0]: filter_clause[1]}))
def need_having(self, obj):
"""
Returns whether or not all elements of this q_object need to be put
together in the HAVING clause.
"""
if not self._annotations:
return False
if hasattr(obj, 'refs_aggregate'):
return obj.refs_aggregate(self.annotations)[0]
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
or (hasattr(obj[1], 'refs_aggregate')
and obj[1].refs_aggregate(self.annotations)[0]))
def split_having_parts(self, q_object, negated=False):
"""
Returns a list of q_objects which need to go into the having clause
instead of the where clause. Removes the splitted out nodes from the
given q_object. Note that the q_object is altered, so cloning it is
needed.
"""
having_parts = []
for c in q_object.children[:]:
# When constructing the having nodes we need to take care to
# preserve the negation status from the upper parts of the tree
if isinstance(c, Node):
# For each negated child, flip the in_negated flag.
in_negated = c.negated ^ negated
if c.connector == OR and self.need_having(c):
# A subtree starting from OR clause must go into having in
# whole if any part of that tree references an aggregate.
q_object.children.remove(c)
having_parts.append(c)
c.negated = in_negated
else:
having_parts.extend(
self.split_having_parts(c, in_negated)[1])
elif self.need_having(c):
q_object.children.remove(c)
new_q = self.where_class(children=[c], negated=negated)
having_parts.append(new_q)
return q_object, having_parts
def add_q(self, q_object): def add_q(self, q_object):
""" """
A preprocessor for the internal _add_q(). Responsible for A preprocessor for the internal _add_q(). Responsible for doing final
splitting the given q_object into where and having parts and join promotion.
setting up some internal variables.
""" """
if not self.need_having(q_object):
where_part, having_parts = q_object, []
else:
where_part, having_parts = self.split_having_parts(
q_object.clone(), q_object.negated)
# For join promotion this case is doing an AND for the added q_object # For join promotion this case is doing an AND for the added q_object
# and existing conditions. So, any existing inner join forces the join # and existing conditions. So, any existing inner join forces the join
# type to remain inner. Existing outer joins can however be demoted. # type to remain inner. Existing outer joins can however be demoted.
@ -1258,11 +1207,8 @@ class Query(object):
# So, demotion is OK. # So, demotion is OK.
existing_inner = set( existing_inner = set(
(a for a in self.alias_map if self.alias_map[a].join_type == INNER)) (a for a in self.alias_map if self.alias_map[a].join_type == INNER))
clause, require_inner = self._add_q(where_part, self.used_aliases) clause, require_inner = self._add_q(q_object, self.used_aliases)
self.where.add(clause, AND) self.where.add(clause, AND)
for hp in having_parts:
clause, _ = self._add_q(hp, self.used_aliases)
self.having.add(clause, AND)
self.demote_joins(existing_inner) self.demote_joins(existing_inner)
def _add_q(self, q_object, used_aliases, branch_negated=False, def _add_q(self, q_object, used_aliases, branch_negated=False,
@ -1557,11 +1503,10 @@ class Query(object):
return condition, needed_inner return condition, needed_inner
def set_empty(self): def set_empty(self):
self.where = EmptyWhere() self.where.add(NothingNode(), AND)
self.having = EmptyWhere()
def is_empty(self): def is_empty(self):
return isinstance(self.where, EmptyWhere) or isinstance(self.having, EmptyWhere) return any(isinstance(c, NothingNode) for c in self.where.children)
def set_limits(self, low=None, high=None): def set_limits(self, low=None, high=None):
""" """

View File

@ -54,10 +54,8 @@ class DeleteQuery(Query):
self.get_initial_alias() self.get_initial_alias()
innerq_used_tables = [t for t in innerq.tables innerq_used_tables = [t for t in innerq.tables
if innerq.alias_refcount[t]] if innerq.alias_refcount[t]]
if ((not innerq_used_tables or innerq_used_tables == self.tables) if not innerq_used_tables or innerq_used_tables == self.tables:
and not len(innerq.having)): # There is only the base table in use in the query.
# There is only the base table in use in the query, and there is
# no aggregate filtering going on.
self.where = innerq.where self.where = innerq.where
else: else:
pk = query.model._meta.pk pk = query.model._meta.pk

View File

@ -1,9 +1,10 @@
""" """
Code to manage the creation and SQL rendering of 'where' constraints. Code to manage the creation and SQL rendering of 'where' constraints.
""" """
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
from django.utils.functional import cached_property
from django.utils import tree from django.utils import tree
from django.utils.functional import cached_property
# Connection types # Connection types
@ -37,6 +38,37 @@ class WhereNode(tree.Node):
""" """
default = AND default = AND
def split_having(self, negated=False):
"""
Returns two possibly None nodes: one for those parts of self that
should be pushed to having and one for those parts of self
that should be in where.
"""
in_negated = negated ^ self.negated
# If the effective connector is OR and this node contains an aggregate,
# then we need to push the whole branch to HAVING clause.
may_need_split = (
(in_negated and self.connector == AND) or
(not in_negated and self.connector == OR))
if may_need_split and self.contains_aggregate:
return None, self
where_parts = []
having_parts = []
for c in self.children:
if hasattr(c, 'split_having'):
where_part, having_part = c.split_having(in_negated)
if where_part:
where_parts.append(where_part)
if having_part:
having_parts.append(having_part)
elif c.contains_aggregate:
having_parts.append(c)
else:
where_parts.append(c)
having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None
where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None
return where_node, having_node
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
""" """
Returns the SQL version of the where clause and the value to be Returns the SQL version of the where clause and the value to be
@ -145,27 +177,20 @@ class WhereNode(tree.Node):
@classmethod @classmethod
def _contains_aggregate(cls, obj): def _contains_aggregate(cls, obj):
if not isinstance(obj, tree.Node): if isinstance(obj, tree.Node):
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) return any(cls._contains_aggregate(c) for c in obj.children)
return any(cls._contains_aggregate(c) for c in obj.children) return obj.contains_aggregate
@cached_property @cached_property
def contains_aggregate(self): def contains_aggregate(self):
return self._contains_aggregate(self) return self._contains_aggregate(self)
class EmptyWhere(WhereNode):
def add(self, data, connector):
return
def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet
class EverythingNode(object): class EverythingNode(object):
""" """
A node that matches everything. A node that matches everything.
""" """
contains_aggregate = False
def as_sql(self, compiler=None, connection=None): def as_sql(self, compiler=None, connection=None):
return '', [] return '', []
@ -175,11 +200,16 @@ class NothingNode(object):
""" """
A node that matches nothing. A node that matches nothing.
""" """
contains_aggregate = False
def as_sql(self, compiler=None, connection=None): def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
class ExtraWhere(object): class ExtraWhere(object):
# The contents are a black box - assume no aggregates are used.
contains_aggregate = False
def __init__(self, sqls, params): def __init__(self, sqls, params):
self.sqls = sqls self.sqls = sqls
self.params = params self.params = params
@ -190,6 +220,10 @@ class ExtraWhere(object):
class SubqueryConstraint(object): class SubqueryConstraint(object):
# Even if aggregates would be used in a subquery, the outer query isn't
# interested about those.
contains_aggregate = False
def __init__(self, alias, columns, targets, query_object): def __init__(self, alias, columns, targets, query_object):
self.alias = alias self.alias = alias
self.columns = columns self.columns = columns