mirror of https://github.com/django/django.git
Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.
This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause. git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
17958fa7a9
commit
5e2d38465a
|
@ -715,9 +715,6 @@ class ValuesQuerySet(QuerySet):
|
||||||
|
|
||||||
def iterator(self):
|
def iterator(self):
|
||||||
# Purge any extra columns that haven't been explicitly asked for
|
# Purge any extra columns that haven't been explicitly asked for
|
||||||
if self.extra_names is not None:
|
|
||||||
self.query.trim_extra_select(self.extra_names)
|
|
||||||
|
|
||||||
extra_names = self.query.extra_select.keys()
|
extra_names = self.query.extra_select.keys()
|
||||||
field_names = self.field_names
|
field_names = self.field_names
|
||||||
aggregate_names = self.query.aggregate_select.keys()
|
aggregate_names = self.query.aggregate_select.keys()
|
||||||
|
@ -741,13 +738,18 @@ class ValuesQuerySet(QuerySet):
|
||||||
if self._fields:
|
if self._fields:
|
||||||
self.extra_names = []
|
self.extra_names = []
|
||||||
self.aggregate_names = []
|
self.aggregate_names = []
|
||||||
if not self.query.extra_select and not self.query.aggregate_select:
|
if not self.query.extra and not self.query.aggregates:
|
||||||
|
# Short cut - if there are no extra or aggregates, then
|
||||||
|
# the values() clause must be just field names.
|
||||||
self.field_names = list(self._fields)
|
self.field_names = list(self._fields)
|
||||||
else:
|
else:
|
||||||
self.query.default_cols = False
|
self.query.default_cols = False
|
||||||
self.field_names = []
|
self.field_names = []
|
||||||
for f in self._fields:
|
for f in self._fields:
|
||||||
if self.query.extra_select.has_key(f):
|
# we inspect the full extra_select list since we might
|
||||||
|
# be adding back an extra select item that we hadn't
|
||||||
|
# had selected previously.
|
||||||
|
if self.query.extra.has_key(f):
|
||||||
self.extra_names.append(f)
|
self.extra_names.append(f)
|
||||||
elif self.query.aggregate_select.has_key(f):
|
elif self.query.aggregate_select.has_key(f):
|
||||||
self.aggregate_names.append(f)
|
self.aggregate_names.append(f)
|
||||||
|
@ -760,6 +762,8 @@ class ValuesQuerySet(QuerySet):
|
||||||
self.aggregate_names = None
|
self.aggregate_names = None
|
||||||
|
|
||||||
self.query.select = []
|
self.query.select = []
|
||||||
|
if self.extra_names is not None:
|
||||||
|
self.query.set_extra_mask(self.extra_names)
|
||||||
self.query.add_fields(self.field_names, False)
|
self.query.add_fields(self.field_names, False)
|
||||||
if self.aggregate_names is not None:
|
if self.aggregate_names is not None:
|
||||||
self.query.set_aggregate_mask(self.aggregate_names)
|
self.query.set_aggregate_mask(self.aggregate_names)
|
||||||
|
@ -816,9 +820,6 @@ class ValuesQuerySet(QuerySet):
|
||||||
|
|
||||||
class ValuesListQuerySet(ValuesQuerySet):
|
class ValuesListQuerySet(ValuesQuerySet):
|
||||||
def iterator(self):
|
def iterator(self):
|
||||||
if self.extra_names is not None:
|
|
||||||
self.query.trim_extra_select(self.extra_names)
|
|
||||||
|
|
||||||
if self.flat and len(self._fields) == 1:
|
if self.flat and len(self._fields) == 1:
|
||||||
for row in self.query.results_iter():
|
for row in self.query.results_iter():
|
||||||
yield row[0]
|
yield row[0]
|
||||||
|
|
|
@ -88,7 +88,10 @@ class BaseQuery(object):
|
||||||
|
|
||||||
# These are for extensions. The contents are more or less appended
|
# These are for extensions. The contents are more or less appended
|
||||||
# verbatim to the appropriate clause.
|
# verbatim to the appropriate clause.
|
||||||
self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params).
|
self.extra = SortedDict() # Maps col_alias -> (col_sql, params).
|
||||||
|
self.extra_select_mask = None
|
||||||
|
self._extra_select_cache = None
|
||||||
|
|
||||||
self.extra_tables = ()
|
self.extra_tables = ()
|
||||||
self.extra_where = ()
|
self.extra_where = ()
|
||||||
self.extra_params = ()
|
self.extra_params = ()
|
||||||
|
@ -214,13 +217,21 @@ class BaseQuery(object):
|
||||||
if self.aggregate_select_mask is None:
|
if self.aggregate_select_mask is None:
|
||||||
obj.aggregate_select_mask = None
|
obj.aggregate_select_mask = None
|
||||||
else:
|
else:
|
||||||
obj.aggregate_select_mask = self.aggregate_select_mask[:]
|
obj.aggregate_select_mask = self.aggregate_select_mask.copy()
|
||||||
if self._aggregate_select_cache is None:
|
if self._aggregate_select_cache is None:
|
||||||
obj._aggregate_select_cache = None
|
obj._aggregate_select_cache = None
|
||||||
else:
|
else:
|
||||||
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
|
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
|
||||||
obj.max_depth = self.max_depth
|
obj.max_depth = self.max_depth
|
||||||
obj.extra_select = self.extra_select.copy()
|
obj.extra = self.extra.copy()
|
||||||
|
if self.extra_select_mask is None:
|
||||||
|
obj.extra_select_mask = None
|
||||||
|
else:
|
||||||
|
obj.extra_select_mask = self.extra_select_mask.copy()
|
||||||
|
if self._extra_select_cache is None:
|
||||||
|
obj._extra_select_cache = None
|
||||||
|
else:
|
||||||
|
obj._extra_select_cache = self._extra_select_cache.copy()
|
||||||
obj.extra_tables = self.extra_tables
|
obj.extra_tables = self.extra_tables
|
||||||
obj.extra_where = self.extra_where
|
obj.extra_where = self.extra_where
|
||||||
obj.extra_params = self.extra_params
|
obj.extra_params = self.extra_params
|
||||||
|
@ -325,7 +336,7 @@ class BaseQuery(object):
|
||||||
query = self
|
query = self
|
||||||
self.select = []
|
self.select = []
|
||||||
self.default_cols = False
|
self.default_cols = False
|
||||||
self.extra_select = {}
|
self.extra = {}
|
||||||
self.remove_inherited_models()
|
self.remove_inherited_models()
|
||||||
|
|
||||||
query.clear_ordering(True)
|
query.clear_ordering(True)
|
||||||
|
@ -540,13 +551,20 @@ class BaseQuery(object):
|
||||||
# It would be nice to be able to handle this, but the queries don't
|
# It would be nice to be able to handle this, but the queries don't
|
||||||
# really make sense (or return consistent value sets). Not worth
|
# really make sense (or return consistent value sets). Not worth
|
||||||
# the extra complexity when you can write a real query instead.
|
# the extra complexity when you can write a real query instead.
|
||||||
if self.extra_select and rhs.extra_select:
|
if self.extra and rhs.extra:
|
||||||
raise ValueError("When merging querysets using 'or', you "
|
raise ValueError("When merging querysets using 'or', you "
|
||||||
"cannot have extra(select=...) on both sides.")
|
"cannot have extra(select=...) on both sides.")
|
||||||
if self.extra_where and rhs.extra_where:
|
if self.extra_where and rhs.extra_where:
|
||||||
raise ValueError("When merging querysets using 'or', you "
|
raise ValueError("When merging querysets using 'or', you "
|
||||||
"cannot have extra(where=...) on both sides.")
|
"cannot have extra(where=...) on both sides.")
|
||||||
self.extra_select.update(rhs.extra_select)
|
self.extra.update(rhs.extra)
|
||||||
|
extra_select_mask = set()
|
||||||
|
if self.extra_select_mask is not None:
|
||||||
|
extra_select_mask.update(self.extra_select_mask)
|
||||||
|
if rhs.extra_select_mask is not None:
|
||||||
|
extra_select_mask.update(rhs.extra_select_mask)
|
||||||
|
if extra_select_mask:
|
||||||
|
self.set_extra_mask(extra_select_mask)
|
||||||
self.extra_tables += rhs.extra_tables
|
self.extra_tables += rhs.extra_tables
|
||||||
self.extra_where += rhs.extra_where
|
self.extra_where += rhs.extra_where
|
||||||
self.extra_params += rhs.extra_params
|
self.extra_params += rhs.extra_params
|
||||||
|
@ -2011,7 +2029,7 @@ class BaseQuery(object):
|
||||||
except MultiJoin:
|
except MultiJoin:
|
||||||
raise FieldError("Invalid field name: '%s'" % name)
|
raise FieldError("Invalid field name: '%s'" % name)
|
||||||
except FieldError:
|
except FieldError:
|
||||||
names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
|
names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys()
|
||||||
names.sort()
|
names.sort()
|
||||||
raise FieldError("Cannot resolve keyword %r into field. "
|
raise FieldError("Cannot resolve keyword %r into field. "
|
||||||
"Choices are: %s" % (name, ", ".join(names)))
|
"Choices are: %s" % (name, ", ".join(names)))
|
||||||
|
@ -2139,7 +2157,7 @@ class BaseQuery(object):
|
||||||
pos = entry.find("%s", pos + 2)
|
pos = entry.find("%s", pos + 2)
|
||||||
select_pairs[name] = (entry, entry_params)
|
select_pairs[name] = (entry, entry_params)
|
||||||
# This is order preserving, since self.extra_select is a SortedDict.
|
# This is order preserving, since self.extra_select is a SortedDict.
|
||||||
self.extra_select.update(select_pairs)
|
self.extra.update(select_pairs)
|
||||||
if where:
|
if where:
|
||||||
self.extra_where += tuple(where)
|
self.extra_where += tuple(where)
|
||||||
if params:
|
if params:
|
||||||
|
@ -2213,22 +2231,26 @@ class BaseQuery(object):
|
||||||
"""
|
"""
|
||||||
target[model] = set([f.name for f in fields])
|
target[model] = set([f.name for f in fields])
|
||||||
|
|
||||||
def trim_extra_select(self, names):
|
|
||||||
"""
|
|
||||||
Removes any aliases in the extra_select dictionary that aren't in
|
|
||||||
'names'.
|
|
||||||
|
|
||||||
This is needed if we are selecting certain values that don't incldue
|
|
||||||
all of the extra_select names.
|
|
||||||
"""
|
|
||||||
for key in set(self.extra_select).difference(set(names)):
|
|
||||||
del self.extra_select[key]
|
|
||||||
|
|
||||||
def set_aggregate_mask(self, names):
|
def set_aggregate_mask(self, names):
|
||||||
"Set the mask of aggregates that will actually be returned by the SELECT"
|
"Set the mask of aggregates that will actually be returned by the SELECT"
|
||||||
self.aggregate_select_mask = names
|
if names is None:
|
||||||
|
self.aggregate_select_mask = None
|
||||||
|
else:
|
||||||
|
self.aggregate_select_mask = set(names)
|
||||||
self._aggregate_select_cache = None
|
self._aggregate_select_cache = None
|
||||||
|
|
||||||
|
def set_extra_mask(self, names):
|
||||||
|
"""
|
||||||
|
Set the mask of extra select items that will be returned by SELECT,
|
||||||
|
we don't actually remove them from the Query since they might be used
|
||||||
|
later
|
||||||
|
"""
|
||||||
|
if names is None:
|
||||||
|
self.extra_select_mask = None
|
||||||
|
else:
|
||||||
|
self.extra_select_mask = set(names)
|
||||||
|
self._extra_select_cache = None
|
||||||
|
|
||||||
def _aggregate_select(self):
|
def _aggregate_select(self):
|
||||||
"""The SortedDict of aggregate columns that are not masked, and should
|
"""The SortedDict of aggregate columns that are not masked, and should
|
||||||
be used in the SELECT clause.
|
be used in the SELECT clause.
|
||||||
|
@ -2247,6 +2269,19 @@ class BaseQuery(object):
|
||||||
return self.aggregates
|
return self.aggregates
|
||||||
aggregate_select = property(_aggregate_select)
|
aggregate_select = property(_aggregate_select)
|
||||||
|
|
||||||
|
def _extra_select(self):
|
||||||
|
if self._extra_select_cache is not None:
|
||||||
|
return self._extra_select_cache
|
||||||
|
elif self.extra_select_mask is not None:
|
||||||
|
self._extra_select_cache = SortedDict([
|
||||||
|
(k,v) for k,v in self.extra.items()
|
||||||
|
if k in self.extra_select_mask
|
||||||
|
])
|
||||||
|
return self._extra_select_cache
|
||||||
|
else:
|
||||||
|
return self.extra
|
||||||
|
extra_select = property(_extra_select)
|
||||||
|
|
||||||
def set_start(self, start):
|
def set_start(self, start):
|
||||||
"""
|
"""
|
||||||
Sets the table from which to start joining. The start position is
|
Sets the table from which to start joining. The start position is
|
||||||
|
|
|
@ -178,7 +178,7 @@ class UpdateQuery(Query):
|
||||||
# from other tables.
|
# from other tables.
|
||||||
query = self.clone(klass=Query)
|
query = self.clone(klass=Query)
|
||||||
query.bump_prefix()
|
query.bump_prefix()
|
||||||
query.extra_select = {}
|
query.extra = {}
|
||||||
query.select = []
|
query.select = []
|
||||||
query.add_fields([query.model._meta.pk.name])
|
query.add_fields([query.model._meta.pk.name])
|
||||||
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
|
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
|
||||||
|
@ -409,7 +409,7 @@ class DateQuery(Query):
|
||||||
self.select = [select]
|
self.select = [select]
|
||||||
self.select_fields = [None]
|
self.select_fields = [None]
|
||||||
self.select_related = False # See #7097.
|
self.select_related = False # See #7097.
|
||||||
self.extra_select = {}
|
self.extra = {}
|
||||||
self.distinct = True
|
self.distinct = True
|
||||||
self.order_by = order == 'ASC' and [1] or [-1]
|
self.order_by = order == 'ASC' and [1] or [-1]
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,9 @@ class TestObject(models.Model):
|
||||||
second = models.CharField(max_length=20)
|
second = models.CharField(max_length=20)
|
||||||
third = models.CharField(max_length=20)
|
third = models.CharField(max_length=20)
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third)
|
||||||
|
|
||||||
__test__ = {"API_TESTS": """
|
__test__ = {"API_TESTS": """
|
||||||
# Regression tests for #7314 and #7372
|
# Regression tests for #7314 and #7372
|
||||||
|
|
||||||
|
@ -189,6 +192,19 @@ True
|
||||||
>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
|
>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
|
||||||
[(u'third', u'first', u'second', 1)]
|
[(u'third', u'first', u'second', 1)]
|
||||||
|
|
||||||
|
# Regression for #10847: the list of extra columns can always be accurately evaluated.
|
||||||
|
# Using an inner query ensures that as_sql() is producing correct output
|
||||||
|
# without requiring full evaluation and execution of the inner query.
|
||||||
|
>>> TestObject.objects.extra(select={'extra': 1}).values('pk')
|
||||||
|
[{'pk': 1}]
|
||||||
|
|
||||||
|
>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk'))
|
||||||
|
[<TestObject: TestObject: first,second,third>]
|
||||||
|
|
||||||
|
>>> TestObject.objects.values('pk').extra(select={'extra': 1})
|
||||||
|
[{'pk': 1}]
|
||||||
|
|
||||||
|
>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
|
||||||
|
[<TestObject: TestObject: first,second,third>]
|
||||||
|
|
||||||
"""}
|
"""}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue