Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.

This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2009-04-30 15:40:09 +00:00
parent 17958fa7a9
commit 5e2d38465a
4 changed files with 84 additions and 32 deletions

View File

@ -715,9 +715,6 @@ class ValuesQuerySet(QuerySet):
def iterator(self): def iterator(self):
# Purge any extra columns that haven't been explicitly asked for # Purge any extra columns that haven't been explicitly asked for
if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names)
extra_names = self.query.extra_select.keys() extra_names = self.query.extra_select.keys()
field_names = self.field_names field_names = self.field_names
aggregate_names = self.query.aggregate_select.keys() aggregate_names = self.query.aggregate_select.keys()
@ -741,13 +738,18 @@ class ValuesQuerySet(QuerySet):
if self._fields: if self._fields:
self.extra_names = [] self.extra_names = []
self.aggregate_names = [] self.aggregate_names = []
if not self.query.extra_select and not self.query.aggregate_select: if not self.query.extra and not self.query.aggregates:
# Short cut - if there are no extra or aggregates, then
# the values() clause must be just field names.
self.field_names = list(self._fields) self.field_names = list(self._fields)
else: else:
self.query.default_cols = False self.query.default_cols = False
self.field_names = [] self.field_names = []
for f in self._fields: for f in self._fields:
if self.query.extra_select.has_key(f): # we inspect the full extra_select list since we might
# be adding back an extra select item that we hadn't
# had selected previously.
if self.query.extra.has_key(f):
self.extra_names.append(f) self.extra_names.append(f)
elif self.query.aggregate_select.has_key(f): elif self.query.aggregate_select.has_key(f):
self.aggregate_names.append(f) self.aggregate_names.append(f)
@ -760,6 +762,8 @@ class ValuesQuerySet(QuerySet):
self.aggregate_names = None self.aggregate_names = None
self.query.select = [] self.query.select = []
if self.extra_names is not None:
self.query.set_extra_mask(self.extra_names)
self.query.add_fields(self.field_names, False) self.query.add_fields(self.field_names, False)
if self.aggregate_names is not None: if self.aggregate_names is not None:
self.query.set_aggregate_mask(self.aggregate_names) self.query.set_aggregate_mask(self.aggregate_names)
@ -816,9 +820,6 @@ class ValuesQuerySet(QuerySet):
class ValuesListQuerySet(ValuesQuerySet): class ValuesListQuerySet(ValuesQuerySet):
def iterator(self): def iterator(self):
if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names)
if self.flat and len(self._fields) == 1: if self.flat and len(self._fields) == 1:
for row in self.query.results_iter(): for row in self.query.results_iter():
yield row[0] yield row[0]

View File

@ -88,7 +88,10 @@ class BaseQuery(object):
# These are for extensions. The contents are more or less appended # These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause. # verbatim to the appropriate clause.
self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params). self.extra = SortedDict() # Maps col_alias -> (col_sql, params).
self.extra_select_mask = None
self._extra_select_cache = None
self.extra_tables = () self.extra_tables = ()
self.extra_where = () self.extra_where = ()
self.extra_params = () self.extra_params = ()
@ -214,13 +217,21 @@ class BaseQuery(object):
if self.aggregate_select_mask is None: if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None obj.aggregate_select_mask = None
else: else:
obj.aggregate_select_mask = self.aggregate_select_mask[:] obj.aggregate_select_mask = self.aggregate_select_mask.copy()
if self._aggregate_select_cache is None: if self._aggregate_select_cache is None:
obj._aggregate_select_cache = None obj._aggregate_select_cache = None
else: else:
obj._aggregate_select_cache = self._aggregate_select_cache.copy() obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy() obj.extra = self.extra.copy()
if self.extra_select_mask is None:
obj.extra_select_mask = None
else:
obj.extra_select_mask = self.extra_select_mask.copy()
if self._extra_select_cache is None:
obj._extra_select_cache = None
else:
obj._extra_select_cache = self._extra_select_cache.copy()
obj.extra_tables = self.extra_tables obj.extra_tables = self.extra_tables
obj.extra_where = self.extra_where obj.extra_where = self.extra_where
obj.extra_params = self.extra_params obj.extra_params = self.extra_params
@ -325,7 +336,7 @@ class BaseQuery(object):
query = self query = self
self.select = [] self.select = []
self.default_cols = False self.default_cols = False
self.extra_select = {} self.extra = {}
self.remove_inherited_models() self.remove_inherited_models()
query.clear_ordering(True) query.clear_ordering(True)
@ -540,13 +551,20 @@ class BaseQuery(object):
# It would be nice to be able to handle this, but the queries don't # It would be nice to be able to handle this, but the queries don't
# really make sense (or return consistent value sets). Not worth # really make sense (or return consistent value sets). Not worth
# the extra complexity when you can write a real query instead. # the extra complexity when you can write a real query instead.
if self.extra_select and rhs.extra_select: if self.extra and rhs.extra:
raise ValueError("When merging querysets using 'or', you " raise ValueError("When merging querysets using 'or', you "
"cannot have extra(select=...) on both sides.") "cannot have extra(select=...) on both sides.")
if self.extra_where and rhs.extra_where: if self.extra_where and rhs.extra_where:
raise ValueError("When merging querysets using 'or', you " raise ValueError("When merging querysets using 'or', you "
"cannot have extra(where=...) on both sides.") "cannot have extra(where=...) on both sides.")
self.extra_select.update(rhs.extra_select) self.extra.update(rhs.extra)
extra_select_mask = set()
if self.extra_select_mask is not None:
extra_select_mask.update(self.extra_select_mask)
if rhs.extra_select_mask is not None:
extra_select_mask.update(rhs.extra_select_mask)
if extra_select_mask:
self.set_extra_mask(extra_select_mask)
self.extra_tables += rhs.extra_tables self.extra_tables += rhs.extra_tables
self.extra_where += rhs.extra_where self.extra_where += rhs.extra_where
self.extra_params += rhs.extra_params self.extra_params += rhs.extra_params
@ -2011,7 +2029,7 @@ class BaseQuery(object):
except MultiJoin: except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name) raise FieldError("Invalid field name: '%s'" % name)
except FieldError: except FieldError:
names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys() names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys()
names.sort() names.sort()
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(names)))
@ -2139,7 +2157,7 @@ class BaseQuery(object):
pos = entry.find("%s", pos + 2) pos = entry.find("%s", pos + 2)
select_pairs[name] = (entry, entry_params) select_pairs[name] = (entry, entry_params)
# This is order preserving, since self.extra_select is a SortedDict. # This is order preserving, since self.extra_select is a SortedDict.
self.extra_select.update(select_pairs) self.extra.update(select_pairs)
if where: if where:
self.extra_where += tuple(where) self.extra_where += tuple(where)
if params: if params:
@ -2213,22 +2231,26 @@ class BaseQuery(object):
""" """
target[model] = set([f.name for f in fields]) target[model] = set([f.name for f in fields])
def trim_extra_select(self, names):
"""
Removes any aliases in the extra_select dictionary that aren't in
'names'.
This is needed if we are selecting certain values that don't incldue
all of the extra_select names.
"""
for key in set(self.extra_select).difference(set(names)):
del self.extra_select[key]
def set_aggregate_mask(self, names): def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT" "Set the mask of aggregates that will actually be returned by the SELECT"
self.aggregate_select_mask = names if names is None:
self.aggregate_select_mask = None
else:
self.aggregate_select_mask = set(names)
self._aggregate_select_cache = None self._aggregate_select_cache = None
def set_extra_mask(self, names):
"""
Set the mask of extra select items that will be returned by SELECT,
we don't actually remove them from the Query since they might be used
later
"""
if names is None:
self.extra_select_mask = None
else:
self.extra_select_mask = set(names)
self._extra_select_cache = None
def _aggregate_select(self): def _aggregate_select(self):
"""The SortedDict of aggregate columns that are not masked, and should """The SortedDict of aggregate columns that are not masked, and should
be used in the SELECT clause. be used in the SELECT clause.
@ -2247,6 +2269,19 @@ class BaseQuery(object):
return self.aggregates return self.aggregates
aggregate_select = property(_aggregate_select) aggregate_select = property(_aggregate_select)
def _extra_select(self):
if self._extra_select_cache is not None:
return self._extra_select_cache
elif self.extra_select_mask is not None:
self._extra_select_cache = SortedDict([
(k,v) for k,v in self.extra.items()
if k in self.extra_select_mask
])
return self._extra_select_cache
else:
return self.extra
extra_select = property(_extra_select)
def set_start(self, start): def set_start(self, start):
""" """
Sets the table from which to start joining. The start position is Sets the table from which to start joining. The start position is

View File

@ -178,7 +178,7 @@ class UpdateQuery(Query):
# from other tables. # from other tables.
query = self.clone(klass=Query) query = self.clone(klass=Query)
query.bump_prefix() query.bump_prefix()
query.extra_select = {} query.extra = {}
query.select = [] query.select = []
query.add_fields([query.model._meta.pk.name]) query.add_fields([query.model._meta.pk.name])
must_pre_select = count > 1 and not self.connection.features.update_can_self_select must_pre_select = count > 1 and not self.connection.features.update_can_self_select
@ -409,7 +409,7 @@ class DateQuery(Query):
self.select = [select] self.select = [select]
self.select_fields = [None] self.select_fields = [None]
self.select_related = False # See #7097. self.select_related = False # See #7097.
self.extra_select = {} self.extra = {}
self.distinct = True self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1] self.order_by = order == 'ASC' and [1] or [-1]

View File

@ -35,6 +35,9 @@ class TestObject(models.Model):
second = models.CharField(max_length=20) second = models.CharField(max_length=20)
third = models.CharField(max_length=20) third = models.CharField(max_length=20)
def __unicode__(self):
return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third)
__test__ = {"API_TESTS": """ __test__ = {"API_TESTS": """
# Regression tests for #7314 and #7372 # Regression tests for #7314 and #7372
@ -189,6 +192,19 @@ True
>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id') >>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
[(u'third', u'first', u'second', 1)] [(u'third', u'first', u'second', 1)]
# Regression for #10847: the list of extra columns can always be accurately evaluated.
# Using an inner query ensures that as_sql() is producing correct output
# without requiring full evaluation and execution of the inner query.
>>> TestObject.objects.extra(select={'extra': 1}).values('pk')
[{'pk': 1}]
>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk'))
[<TestObject: TestObject: first,second,third>]
>>> TestObject.objects.values('pk').extra(select={'extra': 1})
[{'pk': 1}]
>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
[<TestObject: TestObject: first,second,third>]
"""} """}