diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 8c5ef3355e..61c1e51557 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -16,7 +16,7 @@ from django.db import connection from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend -from django.db.models.sql.where import WhereNode, EverythingNode, AND, OR +from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR from django.db.models.sql.datastructures import Count from django.core.exceptions import FieldError from datastructures import EmptyResultSet, Empty, MultiJoin @@ -66,7 +66,7 @@ class BaseQuery(object): self.where = where() self.where_class = where self.group_by = [] - self.having = [] + self.having = where() self.order_by = [] self.low_mark, self.high_mark = 0, None # Used for offset/limit self.distinct = False @@ -172,7 +172,7 @@ class BaseQuery(object): obj.where = deepcopy(self.where) obj.where_class = self.where_class obj.group_by = self.group_by[:] - obj.having = self.having[:] + obj.having = deepcopy(self.having) obj.order_by = self.order_by[:] obj.low_mark, obj.high_mark = self.low_mark, self.high_mark obj.distinct = self.distinct @@ -261,7 +261,9 @@ class BaseQuery(object): # get_from_clause() for details. from_, f_params = self.get_from_clause() - where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias) + qn = self.quote_name_unless_alias + where, w_params = self.where.as_sql(qn=qn) + having, h_params = self.having.as_sql(qn=qn) params = [] for val in self.extra_select.itervalues(): params.extend(val[1]) @@ -291,9 +293,8 @@ class BaseQuery(object): if not ordering: ordering = self.connection.ops.force_no_ordering() - if self.having: - having, h_params = self.get_having() - result.append('HAVING %s' % ', '.join(having)) + if having: + result.append('HAVING %s' % having) params.extend(h_params) if ordering: @@ -577,24 +578,6 @@ class BaseQuery(object): result.append(str(col)) return result - def get_having(self): - """ - Returns a tuple representing the SQL elements in the "having" clause. - By default, the elements of self.having have their as_sql() method - called or are returned unchanged (if they don't have an as_sql() - method). - """ - result = [] - params = [] - for elt in self.having: - if hasattr(elt, 'as_sql'): - sql, params = elt.as_sql() - result.append(sql) - params.extend(params) - else: - result.append(elt) - return result, params - def get_ordering(self): """ Returns list representing the SQL elements in the "order by" clause. @@ -1197,7 +1180,8 @@ class BaseQuery(object): self.promote_alias_chain(join_it, join_promote) self.promote_alias_chain(table_it, table_promote) - self.where.add((alias, col, field, lookup_type, value), connector) + self.where.add((Constraint(alias, col, field), lookup_type, value), + connector) if negate: self.promote_alias_chain(join_list) @@ -1207,7 +1191,7 @@ class BaseQuery(object): if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: j_col = self.alias_map[alias][RHS_JOIN_COL] entry = self.where_class() - entry.add((alias, j_col, None, 'isnull', True), AND) + entry.add((Constraint(alias, j_col, None), 'isnull', True), AND) entry.negate() self.where.add(entry, AND) break @@ -1216,7 +1200,7 @@ class BaseQuery(object): # exclude the "foo__in=[]" case from this handling, because # it's short-circuited in the Where class. entry = self.where_class() - entry.add((alias, col, None, 'isnull', True), AND) + entry.add((Constraint(alias, col, None), 'isnull', True), AND) entry.negate() self.where.add(entry, AND) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 39ef439dc9..524b0894c5 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -6,7 +6,7 @@ from django.core.exceptions import FieldError from django.db.models.sql.constants import * from django.db.models.sql.datastructures import Date from django.db.models.sql.query import Query -from django.db.models.sql.where import AND +from django.db.models.sql.where import AND, Constraint __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 'CountQuery'] @@ -48,8 +48,9 @@ class DeleteQuery(Query): if not isinstance(related.field, generic.GenericRelation): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): where = self.where_class() - where.add((None, related.field.m2m_reverse_name(), - related.field, 'in', + where.add((Constraint(None, + related.field.m2m_reverse_name(), related.field), + 'in', pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), AND) self.do_query(related.field.m2m_db_table(), where) @@ -59,11 +60,11 @@ class DeleteQuery(Query): if isinstance(f, generic.GenericRelation): from django.contrib.contenttypes.models import ContentType field = f.rel.to._meta.get_field(f.content_type_field_name) - w1.add((None, field.column, field, 'exact', + w1.add((Constraint(None, field.column, field), 'exact', ContentType.objects.get_for_model(cls).id), AND) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): where = self.where_class() - where.add((None, f.m2m_column_name(), f, 'in', + where.add((Constraint(None, f.m2m_column_name(), f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) if w1: @@ -81,7 +82,7 @@ class DeleteQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): where = self.where_class() field = self.model._meta.pk - where.add((None, field.column, field, 'in', + where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) self.do_query(self.model._meta.db_table, where) @@ -212,7 +213,7 @@ class UpdateQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.where = self.where_class() f = self.model._meta.pk - self.where.add((None, f.column, f, 'in', + self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) self.values = [(related_field.column, None, '%s')] diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 662d99a4a2..a9fca7df11 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -13,6 +13,13 @@ from datastructures import EmptyResultSet, FullResultSet AND = 'AND' OR = 'OR' +class EmptyShortCircuit(Exception): + """ + Internal exception used to indicate that a "matches nothing" node should be + added to the where-clause. + """ + pass + class WhereNode(tree.Node): """ Used to represent the SQL where-clause. @@ -35,36 +42,35 @@ class WhereNode(tree.Node): storing any reference to field objects). Otherwise, the 'data' is stored unchanged and can be anything with an 'as_sql()' method. """ - # Because of circular imports, we need to import this here. - from django.db.models.base import ObjectDoesNotExist - if not isinstance(data, (list, tuple)): super(WhereNode, self).add(data, connector) return - alias, col, field, lookup_type, value = data - try: - if field: - params = field.get_db_prep_lookup(lookup_type, value) - db_type = field.db_type() - else: - # This is possible when we add a comparison to NULL sometimes - # (we don't really need to waste time looking up the associated - # field object). - params = Field().get_db_prep_lookup(lookup_type, value) - db_type = None - except ObjectDoesNotExist: - # This can happen when trying to insert a reference to a null pk. - # We break out of the normal path and indicate there's nothing to - # match. - super(WhereNode, self).add(NothingNode(), connector) - return + obj, lookup_type, value = data + if hasattr(obj, "process"): + try: + obj, params = obj.process(lookup_type, value) + except EmptyShortCircuit: + # There are situations where we want to short-circuit any + # comparisons and make sure that nothing is returned. One + # example is when checking for a NULL pk value, or the + # equivalent. + super(WhereNode, self).add(NothingNode(), connector) + return + else: + params = Field().get_db_prep_lookup(lookup_type, value) + + # The "annotation" parameter is used to pass auxilliary information + # about the value(s) to the query construction. Specifically, datetime + # and empty values need special handling. Other types could be used + # here in the future (using Python types is suggested for consistency). if isinstance(value, datetime.datetime): annotation = datetime.datetime else: annotation = bool(value) - super(WhereNode, self).add((alias, col, db_type, lookup_type, - annotation, params), connector) + + super(WhereNode, self).add((obj, lookup_type, annotation, params), + connector) def as_sql(self, qn=None): """ @@ -130,12 +136,13 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ - table_alias, name, db_type, lookup_type, value_annot, params = child - if table_alias: - lhs = '%s.%s' % (qn(table_alias), qn(name)) + lvalue, lookup_type, value_annot, params = child + if isinstance(lvalue, tuple): + # A direct database column lookup. + field_sql = self.sql_for_columns(lvalue, qn) else: - lhs = qn(name) - field_sql = connection.ops.field_cast_sql(db_type) % lhs + # A smart object with an as_sql() method. + field_sql = lvalue.as_sql(quote_func=qn) if value_annot is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql() @@ -175,6 +182,19 @@ class WhereNode(tree.Node): raise TypeError('Invalid lookup_type: %r' % lookup_type) + def sql_for_columns(self, data, qn): + """ + Returns the SQL fragment used for the left-hand side of a column + constraint (for example, the "T1.foo" portion in the clause + "WHERE ... T1.foo = 6"). + """ + table_alias, name, db_type = data + if table_alias: + lhs = '%s.%s' % (qn(table_alias), qn(name)) + else: + lhs = qn(name) + return connection.ops.field_cast_sql(db_type) % lhs + def relabel_aliases(self, change_map, node=None): """ Relabels the alias values of any children. 'change_map' is a dictionary @@ -188,8 +208,10 @@ class WhereNode(tree.Node): elif isinstance(child, tree.Node): self.relabel_aliases(change_map, child) else: - if child[0] in change_map: - node.children[pos] = (change_map[child[0]],) + child[1:] + elt = list(child[0]) + if elt[0] in change_map: + elt[0] = change_map[elt[0]] + node.children[pos] = (tuple(elt),) + child[1:] class EverythingNode(object): """ @@ -211,3 +233,33 @@ class NothingNode(object): def relabel_aliases(self, change_map, node=None): return +class Constraint(object): + """ + An object that can be passed to WhereNode.add() and knows how to + pre-process itself prior to including in the WhereNode. + """ + def __init__(self, alias, col, field): + self.alias, self.col, self.field = alias, col, field + + def process(self, lookup_type, value): + """ + Returns a tuple of data suitable for inclusion in a WhereNode + instance. + """ + # Because of circular imports, we need to import this here. + from django.db.models.base import ObjectDoesNotExist + try: + if self.field: + params = self.field.get_db_prep_lookup(lookup_type, value) + db_type = self.field.db_type() + else: + # This branch is used at times when we add a comparison to NULL + # (we don't really want to waste time looking up the associated + # field object at the calling location). + params = Field().get_db_prep_lookup(lookup_type, value) + db_type = None + except ObjectDoesNotExist: + raise EmptyShortCircuit + + return (self.alias, self.col, db_type), params + diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index d8737df197..94f2045fe7 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -973,19 +973,6 @@ relations. >>> len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]) 1 -A check to ensure we don't break the internal query construction of GROUP BY -and HAVING. These aren't supported in the public API, but the Query class knows -about them and shouldn't do bad things. ->>> qs = Tag.objects.values_list('parent_id', flat=True).order_by() ->>> qs.query.group_by = ['parent_id'] ->>> qs.query.having = ['count(parent_id) > 1'] ->>> expected = [t3.parent_id, t4.parent_id] ->>> expected.sort() ->>> result = list(qs) ->>> result.sort() ->>> expected == result -True - Make sure bump_prefix() (an internal Query method) doesn't (re-)break. It's sufficient that this query runs without error. >>> qs = Tag.objects.values_list('id', flat=True).order_by('id')