Fixed #14930 -- values_list() failure on qs ordered by extra column

Thanks lsaffre for the report and simon29, vicould, and Florian Hahn
for the patch.

Some changes done by committer.
This commit is contained in:
Florian Hahn 2013-02-06 00:11:28 +01:00 committed by Anssi Kääriäinen
parent 9da9b3eb04
commit 2f35c6f10f
3 changed files with 83 additions and 21 deletions

View File

@ -22,6 +22,12 @@ class SQLCompiler(object):
self.connection = connection self.connection = connection
self.using = using self.using = using
self.quote_cache = {} self.quote_cache = {}
# When ordering a queryset with distinct on a column not part of the
# select set, the ordering column needs to be added to the select
# clause. This information is needed both in SQL construction and
# masking away the ordering selects from the returned row.
self.ordering_aliases = []
self.ordering_params = []
def pre_sql_setup(self): def pre_sql_setup(self):
""" """
@ -74,7 +80,7 @@ class SQLCompiler(object):
# another run of it. # another run of it.
self.refcounts_before = self.query.alias_refcount.copy() self.refcounts_before = self.query.alias_refcount.copy()
out_cols, s_params = self.get_columns(with_col_aliases) out_cols, s_params = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering() ordering, o_params, ordering_group_by = self.get_ordering()
distinct_fields = self.get_distinct() distinct_fields = self.get_distinct()
@ -95,9 +101,10 @@ class SQLCompiler(object):
if self.query.distinct: if self.query.distinct:
result.append(self.connection.ops.distinct_sql(distinct_fields)) result.append(self.connection.ops.distinct_sql(distinct_fields))
params.extend(o_params)
result.append(', '.join(out_cols + self.query.ordering_aliases)) result.append(', '.join(out_cols + self.ordering_aliases))
params.extend(s_params) params.extend(s_params)
params.extend(self.ordering_params)
result.append('FROM') result.append('FROM')
result.extend(from_) result.extend(from_)
@ -319,7 +326,6 @@ class SQLCompiler(object):
result.append("%s.%s" % (qn(alias), qn2(col))) result.append("%s.%s" % (qn(alias), qn2(col)))
return result return result
def get_ordering(self): def get_ordering(self):
""" """
Returns a tuple containing a list representing the SQL elements in the Returns a tuple containing a list representing the SQL elements in the
@ -357,7 +363,9 @@ class SQLCompiler(object):
# the table/column pairs we use and discard any after the first use. # the table/column pairs we use and discard any after the first use.
processed_pairs = set() processed_pairs = set()
for field in ordering: params = []
ordering_params = []
for pos, field in enumerate(ordering):
if field == '?': if field == '?':
result.append(self.connection.ops.random_function_sql()) result.append(self.connection.ops.random_function_sql())
continue continue
@ -384,7 +392,7 @@ class SQLCompiler(object):
if not distinct or elt in select_aliases: if not distinct or elt in select_aliases:
result.append('%s %s' % (elt, order)) result.append('%s %s' % (elt, order))
group_by.append((elt, [])) group_by.append((elt, []))
elif get_order_dir(field)[0] not in self.query.extra_select: elif get_order_dir(field)[0] not in self.query.extra:
# 'col' is of the form 'field' or 'field1__field2' or # 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc. # '-field1__field2__field', etc.
for table, cols, order in self.find_ordering_name(field, for table, cols, order in self.find_ordering_name(field,
@ -399,12 +407,19 @@ class SQLCompiler(object):
group_by.append((elt, [])) group_by.append((elt, []))
else: else:
elt = qn2(col) elt = qn2(col)
if col not in self.query.extra_select:
sql = "(%s) AS %s" % (self.query.extra[col][0], elt)
ordering_aliases.append(sql)
ordering_params.extend(self.query.extra[col][1])
else:
if distinct and col not in select_aliases: if distinct and col not in select_aliases:
ordering_aliases.append(elt) ordering_aliases.append(elt)
ordering_params.extend(params)
result.append('%s %s' % (elt, order)) result.append('%s %s' % (elt, order))
group_by.append(self.query.extra_select[col]) group_by.append(self.query.extra[col])
self.query.ordering_aliases = ordering_aliases self.ordering_aliases = ordering_aliases
return result, group_by self.ordering_params = ordering_params
return result, params, group_by
def find_ordering_name(self, name, opts, alias=None, default_order='ASC', def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
already_seen=None): already_seen=None):
@ -764,13 +779,13 @@ class SQLCompiler(object):
if not result_type: if not result_type:
return cursor return cursor
if result_type == SINGLE: if result_type == SINGLE:
if self.query.ordering_aliases: if self.ordering_aliases:
return cursor.fetchone()[:-len(self.query.ordering_aliases)] return cursor.fetchone()[:-len(self.ordering_aliases)]
return cursor.fetchone() return cursor.fetchone()
# The MULTI case. # The MULTI case.
if self.query.ordering_aliases: if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.query.ordering_aliases), result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value) self.connection.features.empty_fetchmany_value)
else: else:
result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),

View File

@ -115,7 +115,6 @@ class Query(object):
self.default_cols = True self.default_cols = True
self.default_ordering = True self.default_ordering = True
self.standard_ordering = True self.standard_ordering = True
self.ordering_aliases = []
self.used_aliases = set() self.used_aliases = set()
self.filter_is_sticky = False self.filter_is_sticky = False
self.included_inherited_models = {} self.included_inherited_models = {}
@ -227,7 +226,6 @@ class Query(object):
obj.default_ordering = self.default_ordering obj.default_ordering = self.default_ordering
obj.standard_ordering = self.standard_ordering obj.standard_ordering = self.standard_ordering
obj.included_inherited_models = self.included_inherited_models.copy() obj.included_inherited_models = self.included_inherited_models.copy()
obj.ordering_aliases = []
obj.select = self.select[:] obj.select = self.select[:]
obj.related_select_cols = [] obj.related_select_cols = []
obj.tables = self.tables[:] obj.tables = self.tables[:]

View File

@ -1976,13 +1976,62 @@ class EmptyQuerySetTests(TestCase):
class ValuesQuerysetTests(BaseQuerysetTest): class ValuesQuerysetTests(BaseQuerysetTest):
def test_flat_values_lits(self): def setUp(self):
Number.objects.create(num=72) Number.objects.create(num=72)
self.identity = lambda x: x
def test_flat_values_list(self):
qs = Number.objects.values_list("num") qs = Number.objects.values_list("num")
qs = qs.values_list("num", flat=True) qs = qs.values_list("num", flat=True)
self.assertValueQuerysetEqual( self.assertValueQuerysetEqual(qs, [72])
qs, [72]
) def test_extra_values(self):
# testing for ticket 14930 issues
qs = Number.objects.extra(select=SortedDict([('value_plus_x', 'num+%s'),
('value_minus_x', 'num-%s')]),
select_params=(1, 2))
qs = qs.order_by('value_minus_x')
qs = qs.values('num')
self.assertQuerysetEqual(qs, [{'num': 72}], self.identity)
def test_extra_values_order_twice(self):
# testing for ticket 14930 issues
qs = Number.objects.extra(select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1'})
qs = qs.order_by('value_minus_one').order_by('value_plus_one')
qs = qs.values('num')
self.assertQuerysetEqual(qs, [{'num': 72}], self.identity)
def test_extra_values_order_multiple(self):
# Postgres doesn't allow constants in order by, so check for that.
qs = Number.objects.extra(select={
'value_plus_one': 'num+1',
'value_minus_one': 'num-1',
'constant_value': '1'
})
qs = qs.order_by('value_plus_one', 'value_minus_one', 'constant_value')
qs = qs.values('num')
self.assertQuerysetEqual(qs, [{'num': 72}], self.identity)
def test_extra_values_order_in_extra(self):
# testing for ticket 14930 issues
qs = Number.objects.extra(
select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1'},
order_by=['value_minus_one'])
qs = qs.values('num')
def test_extra_values_list(self):
# testing for ticket 14930 issues
qs = Number.objects.extra(select={'value_plus_one': 'num+1'})
qs = qs.order_by('value_plus_one')
qs = qs.values_list('num')
self.assertQuerysetEqual(qs, [(72,)], self.identity)
def test_flat_extra_values_list(self):
# testing for ticket 14930 issues
qs = Number.objects.extra(select={'value_plus_one': 'num+1'})
qs = qs.order_by('value_plus_one')
qs = qs.values_list('num', flat=True)
self.assertQuerysetEqual(qs, [72], self.identity)
class WeirdQuerysetSlicingTests(BaseQuerysetTest): class WeirdQuerysetSlicingTests(BaseQuerysetTest):