Fixed #12855 -- QuerySets with `extra` where parameters now combine correctly. Thanks, Alex Gaynor.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12502 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Justin Bronn 2010-02-23 04:39:39 +00:00
parent c1d795df45
commit 6ed7bd5609
4 changed files with 17 additions and 22 deletions

View File

@ -84,12 +84,6 @@ class SQLCompiler(object):
if where: if where:
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
params.extend(w_params) params.extend(w_params)
if self.query.extra_where:
if not where:
result.append('WHERE')
else:
result.append('AND')
result.append(' AND '.join(self.query.extra_where))
grouping, gb_params = self.get_grouping() grouping, gb_params = self.get_grouping()
if grouping: if grouping:
@ -124,7 +118,6 @@ class SQLCompiler(object):
result.append('LIMIT %d' % val) result.append('LIMIT %d' % val)
result.append('OFFSET %d' % self.query.low_mark) result.append('OFFSET %d' % self.query.low_mark)
params.extend(self.query.extra_params)
return ' '.join(result), tuple(params) return ' '.join(result), tuple(params)
def as_nested_sql(self): def as_nested_sql(self):

View File

@ -19,7 +19,8 @@ from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import * from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
ExtraWhere, AND, OR)
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
__all__ = ['Query', 'RawQuery'] __all__ = ['Query', 'RawQuery']
@ -128,8 +129,6 @@ class Query(object):
self._extra_select_cache = None self._extra_select_cache = None
self.extra_tables = () self.extra_tables = ()
self.extra_where = ()
self.extra_params = ()
self.extra_order_by = () self.extra_order_by = ()
# A tuple that is a set of model field names and either True, if these # A tuple that is a set of model field names and either True, if these
@ -256,8 +255,6 @@ class Query(object):
else: else:
obj._extra_select_cache = self._extra_select_cache.copy() obj._extra_select_cache = self._extra_select_cache.copy()
obj.extra_tables = self.extra_tables obj.extra_tables = self.extra_tables
obj.extra_where = self.extra_where
obj.extra_params = self.extra_params
obj.extra_order_by = self.extra_order_by obj.extra_order_by = self.extra_order_by
obj.deferred_loading = deepcopy(self.deferred_loading) obj.deferred_loading = deepcopy(self.deferred_loading)
if self.filter_is_sticky and self.used_aliases: if self.filter_is_sticky and self.used_aliases:
@ -466,9 +463,6 @@ class Query(object):
if self.extra and rhs.extra: if self.extra and rhs.extra:
raise ValueError("When merging querysets using 'or', you " raise ValueError("When merging querysets using 'or', you "
"cannot have extra(select=...) on both sides.") "cannot have extra(select=...) on both sides.")
if self.extra_where and rhs.extra_where:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(where=...) on both sides.")
self.extra.update(rhs.extra) self.extra.update(rhs.extra)
extra_select_mask = set() extra_select_mask = set()
if self.extra_select_mask is not None: if self.extra_select_mask is not None:
@ -478,8 +472,6 @@ class Query(object):
if extra_select_mask: if extra_select_mask:
self.set_extra_mask(extra_select_mask) self.set_extra_mask(extra_select_mask)
self.extra_tables += rhs.extra_tables self.extra_tables += rhs.extra_tables
self.extra_where += rhs.extra_where
self.extra_params += rhs.extra_params
# Ordering uses the 'rhs' ordering, unless it has none, in which case # Ordering uses the 'rhs' ordering, unless it has none, in which case
# the current ordering is used. # the current ordering is used.
@ -1611,10 +1603,8 @@ class Query(object):
select_pairs[name] = (entry, entry_params) select_pairs[name] = (entry, entry_params)
# This is order preserving, since self.extra_select is a SortedDict. # This is order preserving, since self.extra_select is a SortedDict.
self.extra.update(select_pairs) self.extra.update(select_pairs)
if where: if where or params:
self.extra_where += tuple(where) self.where.add(ExtraWhere(where, params), AND)
if params:
self.extra_params += tuple(params)
if tables: if tables:
self.extra_tables += tuple(tables) self.extra_tables += tuple(tables)
if order_by: if order_by:

View File

@ -220,7 +220,7 @@ class WhereNode(tree.Node):
child.relabel_aliases(change_map) child.relabel_aliases(change_map)
elif isinstance(child, tree.Node): elif isinstance(child, tree.Node):
self.relabel_aliases(change_map, child) self.relabel_aliases(change_map, child)
else: elif isinstance(child, (list, tuple)):
if isinstance(child[0], (list, tuple)): if isinstance(child[0], (list, tuple)):
elt = list(child[0]) elt = list(child[0])
if elt[0] in change_map: if elt[0] in change_map:
@ -254,6 +254,14 @@ class NothingNode(object):
def relabel_aliases(self, change_map, node=None): def relabel_aliases(self, change_map, node=None):
return return
class ExtraWhere(object):
def __init__(self, sqls, params):
self.sqls = sqls
self.params = params
def as_sql(self, qn=None, connection=None):
return " AND ".join(self.sqls), tuple(self.params or ())
class Constraint(object): class Constraint(object):
""" """
An object that can be passed to WhereNode.add() and knows how to An object that can be passed to WhereNode.add() and knows how to

View File

@ -210,4 +210,8 @@ True
>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1})) >>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
[<TestObject: TestObject: first,second,third>] [<TestObject: TestObject: first,second,third>]
>>> pk = TestObject.objects.get().pk
>>> TestObject.objects.filter(pk=pk) | TestObject.objects.extra(where=["id > %s"], params=[pk])
[<TestObject: TestObject: first,second,third>]
"""} """}