From 6a28f581c0fec79f387fa10dddb9bdc8594a03d6 Mon Sep 17 00:00:00 2001 From: Justin Bronn Date: Tue, 23 Feb 2010 05:59:04 +0000 Subject: [PATCH] Fixed #12855 -- QuerySets? with extra where parameters now combine correctly. Thanks, Alex Gaynor. Backport of r12502 from trunk. git-svn-id: http://code.djangoproject.com/svn/django/branches/releases/1.1.X@12507 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/sql/query.py | 25 +++---------------- django/db/models/sql/where.py | 10 +++++++- tests/regressiontests/extra_regress/models.py | 4 +++ 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 033a267a8b..a94dca770f 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,7 +18,8 @@ from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.expressions import SQLEvaluator -from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR +from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, + ExtraWhere, AND, OR) from django.core.exceptions import FieldError from datastructures import EmptyResultSet, Empty, MultiJoin from constants import * @@ -92,8 +93,6 @@ class BaseQuery(object): self._extra_select_cache = None self.extra_tables = () - self.extra_where = () - self.extra_params = () self.extra_order_by = () # A tuple that is a set of model field names and either True, if these @@ -232,8 +231,6 @@ class BaseQuery(object): else: obj._extra_select_cache = self._extra_select_cache.copy() obj.extra_tables = self.extra_tables - obj.extra_where = self.extra_where - obj.extra_params = self.extra_params obj.extra_order_by = self.extra_order_by obj.deferred_loading = deepcopy(self.deferred_loading) if self.filter_is_sticky and self.used_aliases: @@ -418,12 +415,6 @@ class BaseQuery(object): if where: result.append('WHERE %s' % where) params.extend(w_params) - if self.extra_where: - if not where: - result.append('WHERE') - else: - result.append('AND') - result.append(' AND '.join(self.extra_where)) grouping, gb_params = self.get_grouping() if grouping: @@ -458,7 +449,6 @@ class BaseQuery(object): result.append('LIMIT %d' % val) result.append('OFFSET %d' % self.low_mark) - params.extend(self.extra_params) return ' '.join(result), tuple(params) def as_nested_sql(self): @@ -553,9 +543,6 @@ class BaseQuery(object): if self.extra and rhs.extra: raise ValueError("When merging querysets using 'or', you " "cannot have extra(select=...) on both sides.") - if self.extra_where and rhs.extra_where: - raise ValueError("When merging querysets using 'or', you " - "cannot have extra(where=...) on both sides.") self.extra.update(rhs.extra) extra_select_mask = set() if self.extra_select_mask is not None: @@ -565,8 +552,6 @@ class BaseQuery(object): if extra_select_mask: self.set_extra_mask(extra_select_mask) self.extra_tables += rhs.extra_tables - self.extra_where += rhs.extra_where - self.extra_params += rhs.extra_params # Ordering uses the 'rhs' ordering, unless it has none, in which case # the current ordering is used. @@ -2181,10 +2166,8 @@ class BaseQuery(object): select_pairs[name] = (entry, entry_params) # This is order preserving, since self.extra_select is a SortedDict. self.extra.update(select_pairs) - if where: - self.extra_where += tuple(where) - if params: - self.extra_params += tuple(params) + if where or params: + self.where.add(ExtraWhere(where, params), AND) if tables: self.extra_tables += tuple(tables) if order_by: diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ec0545ca5b..9ae5e74c42 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -216,7 +216,7 @@ class WhereNode(tree.Node): child.relabel_aliases(change_map) elif isinstance(child, tree.Node): self.relabel_aliases(change_map, child) - else: + elif isinstance(child, (list, tuple)): if isinstance(child[0], (list, tuple)): elt = list(child[0]) if elt[0] in change_map: @@ -249,6 +249,14 @@ class NothingNode(object): def relabel_aliases(self, change_map, node=None): return +class ExtraWhere(object): + def __init__(self, sqls, params): + self.sqls = sqls + self.params = params + + def as_sql(self, qn=None): + return " AND ".join(self.sqls), tuple(self.params or ()) + class Constraint(object): """ An object that can be passed to WhereNode.add() and knows how to diff --git a/tests/regressiontests/extra_regress/models.py b/tests/regressiontests/extra_regress/models.py index 76eb549f81..65c53e65b3 100644 --- a/tests/regressiontests/extra_regress/models.py +++ b/tests/regressiontests/extra_regress/models.py @@ -208,4 +208,8 @@ True >>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1})) [] +>>> pk = TestObject.objects.get().pk +>>> TestObject.objects.filter(pk=pk) | TestObject.objects.extra(where=["id > %s"], params=[pk]) +[] + """}