diff --git a/django/contrib/gis/db/models/sql/where.py b/django/contrib/gis/db/models/sql/where.py index 6ef34db0a3..c29533bf84 100644 --- a/django/contrib/gis/db/models/sql/where.py +++ b/django/contrib/gis/db/models/sql/where.py @@ -32,13 +32,14 @@ class GeoWhereNode(WhereNode): Used to represent the SQL where-clause for spatial databases -- these are tied to the GeoQuery class that created it. """ - def add(self, data, connector): + + def _prepare_data(self, data): if isinstance(data, (list, tuple)): obj, lookup_type, value = data if ( isinstance(obj, Constraint) and isinstance(obj.field, GeometryField) ): data = (GeoConstraint(obj), lookup_type, value) - super(GeoWhereNode, self).add(data, connector) + return super(GeoWhereNode, self)._prepare_data(data) def make_atom(self, child, qn, connection): lvalue, lookup_type, value_annot, params_or_value = child diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index a2349cf5c6..b89db1c563 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -1,6 +1,19 @@ """ Classes to represent the definitions of aggregate functions. """ +from django.db.models.constants import LOOKUP_SEP + +def refs_aggregate(lookup_parts, aggregates): + """ + A little helper method to check if the lookup_parts contains references + to the given aggregates set. Because the LOOKUP_SEP is contained in the + default annotation names we must check each prefix of the lookup_parts + for match. + """ + for i in range(len(lookup_parts) + 1): + if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates: + return True + return False class Aggregate(object): """ diff --git a/django/db/models/constants.py b/django/db/models/constants.py index 629497eb3d..a7e6c252d9 100644 --- a/django/db/models/constants.py +++ b/django/db/models/constants.py @@ -4,4 +4,3 @@ Constants used across the ORM in general. # Separator used to split filter strings apart. LOOKUP_SEP = '__' - diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 3566d777c6..6e0f3c434e 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1,4 +1,7 @@ import datetime + +from django.db.models.aggregates import refs_aggregate +from django.db.models.constants import LOOKUP_SEP from django.utils import tree class ExpressionNode(tree.Node): @@ -37,6 +40,18 @@ class ExpressionNode(tree.Node): obj.add(other, connector) return obj + def contains_aggregate(self, existing_aggregates): + if self.children: + return any(child.contains_aggregate(existing_aggregates) + for child in self.children + if hasattr(child, 'contains_aggregate')) + else: + return refs_aggregate(self.name.split(LOOKUP_SEP), + existing_aggregates) + + def prepare_database_save(self, unused): + return self + ################### # VISITOR METHODS # ################### @@ -113,9 +128,6 @@ class ExpressionNode(tree.Node): "Use .bitand() and .bitor() for bitwise logical operations." ) - def prepare_database_save(self, unused): - return self - class F(ExpressionNode): """ An expression representing the value of the given field. diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index c82cc45617..a33c44833c 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -47,6 +47,7 @@ class Q(tree.Node): if not isinstance(other, Q): raise TypeError(other) obj = type(self)() + obj.connector = conn obj.add(self, conn) obj.add(other, conn) return obj @@ -63,6 +64,16 @@ class Q(tree.Node): obj.negate() return obj + def clone(self): + clone = self.__class__._new_instance( + children=[], connector=self.connector, negated=self.negated) + for child in self.children: + if hasattr(child, 'clone'): + clone.children.append(child.clone()) + else: + clone.children.append(child) + return clone + class DeferredAttribute(object): """ A wrapper for a deferred-loading field. When the value is read from this diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 7ddee9785c..4711ea6e19 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -87,6 +87,7 @@ class SQLCompiler(object): where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) + having_group_by = self.query.having.get_cols() params = [] for val in six.itervalues(self.query.extra_select): params.extend(val[1]) @@ -107,7 +108,7 @@ class SQLCompiler(object): result.append('WHERE %s' % where) params.extend(w_params) - grouping, gb_params = self.get_grouping(ordering_group_by) + grouping, gb_params = self.get_grouping(having_group_by, ordering_group_by) if grouping: if distinct_fields: raise NotImplementedError( @@ -534,7 +535,7 @@ class SQLCompiler(object): first = False return result, from_params - def get_grouping(self, ordering_group_by): + def get_grouping(self, having_group_by, ordering_group_by): """ Returns a tuple representing the SQL elements in the "group by" clause. """ @@ -551,7 +552,7 @@ class SQLCompiler(object): ] select_cols = [] seen = set() - cols = self.query.group_by + select_cols + cols = self.query.group_by + having_group_by + select_cols for col in cols: col_params = () if isinstance(col, (list, tuple)): diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 55ae655cb0..389099161a 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -7,16 +7,14 @@ class SQLEvaluator(object): def __init__(self, expression, query, allow_joins=True, reuse=None): self.expression = expression self.opts = query.get_meta() - self.cols = [] - - self.contains_aggregate = False self.reuse = reuse + self.cols = [] self.expression.prepare(self, query, allow_joins) def relabeled_clone(self, change_map): clone = copy.copy(self) clone.cols = [] - for node, col in self.cols[:]: + for node, col in self.cols: if hasattr(col, 'relabeled_clone'): clone.cols.append((node, col.relabeled_clone(change_map))) else: @@ -24,6 +22,15 @@ class SQLEvaluator(object): (change_map.get(col[0], col[0]), col[1]))) return clone + def get_cols(self): + cols = [] + for node, col in self.cols: + if hasattr(node, 'get_cols'): + cols.extend(node.get_cols()) + elif isinstance(col, tuple): + cols.append(col) + return cols + def prepare(self): return self @@ -44,9 +51,7 @@ class SQLEvaluator(object): raise FieldError("Joined field references are not permitted in this query") field_list = node.name.split(LOOKUP_SEP) - if (len(field_list) == 1 and - node.name in query.aggregate_select.keys()): - self.contains_aggregate = True + if node.name in query.aggregates: self.cols.append((node, query.aggregate_select[node.name])) else: try: diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fa583f6120..2fd67ff0ca 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -15,6 +15,7 @@ from django.utils.tree import Node from django.utils import six from django.db import connections, DEFAULT_DB_ALIAS from django.db.models.constants import LOOKUP_SEP +from django.db.models.aggregates import refs_aggregate from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist from django.db.models.loading import get_model @@ -1004,19 +1005,6 @@ class Query(object): self.unref_alias(alias) self.included_inherited_models = {} - def need_force_having(self, q_object): - """ - Returns whether or not all elements of this q_object need to be put - together in the HAVING clause. - """ - for child in q_object.children: - if isinstance(child, Node): - if self.need_force_having(child): - return True - else: - if child[0].split(LOOKUP_SEP)[0] in self.aggregates: - return True - return False def add_aggregate(self, aggregate, model, alias, is_summary): """ @@ -1065,24 +1053,32 @@ class Query(object): # Add the aggregate to the query aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) - def add_filter(self, filter_expr, connector=AND, negate=False, - can_reuse=None, force_having=False): + def build_filter(self, filter_expr, branch_negated=False, current_negated=False, + can_reuse=None): """ - Add a single filter to the query. The 'filter_expr' is a pair: - (filter_string, value). E.g. ('name__contains', 'fred') + Builds a WhereNode for a single filter clause, but doesn't add it + to this Query. Query.add_q() will then add this filter to the where + or having Node. - If 'negate' is True, this is an exclude() filter. It's important to - note that this method does not negate anything in the where-clause - object when inserting the filter constraints. This is because negated - filters often require multiple calls to add_filter() and the negation - should only happen once. So the caller is responsible for this (the - caller will normally be add_q(), so that as an example). + The 'branch_negated' tells us if the current branch contains any + negations. This will be used to determine if subqueries are needed. - If 'can_reuse' is a set, we are processing a component of a - multi-component filter (e.g. filter(Q1, Q2)). In this case, 'can_reuse' - will be a set of table aliases that can be reused in this filter, even - if we would otherwise force the creation of new aliases for a join - (needed for nested Q-filters). The set is updated by this method. + The 'current_negated' is used to determine if the current filter is + negated or not and this will be used to determine if IS NULL filtering + is needed. + + The difference between current_netageted and branch_negated is that + branch_negated is set on first negation, but current_negated is + flipped for each negation. + + Note that add_filter will not do any negating itself, that is done + upper in the code by add_q(). + + The 'can_reuse' is a set of reusable joins for multijoins. + + The method will create a filter clause that can be added to the current + query. However, if the filter isn't added to the query then the caller + is responsible for unreffing the joins used. """ arg, value = filter_expr parts = arg.split(LOOKUP_SEP) @@ -1091,10 +1087,10 @@ class Query(object): # Work out the lookup type and remove it from the end of 'parts', # if necessary. - lookup_type = 'exact' # Default lookup type + lookup_type = 'exact' # Default lookup type num_parts = len(parts) if (len(parts) > 1 and parts[-1] in self.query_terms - and arg not in self.aggregates): + and arg not in self.aggregates): # Traverse the lookup query to distinguish related fields from # lookup types. lookup_model = self.model @@ -1115,10 +1111,7 @@ class Query(object): lookup_type = parts.pop() break - # By default, this is a WHERE clause. If an aggregate is referenced - # in the value, the filter will be promoted to a HAVING - having_clause = False - + clause = self.where_class() # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. if value is None: @@ -1131,20 +1124,15 @@ class Query(object): elif isinstance(value, ExpressionNode): # If value is a query expression, evaluate it value = SQLEvaluator(value, self, reuse=can_reuse) - having_clause = value.contains_aggregate for alias, aggregate in self.aggregates.items(): if alias in (parts[0], LOOKUP_SEP.join(parts)): - entry = self.where_class() - entry.add((aggregate, lookup_type, value), AND) - if negate: - entry.negate() - self.having.add(entry, connector) - return + clause.add((aggregate, lookup_type, value), AND) + return clause opts = self.get_meta() alias = self.get_initial_alias() - allow_many = not negate + allow_many = not branch_negated try: field, target, opts, join_list, path = self.setup_joins( @@ -1153,11 +1141,10 @@ class Query(object): if can_reuse is not None: can_reuse.update(join_list) except MultiJoin as e: - self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), - can_reuse, e.names_with_path) - return + return self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), + can_reuse, e.names_with_path) - if (lookup_type == 'isnull' and value is True and not negate and + if (lookup_type == 'isnull' and value is True and not current_negated and len(join_list) > 1): # If the comparison is against NULL, we may need to use some left # outer joins when creating the join chain. This is only done when @@ -1169,17 +1156,9 @@ class Query(object): # promotion must happen before join trimming to have the join type # information available when reusing joins. target, alias, join_list = self.trim_joins(target, join_list, path) - - if having_clause or force_having: - if (alias, target.column) not in self.group_by: - self.group_by.append((alias, target.column)) - self.having.add((Constraint(alias, target.column, field), lookup_type, value), - connector) - else: - self.where.add((Constraint(alias, target.column, field), lookup_type, value), - connector) - - if negate: + clause.add((Constraint(alias, target.column, field), lookup_type, value), + AND) + if current_negated and (lookup_type != 'isnull' or value is False): self.promote_joins(join_list) if (lookup_type != 'isnull' and ( self.is_nullable(target) or self.alias_map[join_list[-1]].join_type == self.LOUTER)): @@ -1192,64 +1171,112 @@ class Query(object): # (col IS NULL OR col != someval) # <=> # NOT (col IS NOT NULL AND col = someval). - self.where.add((Constraint(alias, target.column, None), 'isnull', False), AND) + clause.add((Constraint(alias, target.column, None), 'isnull', False), AND) + return clause - def add_q(self, q_object, used_aliases=None, force_having=False): + def add_filter(self, filter_clause): + self.where.add(self.build_filter(filter_clause), 'AND') + + def need_having(self, obj): + """ + Returns whether or not all elements of this q_object need to be put + together in the HAVING clause. + """ + if not isinstance(obj, Node): + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates) + or (hasattr(obj[1], 'contains_aggregate') + and obj[1].contains_aggregate(self.aggregates))) + return any(self.need_having(c) for c in obj.children) + + def split_having_parts(self, q_object, negated=False): + """ + Returns a list of q_objects which need to go into the having clause + instead of the where clause. Removes the splitted out nodes from the + given q_object. Note that the q_object is altered, so cloning it is + needed. + """ + having_parts = [] + for c in q_object.children[:]: + # When constucting the having nodes we need to take care to + # preserve the negation status from the upper parts of the tree + if isinstance(c, Node): + # For each negated child, flip the in_negated flag. + in_negated = c.negated ^ negated + if c.connector == OR and self.need_having(c): + # A subtree starting from OR clause must go into having in + # whole if any part of that tree references an aggregate. + q_object.children.remove(c) + having_parts.append(c) + c.negated = in_negated + else: + having_parts.extend( + self.split_having_parts(c, in_negated)[1]) + elif self.need_having(c): + q_object.children.remove(c) + new_q = self.where_class(children=[c], negated=negated) + having_parts.append(new_q) + return q_object, having_parts + + def add_q(self, q_object): + """ + A preprocessor for the internal _add_q(). Responsible for + splitting the given q_object into where and having parts and + setting up some internal variables. + """ + if not self.need_having(q_object): + where_part, having_parts = q_object, [] + else: + where_part, having_parts = self.split_having_parts( + q_object.clone(), q_object.negated) + used_aliases = self.used_aliases + clause = self._add_q(where_part, used_aliases) + self.where.add(clause, AND) + for hp in having_parts: + clause = self._add_q(hp, used_aliases) + self.having.add(clause, AND) + if self.filter_is_sticky: + self.used_aliases = used_aliases + + def _add_q(self, q_object, used_aliases, branch_negated=False, + current_negated=False): """ Adds a Q-object to the current filter. Can also be used to add anything that has an 'add_to_query()' method. """ - if used_aliases is None: - used_aliases = self.used_aliases - if hasattr(q_object, 'add_to_query'): - # Complex custom objects are responsible for adding themselves. - q_object.add_to_query(self, used_aliases) - else: - if self.where and q_object.connector != AND and len(q_object) > 1: - self.where.start_subtree(AND) - subtree = True + connector = q_object.connector + current_negated = current_negated ^ q_object.negated + branch_negated = branch_negated or q_object.negated + # Note that if the connector happens to match what we have already in + # the tree, the add will be a no-op. + target_clause = self.where_class(connector=connector, + negated=q_object.negated) + + if connector == OR: + alias_usage_counts = dict() + aliases_before = set(self.tables) + for child in q_object.children: + if connector == OR: + refcounts_before = self.alias_refcount.copy() + if isinstance(child, Node): + child_clause = self._add_q( + child, used_aliases, branch_negated, + current_negated) else: - subtree = False - connector = q_object.connector + child_clause = self.build_filter( + child, can_reuse=used_aliases, branch_negated=branch_negated, + current_negated=current_negated) + target_clause.add(child_clause, connector) if connector == OR: - alias_usage_counts = dict() - aliases_before = set(self.tables) - if q_object.connector == OR and not force_having: - force_having = self.need_force_having(q_object) - for child in q_object.children: - if force_having: - self.having.start_subtree(connector) - else: - self.where.start_subtree(connector) - if connector == OR: - refcounts_before = self.alias_refcount.copy() - if isinstance(child, Node): - self.add_q(child, used_aliases, force_having=force_having) - else: - self.add_filter(child, connector, q_object.negated, - can_reuse=used_aliases, force_having=force_having) - if connector == OR: - used = alias_diff(refcounts_before, self.alias_refcount) - for alias in used: - alias_usage_counts[alias] = alias_usage_counts.get(alias, 0) + 1 - if force_having: - self.having.end_subtree() - else: - self.where.end_subtree() + used = alias_diff(refcounts_before, self.alias_refcount) + for alias in used: + alias_usage_counts[alias] = alias_usage_counts.get(alias, 0) + 1 + if connector == OR: + self.promote_disjunction(aliases_before, alias_usage_counts, + len(q_object.children)) + return target_clause - if connector == OR: - self.promote_disjunction(aliases_before, alias_usage_counts, - len(q_object.children)) - if q_object.negated: - self.where.negate() - if subtree: - self.where.end_subtree() - if self.filter_is_sticky: - self.used_aliases = used_aliases - - def names_to_path(self, names, opts, allow_many=False, - allow_explicit_fk=True): + def names_to_path(self, names, opts, allow_many, allow_explicit_fk): """ Walks the names path and turns them PathInfo tuples. Note that a single name in 'names' can generate multiple PathInfos (m2m for @@ -1413,7 +1440,7 @@ class Query(object): """ # Generate the inner query. query = Query(self.model) - query.add_filter(filter_expr) + query.where.add(query.build_filter(filter_expr), AND) query.bump_prefix() query.clear_ordering(True) # Try to have as simple as possible subquery -> trim leading joins from @@ -1443,8 +1470,9 @@ class Query(object): path[paths_in_prefix - len(path)].from_field.name) break trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix) - self.add_filter(('%s__in' % trimmed_prefix, query), negate=True, - can_reuse=can_reuse) + return self.build_filter( + ('%s__in' % trimmed_prefix, query), + current_negated=True, branch_negated=True, can_reuse=can_reuse) def set_empty(self): self.where = EmptyWhere() diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 6aac5c898c..78727e394a 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -45,7 +45,7 @@ class DeleteQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): where = self.where_class() where.add((Constraint(None, field.column, field), 'in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND) + pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND) self.do_query(self.model._meta.db_table, where, using=using) def delete_qs(self, query, using): @@ -117,8 +117,8 @@ class UpdateQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.where = self.where_class() self.where.add((Constraint(None, pk_field.column, pk_field), 'in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), - AND) + pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), + AND) self.get_compiler(using).execute_sql(None) def add_update_values(self, values): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 152a396785..ced5325754 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -46,18 +46,17 @@ class WhereNode(tree.Node): """ default = AND - def add(self, data, connector): + def _prepare_data(self, data): """ - Add a node to the where-tree. If the data is a list or tuple, it is - expected to be of the form (obj, lookup_type, value), where obj is - a Constraint object, and is then slightly munged before being stored - (to avoid storing any reference to field objects). Otherwise, the 'data' - is stored unchanged and can be any class with an 'as_sql()' method. + Prepare data for addition to the tree. If the data is a list or tuple, + it is expected to be of the form (obj, lookup_type, value), where obj + is a Constraint object, and is then slightly munged before being + stored (to avoid storing any reference to field objects). Otherwise, + the 'data' is stored unchanged and can be any class with an 'as_sql()' + method. """ if not isinstance(data, (list, tuple)): - super(WhereNode, self).add(data, connector) - return - + return data obj, lookup_type, value = data if isinstance(value, collections.Iterator): # Consume any generators immediately, so that we can determine @@ -78,9 +77,7 @@ class WhereNode(tree.Node): if hasattr(obj, "prepare"): value = obj.prepare(lookup_type, value) - - super(WhereNode, self).add( - (obj, lookup_type, value_annotation, value), connector) + return (obj, lookup_type, value_annotation, value) def as_sql(self, qn, connection): """ @@ -154,6 +151,18 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params + def get_cols(self): + cols = [] + for child in self.children: + if hasattr(child, 'get_cols'): + cols.extend(child.get_cols()) + else: + if isinstance(child[0], Constraint): + cols.append((child[0].alias, child[0].col)) + if hasattr(child[3], 'get_cols'): + cols.extend(child[3].get_cols()) + return cols + def make_atom(self, child, qn, connection): """ Turn a tuple (Constraint(table_alias, column_name, db_type), @@ -284,7 +293,6 @@ class WhereNode(tree.Node): with empty subtree_parents). Childs must be either (Contraint, lookup, value) tuples, or objects supporting .clone(). """ - assert not self.subtree_parents clone = self.__class__._new_instance( children=[], connector=self.connector, negated=self.negated) for child in self.children: diff --git a/django/utils/tree.py b/django/utils/tree.py index ce490224e0..4152b4600b 100644 --- a/django/utils/tree.py +++ b/django/utils/tree.py @@ -19,14 +19,9 @@ class Node(object): """ Constructs a new Node. If no connector is given, the default will be used. - - Warning: You probably don't want to pass in the 'negated' parameter. It - is NOT the same as constructing a node and calling negate() on the - result. """ self.children = children and children[:] or [] self.connector = connector or self.default - self.subtree_parents = [] self.negated = negated # We need this because of django.db.models.query_utils.Q. Q. __init__() is @@ -59,7 +54,6 @@ class Node(object): obj = Node(connector=self.connector, negated=self.negated) obj.__class__ = self.__class__ obj.children = copy.deepcopy(self.children, memodict) - obj.subtree_parents = copy.deepcopy(self.subtree_parents, memodict) return obj def __len__(self): @@ -83,74 +77,60 @@ class Node(object): """ return other in self.children - def add(self, node, conn_type): + def _prepare_data(self, data): """ - Adds a new node to the tree. If the conn_type is the same as the root's - current connector type, the node is added to the first level. - Otherwise, the whole tree is pushed down one level and a new root - connector is created, connecting the existing tree and the new node. + A subclass hook for doing subclass specific transformations of the + given data on combine() or add(). """ - if node in self.children and conn_type == self.connector: - return - if len(self.children) < 2: - self.connector = conn_type + return data + + def add(self, data, conn_type, squash=True): + """ + Combines this tree and the data represented by data using the + connector conn_type. The combine is done by squashing the node other + away if possible. + + This tree (self) will never be pushed to a child node of the + combined tree, nor will the connector or negated properties change. + + The function returns a node which can be used in place of data + regardless if the node other got squashed or not. + + If `squash` is False the data is prepared and added as a child to + this tree without further logic. + """ + if data in self.children: + return data + data = self._prepare_data(data) + if not squash: + self.children.append(data) + return data if self.connector == conn_type: - if isinstance(node, Node) and (node.connector == conn_type or - len(node) == 1): - self.children.extend(node.children) + # We can reuse self.children to append or squash the node other. + if (isinstance(data, Node) and not data.negated + and (data.connector == conn_type or len(data) == 1)): + # We can squash the other node's children directly into this + # node. We are just doing (AB)(CD) == (ABCD) here, with the + # addition that if the length of the other node is 1 the + # connector doesn't matter. However, for the len(self) == 1 + # case we don't want to do the squashing, as it would alter + # self.connector. + self.children.extend(data.children) + return self else: - self.children.append(node) + # We could use perhaps additional logic here to see if some + # children could be used for pushdown here. + self.children.append(data) + return data else: obj = self._new_instance(self.children, self.connector, - self.negated) + self.negated) self.connector = conn_type - self.children = [obj, node] + self.children = [obj, data] + return data def negate(self): """ - Negate the sense of the root connector. This reorganises the children - so that the current node has a single child: a negated node containing - all the previous children. This slightly odd construction makes adding - new children behave more intuitively. - - Interpreting the meaning of this negate is up to client code. This - method is useful for implementing "not" arrangements. + Negate the sense of the root connector. """ - self.children = [self._new_instance(self.children, self.connector, - not self.negated)] - self.connector = self.default - - def start_subtree(self, conn_type): - """ - Sets up internal state so that new nodes are added to a subtree of the - current node. The conn_type specifies how the sub-tree is joined to the - existing children. - """ - if len(self.children) == 1: - self.connector = conn_type - elif self.connector != conn_type: - self.children = [self._new_instance(self.children, self.connector, - self.negated)] - self.connector = conn_type - self.negated = False - - self.subtree_parents.append(self.__class__(self.children, - self.connector, self.negated)) - self.connector = self.default - self.negated = False - self.children = [] - - def end_subtree(self): - """ - Closes off the most recently unmatched start_subtree() call. - - This puts the current state into a node of the parent tree and returns - the current instances state to be the parent. - """ - obj = self.subtree_parents.pop() - node = self.__class__(self.children, self.connector) - self.connector = obj.connector - self.negated = obj.negated - self.children = obj.children - self.children.append(node) - + self.negated = not self.negated diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index bb1eb59b07..20677270f5 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -10,6 +10,7 @@ from django.contrib.contenttypes.models import ContentType from django.db.models import Count, Max, Avg, Sum, StdDev, Variance, F, Q from django.test import TestCase, Approximate, skipUnlessDBFeature from django.utils import six +from django.utils.unittest import expectedFailure from .models import (Author, Book, Publisher, Clues, Entries, HardbackBook, ItemTag, WithManualPK) @@ -472,7 +473,7 @@ class AggregationTests(TestCase): # Regression for #15709 - Ensure each group_by field only exists once # per query qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by() - grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([]) + grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([], []) self.assertEqual(len(grouping), 1) def test_duplicate_alias(self): @@ -847,14 +848,14 @@ class AggregationTests(TestCase): # The name of the explicitly provided annotation name in this case # poses no problem - qs = Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2) + qs = Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2).order_by('name') self.assertQuerysetEqual( qs, ['Peter Norvig'], lambda b: b.name ) # Neither in this case - qs = Author.objects.annotate(book_count=Count('book')).filter(book_count=2) + qs = Author.objects.annotate(book_count=Count('book')).filter(book_count=2).order_by('name') self.assertQuerysetEqual( qs, ['Peter Norvig'], @@ -862,7 +863,7 @@ class AggregationTests(TestCase): ) # This case used to fail because the ORM couldn't resolve the # automatically generated annotation name `book__count` - qs = Author.objects.annotate(Count('book')).filter(book__count=2) + qs = Author.objects.annotate(Count('book')).filter(book__count=2).order_by('name') self.assertQuerysetEqual( qs, ['Peter Norvig'], @@ -1020,3 +1021,83 @@ class AggregationTests(TestCase): ('The Definitive Guide to Django: Web Development Done Right', 0) ] ) + + def test_negated_aggregation(self): + expected_results = Author.objects.exclude( + pk__in=Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2) + ).order_by('name') + expected_results = [a.name for a in expected_results] + qs = Author.objects.annotate(book_cnt=Count('book')).exclude( + Q(book_cnt=2), Q(book_cnt=2)).order_by('name') + self.assertQuerysetEqual( + qs, + expected_results, + lambda b: b.name + ) + expected_results = Author.objects.exclude( + pk__in=Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2) + ).order_by('name') + expected_results = [a.name for a in expected_results] + qs = Author.objects.annotate(book_cnt=Count('book')).exclude(Q(book_cnt=2)|Q(book_cnt=2)).order_by('name') + self.assertQuerysetEqual( + qs, + expected_results, + lambda b: b.name + ) + + def test_name_filters(self): + qs = Author.objects.annotate(Count('book')).filter( + Q(book__count__exact=2)|Q(name='Adrian Holovaty') + ).order_by('name') + self.assertQuerysetEqual( + qs, + ['Adrian Holovaty', 'Peter Norvig'], + lambda b: b.name + ) + + def test_name_expressions(self): + # Test that aggregates are spotted corretly from F objects. + # Note that Adrian's age is 34 in the fixtures, and he has one book + # so both conditions match one author. + qs = Author.objects.annotate(Count('book')).filter( + Q(name='Peter Norvig')|Q(age=F('book__count') + 33) + ).order_by('name') + self.assertQuerysetEqual( + qs, + ['Adrian Holovaty', 'Peter Norvig'], + lambda b: b.name + ) + + def test_ticket_11293(self): + q1 = Q(price__gt=50) + q2 = Q(authors__count__gt=1) + query = Book.objects.annotate(Count('authors')).filter( + q1 | q2).order_by('pk') + self.assertQuerysetEqual( + query, [1, 4, 5, 6], + lambda b: b.pk) + + def test_ticket_11293_q_immutable(self): + """ + Check that splitting a q object to parts for where/having doesn't alter + the original q-object. + """ + q1 = Q(isbn='') + q2 = Q(authors__count__gt=1) + query = Book.objects.annotate(Count('authors')) + query.filter(q1 | q2) + self.assertEqual(len(q2.children), 1) + + def test_fobj_group_by(self): + """ + Check that an F() object referring to related column works correctly + in group by. + """ + qs = Book.objects.annotate( + acount=Count('authors') + ).filter( + acount=F('publisher__num_awards') + ) + self.assertQuerysetEqual( + qs, ['Sams Teach Yourself Django in 24 Hours'], + lambda b: b.name) diff --git a/tests/queries/models.py b/tests/queries/models.py index 6132544c2f..71346d8be9 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -475,3 +475,25 @@ class MyObject(models.Model): parent = models.ForeignKey('self', null=True, blank=True, related_name='children') data = models.CharField(max_length=100) created_at = models.DateTimeField(auto_now_add=True) + +# Models for #17600 regressions +@python_2_unicode_compatible +class Order(models.Model): + id = models.IntegerField(primary_key=True) + + class Meta: + ordering = ('pk', ) + + def __str__(self): + return '%s' % self.pk + +@python_2_unicode_compatible +class OrderItem(models.Model): + order = models.ForeignKey(Order, related_name='items') + status = models.IntegerField() + + class Meta: + ordering = ('pk', ) + + def __str__(self): + return '%s' % self.pk diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 976a0ab05e..31d4b4e1a4 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -23,9 +23,9 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover, Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, - SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, - JobResponsibilities, BaseA, Identifier, Program, Channel, Page, Paragraph, - Chapter, Book, MyObject) + SingleObject, RelatedObject, ModelA, ModelB, ModelC, ModelD, Responsibility, + Job, JobResponsibilities, BaseA, Identifier, Program, Channel, Page, + Paragraph, Chapter, Book, MyObject, Order, OrderItem) class BaseQuerysetTest(TestCase): @@ -834,7 +834,6 @@ class Queries1Tests(BaseQuerysetTest): Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)), ['', ''] ) - xx.delete() q = Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)).query self.assertEqual( len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]), @@ -880,7 +879,6 @@ class Queries1Tests(BaseQuerysetTest): Item.objects.filter(Q(tags__name='t4')), [repr(i) for i in Item.objects.filter(~Q(~Q(tags__name='t4')))]) - @unittest.expectedFailure def test_exclude_in(self): self.assertQuerysetEqual( Item.objects.exclude(Q(tags__name__in=['t4', 't3'])), @@ -2291,6 +2289,103 @@ class ExcludeTest(TestCase): Responsibility.objects.exclude(jobs__name='Manager'), ['']) +class ExcludeTest17600(TestCase): + """ + Some regressiontests for ticket #17600. Some of these likely duplicate + other existing tests. + """ + + def setUp(self): + # Create a few Orders. + self.o1 = Order.objects.create(pk=1) + self.o2 = Order.objects.create(pk=2) + self.o3 = Order.objects.create(pk=3) + + # Create some OrderItems for the first order with homogeneous + # status_id values + self.oi1 = OrderItem.objects.create(order=self.o1, status=1) + self.oi2 = OrderItem.objects.create(order=self.o1, status=1) + self.oi3 = OrderItem.objects.create(order=self.o1, status=1) + + # Create some OrderItems for the second order with heterogeneous + # status_id values + self.oi4 = OrderItem.objects.create(order=self.o2, status=1) + self.oi5 = OrderItem.objects.create(order=self.o2, status=2) + self.oi6 = OrderItem.objects.create(order=self.o2, status=3) + + # Create some OrderItems for the second order with heterogeneous + # status_id values + self.oi7 = OrderItem.objects.create(order=self.o3, status=2) + self.oi8 = OrderItem.objects.create(order=self.o3, status=3) + self.oi9 = OrderItem.objects.create(order=self.o3, status=4) + + def test_exclude_plain(self): + """ + This should exclude Orders which have some items with status 1 + + """ + self.assertQuerysetEqual( + Order.objects.exclude(items__status=1), + ['']) + + def test_exclude_plain_distinct(self): + """ + This should exclude Orders which have some items with status 1 + + """ + self.assertQuerysetEqual( + Order.objects.exclude(items__status=1).distinct(), + ['']) + + def test_exclude_with_q_object_distinct(self): + """ + This should exclude Orders which have some items with status 1 + + """ + self.assertQuerysetEqual( + Order.objects.exclude(Q(items__status=1)).distinct(), + ['']) + + def test_exclude_with_q_object_no_distinct(self): + """ + This should exclude Orders which have some items with status 1 + + """ + self.assertQuerysetEqual( + Order.objects.exclude(Q(items__status=1)), + ['']) + + def test_exclude_with_q_is_equal_to_plain_exclude(self): + """ + Using exclude(condition) and exclude(Q(condition)) should + yield the same QuerySet + + """ + self.assertEqual( + list(Order.objects.exclude(items__status=1).distinct()), + list(Order.objects.exclude(Q(items__status=1)).distinct())) + + def test_exclude_with_q_is_equal_to_plain_exclude_variation(self): + """ + Using exclude(condition) and exclude(Q(condition)) should + yield the same QuerySet + + """ + self.assertEqual( + list(Order.objects.exclude(items__status=1)), + list(Order.objects.exclude(Q(items__status=1)).distinct())) + + @unittest.expectedFailure + def test_only_orders_with_all_items_having_status_1(self): + """ + This should only return orders having ALL items set to status 1, or + those items not having any orders at all. The correct way to write + this query in SQL seems to be using two nested subqueries. + """ + self.assertQuerysetEqual( + Order.objects.exclude(~Q(items__status=1)).distinct(), + ['']) + class NullInExcludeTest(TestCase): def setUp(self): NullableName.objects.create(name='i1') @@ -2326,6 +2421,14 @@ class NullInExcludeTest(TestCase): NullableName.objects.exclude(name__in=[None]), ['i1'], attrgetter('name')) + def test_double_exclude(self): + self.assertEqual( + list(NullableName.objects.filter(~~Q(name='i1'))), + list(NullableName.objects.filter(Q(name='i1')))) + self.assertNotIn( + 'IS NOT NULL', + str(NullableName.objects.filter(~~Q(name='i1')).query)) + class EmptyStringsAsNullTest(TestCase): """ Test that filtering on non-null character fields works as expected. @@ -2433,8 +2536,12 @@ class WhereNodeTest(TestCase): class NullJoinPromotionOrTest(TestCase): def setUp(self): - d = ModelD.objects.create(name='foo') - ModelA.objects.create(name='bar', d=d) + self.d1 = ModelD.objects.create(name='foo') + d2 = ModelD.objects.create(name='bar') + self.a1 = ModelA.objects.create(name='a1', d=self.d1) + c = ModelC.objects.create(name='c') + b = ModelB.objects.create(name='b', c=c) + self.a2 = ModelA.objects.create(name='a2', b=b, d=d2) def test_ticket_17886(self): # The first Q-object is generating the match, the rest of the filters @@ -2448,12 +2555,38 @@ class NullJoinPromotionOrTest(TestCase): Q(b__c__name='foo') ) qset = ModelA.objects.filter(q_obj) - self.assertEqual(len(qset), 1) + self.assertEqual(list(qset), [self.a1]) # We generate one INNER JOIN to D. The join is direct and not nullable # so we can use INNER JOIN for it. However, we can NOT use INNER JOIN # for the b->c join, as a->b is nullable. self.assertEqual(str(qset.query).count('INNER JOIN'), 1) + def test_isnull_filter_promotion(self): + qs = ModelA.objects.filter(Q(b__name__isnull=True)) + self.assertEqual(str(qs.query).count('LEFT OUTER'), 1) + self.assertEqual(list(qs), [self.a1]) + + qs = ModelA.objects.filter(~Q(b__name__isnull=True)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(list(qs), [self.a2]) + + qs = ModelA.objects.filter(~~Q(b__name__isnull=True)) + self.assertEqual(str(qs.query).count('LEFT OUTER'), 1) + self.assertEqual(list(qs), [self.a1]) + + qs = ModelA.objects.filter(Q(b__name__isnull=False)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(list(qs), [self.a2]) + + qs = ModelA.objects.filter(~Q(b__name__isnull=False)) + self.assertEqual(str(qs.query).count('LEFT OUTER'), 1) + self.assertEqual(list(qs), [self.a1]) + + qs = ModelA.objects.filter(~~Q(b__name__isnull=False)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(list(qs), [self.a2]) + + class ReverseJoinTrimmingTest(TestCase): def test_reverse_trimming(self): # Check that we don't accidentally trim reverse joins - we can't know