From 6ed7bd5609d298858c02d5d3b6357b434df06d9f Mon Sep 17 00:00:00 2001 From: Justin Bronn Date: Tue, 23 Feb 2010 04:39:39 +0000 Subject: [PATCH] Fixed #12855 -- QuerySets with `extra` where parameters now combine correctly. Thanks, Alex Gaynor. git-svn-id: http://code.djangoproject.com/svn/django/trunk@12502 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/sql/compiler.py | 7 ------- django/db/models/sql/query.py | 18 ++++-------------- django/db/models/sql/where.py | 10 +++++++++- tests/regressiontests/extra_regress/models.py | 4 ++++ 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1625a0e6c9..f364b1de82 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -84,12 +84,6 @@ class SQLCompiler(object): if where: result.append('WHERE %s' % where) params.extend(w_params) - if self.query.extra_where: - if not where: - result.append('WHERE') - else: - result.append('AND') - result.append(' AND '.join(self.query.extra_where)) grouping, gb_params = self.get_grouping() if grouping: @@ -124,7 +118,6 @@ class SQLCompiler(object): result.append('LIMIT %d' % val) result.append('OFFSET %d' % self.query.low_mark) - params.extend(self.query.extra_params) return ' '.join(result), tuple(params) def as_nested_sql(self): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dde1494662..2d3f6109a0 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -19,7 +19,8 @@ from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.constants import * from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.expressions import SQLEvaluator -from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR +from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, + ExtraWhere, AND, OR) from django.core.exceptions import FieldError __all__ = ['Query', 'RawQuery'] @@ -128,8 +129,6 @@ class Query(object): self._extra_select_cache = None self.extra_tables = () - self.extra_where = () - self.extra_params = () self.extra_order_by = () # A tuple that is a set of model field names and either True, if these @@ -256,8 +255,6 @@ class Query(object): else: obj._extra_select_cache = self._extra_select_cache.copy() obj.extra_tables = self.extra_tables - obj.extra_where = self.extra_where - obj.extra_params = self.extra_params obj.extra_order_by = self.extra_order_by obj.deferred_loading = deepcopy(self.deferred_loading) if self.filter_is_sticky and self.used_aliases: @@ -466,9 +463,6 @@ class Query(object): if self.extra and rhs.extra: raise ValueError("When merging querysets using 'or', you " "cannot have extra(select=...) on both sides.") - if self.extra_where and rhs.extra_where: - raise ValueError("When merging querysets using 'or', you " - "cannot have extra(where=...) on both sides.") self.extra.update(rhs.extra) extra_select_mask = set() if self.extra_select_mask is not None: @@ -478,8 +472,6 @@ class Query(object): if extra_select_mask: self.set_extra_mask(extra_select_mask) self.extra_tables += rhs.extra_tables - self.extra_where += rhs.extra_where - self.extra_params += rhs.extra_params # Ordering uses the 'rhs' ordering, unless it has none, in which case # the current ordering is used. @@ -1611,10 +1603,8 @@ class Query(object): select_pairs[name] = (entry, entry_params) # This is order preserving, since self.extra_select is a SortedDict. self.extra.update(select_pairs) - if where: - self.extra_where += tuple(where) - if params: - self.extra_params += tuple(params) + if where or params: + self.where.add(ExtraWhere(where, params), AND) if tables: self.extra_tables += tuple(tables) if order_by: diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 4aa2351f17..cf147c6ad9 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -220,7 +220,7 @@ class WhereNode(tree.Node): child.relabel_aliases(change_map) elif isinstance(child, tree.Node): self.relabel_aliases(change_map, child) - else: + elif isinstance(child, (list, tuple)): if isinstance(child[0], (list, tuple)): elt = list(child[0]) if elt[0] in change_map: @@ -254,6 +254,14 @@ class NothingNode(object): def relabel_aliases(self, change_map, node=None): return +class ExtraWhere(object): + def __init__(self, sqls, params): + self.sqls = sqls + self.params = params + + def as_sql(self, qn=None, connection=None): + return " AND ".join(self.sqls), tuple(self.params or ()) + class Constraint(object): """ An object that can be passed to WhereNode.add() and knows how to diff --git a/tests/regressiontests/extra_regress/models.py b/tests/regressiontests/extra_regress/models.py index d4d7cb86e5..1e94de0566 100644 --- a/tests/regressiontests/extra_regress/models.py +++ b/tests/regressiontests/extra_regress/models.py @@ -210,4 +210,8 @@ True >>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1})) [] +>>> pk = TestObject.objects.get().pk +>>> TestObject.objects.filter(pk=pk) | TestObject.objects.extra(where=["id > %s"], params=[pk]) +[] + """}