Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.

This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2009-04-30 15:40:09 +00:00
parent 17958fa7a9
commit 5e2d38465a
4 changed files with 84 additions and 32 deletions

View File

@ -715,9 +715,6 @@ class ValuesQuerySet(QuerySet):
def iterator(self):
# Purge any extra columns that haven't been explicitly asked for
if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names)
extra_names = self.query.extra_select.keys()
field_names = self.field_names
aggregate_names = self.query.aggregate_select.keys()
@ -741,13 +738,18 @@ class ValuesQuerySet(QuerySet):
if self._fields:
self.extra_names = []
self.aggregate_names = []
if not self.query.extra_select and not self.query.aggregate_select:
if not self.query.extra and not self.query.aggregates:
# Short cut - if there are no extra or aggregates, then
# the values() clause must be just field names.
self.field_names = list(self._fields)
else:
self.query.default_cols = False
self.field_names = []
for f in self._fields:
if self.query.extra_select.has_key(f):
# we inspect the full extra_select list since we might
# be adding back an extra select item that we hadn't
# had selected previously.
if self.query.extra.has_key(f):
self.extra_names.append(f)
elif self.query.aggregate_select.has_key(f):
self.aggregate_names.append(f)
@ -760,6 +762,8 @@ class ValuesQuerySet(QuerySet):
self.aggregate_names = None
self.query.select = []
if self.extra_names is not None:
self.query.set_extra_mask(self.extra_names)
self.query.add_fields(self.field_names, False)
if self.aggregate_names is not None:
self.query.set_aggregate_mask(self.aggregate_names)
@ -816,9 +820,6 @@ class ValuesQuerySet(QuerySet):
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names)
if self.flat and len(self._fields) == 1:
for row in self.query.results_iter():
yield row[0]

View File

@ -88,7 +88,10 @@ class BaseQuery(object):
# These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause.
self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params).
self.extra = SortedDict() # Maps col_alias -> (col_sql, params).
self.extra_select_mask = None
self._extra_select_cache = None
self.extra_tables = ()
self.extra_where = ()
self.extra_params = ()
@ -214,13 +217,21 @@ class BaseQuery(object):
if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None
else:
obj.aggregate_select_mask = self.aggregate_select_mask[:]
obj.aggregate_select_mask = self.aggregate_select_mask.copy()
if self._aggregate_select_cache is None:
obj._aggregate_select_cache = None
else:
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy()
obj.extra = self.extra.copy()
if self.extra_select_mask is None:
obj.extra_select_mask = None
else:
obj.extra_select_mask = self.extra_select_mask.copy()
if self._extra_select_cache is None:
obj._extra_select_cache = None
else:
obj._extra_select_cache = self._extra_select_cache.copy()
obj.extra_tables = self.extra_tables
obj.extra_where = self.extra_where
obj.extra_params = self.extra_params
@ -325,7 +336,7 @@ class BaseQuery(object):
query = self
self.select = []
self.default_cols = False
self.extra_select = {}
self.extra = {}
self.remove_inherited_models()
query.clear_ordering(True)
@ -540,13 +551,20 @@ class BaseQuery(object):
# It would be nice to be able to handle this, but the queries don't
# really make sense (or return consistent value sets). Not worth
# the extra complexity when you can write a real query instead.
if self.extra_select and rhs.extra_select:
if self.extra and rhs.extra:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(select=...) on both sides.")
if self.extra_where and rhs.extra_where:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(where=...) on both sides.")
self.extra_select.update(rhs.extra_select)
self.extra.update(rhs.extra)
extra_select_mask = set()
if self.extra_select_mask is not None:
extra_select_mask.update(self.extra_select_mask)
if rhs.extra_select_mask is not None:
extra_select_mask.update(rhs.extra_select_mask)
if extra_select_mask:
self.set_extra_mask(extra_select_mask)
self.extra_tables += rhs.extra_tables
self.extra_where += rhs.extra_where
self.extra_params += rhs.extra_params
@ -2011,7 +2029,7 @@ class BaseQuery(object):
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys()
names.sort()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
@ -2139,7 +2157,7 @@ class BaseQuery(object):
pos = entry.find("%s", pos + 2)
select_pairs[name] = (entry, entry_params)
# This is order preserving, since self.extra_select is a SortedDict.
self.extra_select.update(select_pairs)
self.extra.update(select_pairs)
if where:
self.extra_where += tuple(where)
if params:
@ -2213,22 +2231,26 @@ class BaseQuery(object):
"""
target[model] = set([f.name for f in fields])
def trim_extra_select(self, names):
"""
Removes any aliases in the extra_select dictionary that aren't in
'names'.
This is needed if we are selecting certain values that don't incldue
all of the extra_select names.
"""
for key in set(self.extra_select).difference(set(names)):
del self.extra_select[key]
def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT"
self.aggregate_select_mask = names
if names is None:
self.aggregate_select_mask = None
else:
self.aggregate_select_mask = set(names)
self._aggregate_select_cache = None
def set_extra_mask(self, names):
"""
Set the mask of extra select items that will be returned by SELECT,
we don't actually remove them from the Query since they might be used
later
"""
if names is None:
self.extra_select_mask = None
else:
self.extra_select_mask = set(names)
self._extra_select_cache = None
def _aggregate_select(self):
"""The SortedDict of aggregate columns that are not masked, and should
be used in the SELECT clause.
@ -2247,6 +2269,19 @@ class BaseQuery(object):
return self.aggregates
aggregate_select = property(_aggregate_select)
def _extra_select(self):
if self._extra_select_cache is not None:
return self._extra_select_cache
elif self.extra_select_mask is not None:
self._extra_select_cache = SortedDict([
(k,v) for k,v in self.extra.items()
if k in self.extra_select_mask
])
return self._extra_select_cache
else:
return self.extra
extra_select = property(_extra_select)
def set_start(self, start):
"""
Sets the table from which to start joining. The start position is

View File

@ -178,7 +178,7 @@ class UpdateQuery(Query):
# from other tables.
query = self.clone(klass=Query)
query.bump_prefix()
query.extra_select = {}
query.extra = {}
query.select = []
query.add_fields([query.model._meta.pk.name])
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
@ -409,7 +409,7 @@ class DateQuery(Query):
self.select = [select]
self.select_fields = [None]
self.select_related = False # See #7097.
self.extra_select = {}
self.extra = {}
self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1]

View File

@ -35,6 +35,9 @@ class TestObject(models.Model):
second = models.CharField(max_length=20)
third = models.CharField(max_length=20)
def __unicode__(self):
return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third)
__test__ = {"API_TESTS": """
# Regression tests for #7314 and #7372
@ -189,6 +192,19 @@ True
>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
[(u'third', u'first', u'second', 1)]
# Regression for #10847: the list of extra columns can always be accurately evaluated.
# Using an inner query ensures that as_sql() is producing correct output
# without requiring full evaluation and execution of the inner query.
>>> TestObject.objects.extra(select={'extra': 1}).values('pk')
[{'pk': 1}]
>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk'))
[<TestObject: TestObject: first,second,third>]
>>> TestObject.objects.values('pk').extra(select={'extra': 1})
[{'pk': 1}]
>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
[<TestObject: TestObject: first,second,third>]
"""}