Fixed #24747 -- Allowed transforms in QuerySet.order_by() and distinct(*fields).

This commit is contained in:
Matthew Wilkes 2017-06-18 16:53:40 +01:00 committed by Tim Graham
parent bf26f66029
commit 2162f0983d
15 changed files with 260 additions and 38 deletions

View File

@ -158,7 +158,7 @@ class BaseDatabaseOperations:
"""
return ''
def distinct_sql(self, fields):
def distinct_sql(self, fields, params):
"""
Return an SQL DISTINCT clause which removes duplicate rows from the
result set. If any fields are given, only check the given fields for
@ -167,7 +167,7 @@ class BaseDatabaseOperations:
if fields:
raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
else:
return 'DISTINCT'
return ['DISTINCT'], []
def fetch_returned_insert_id(self, cursor):
"""

View File

@ -207,11 +207,12 @@ class DatabaseOperations(BaseDatabaseOperations):
"""
return 63
def distinct_sql(self, fields):
def distinct_sql(self, fields, params):
if fields:
return 'DISTINCT ON (%s)' % ', '.join(fields)
params = [param for param_list in params for param in param_list]
return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
else:
return 'DISTINCT'
return ['DISTINCT'], []
def last_executed_query(self, cursor, sql, params):
# http://initd.org/psycopg/docs/cursor.html#cursor.query

View File

@ -451,7 +451,7 @@ class SQLCompiler:
raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
else:
distinct_fields = self.get_distinct()
distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause()
@ -461,7 +461,12 @@ class SQLCompiler:
params = []
if self.query.distinct:
result.append(self.connection.ops.distinct_sql(distinct_fields))
distinct_result, distinct_params = self.connection.ops.distinct_sql(
distinct_fields,
distinct_params,
)
result += distinct_result
params += distinct_params
out_cols = []
col_idx = 1
@ -621,21 +626,22 @@ class SQLCompiler:
This method can alter the tables in the query, and thus it must be
called before get_from_clause().
"""
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
result = []
params = []
opts = self.query.get_meta()
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
_, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None)
_, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None)
targets, alias, _ = self.query.trim_joins(targets, joins, path)
for target in targets:
if name in self.query.annotation_select:
result.append(name)
else:
result.append("%s.%s" % (qn(alias), qn2(target.column)))
return result
r, p = self.compile(transform_function(target, alias))
result.append(r)
params.append(p)
return result, params
def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
already_seen=None):
@ -647,7 +653,7 @@ class SQLCompiler:
name, order = get_order_dir(name, default_order)
descending = order == 'DESC'
pieces = name.split(LOOKUP_SEP)
field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias)
field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model unless the attribute name
@ -666,7 +672,7 @@ class SQLCompiler:
order, already_seen))
return results
targets, alias, _ = self.query.trim_joins(targets, joins, path)
return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets]
return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets]
def _setup_joins(self, pieces, opts, alias):
"""
@ -677,10 +683,9 @@ class SQLCompiler:
match. Executing SQL where this is not true is an error.
"""
alias = alias or self.query.get_initial_alias()
field, targets, opts, joins, path = self.query.setup_joins(
pieces, opts, alias)
field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias)
alias = joins[-1]
return field, targets, alias, joins, path, opts
return field, targets, alias, joins, path, opts, transform_function
def get_from_clause(self):
"""
@ -786,7 +791,7 @@ class SQLCompiler:
}
related_klass_infos.append(klass_info)
select_fields = []
_, _, _, joins, _ = self.query.setup_joins(
_, _, _, joins, _, _ = self.query.setup_joins(
[f.name], opts, root_alias)
alias = joins[-1]
columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta)
@ -843,7 +848,7 @@ class SQLCompiler:
break
if name in self.query._filtered_relations:
fields_found.add(name)
f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias)
model = join_opts.model
alias = joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model

View File

@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL
databases). The abstraction barrier only works one way: this module has to know
all about the internals of models in order to get the information it needs.
"""
import functools
from collections import Counter, OrderedDict, namedtuple
from collections.abc import Iterator, Mapping
from itertools import chain, count, product
@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
@ -56,7 +58,7 @@ def get_children_from_q(q):
JoinInfo = namedtuple(
'JoinInfo',
('final_field', 'targets', 'opts', 'joins', 'path')
('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
)
@ -1429,8 +1431,11 @@ class Query:
generate a MultiJoin exception.
Return the final field involved in the joins, the target field (used
for any 'where' constraint), the final 'opts' value, the joins and the
field path travelled to generate the joins.
for any 'where' constraint), the final 'opts' value, the joins, the
field path traveled to generate the joins, and a transform function
that takes a field and alias and is equivalent to `field.get_col(alias)`
in the simple case but wraps field transforms if they were included in
names.
The target field is the field containing the concrete value. Final
field can be something different, for example foreign key pointing to
@ -1439,10 +1444,46 @@ class Query:
key field for example).
"""
joins = [alias]
# First, generate the path for the names
path, final_field, targets, rest = self.names_to_path(
names, opts, allow_many, fail_on_missing=True)
# The transform can't be applied yet, as joins must be trimmed later.
# To avoid making every caller of this method look up transforms
# directly, compute transforms here and and create a partial that
# converts fields to the appropriate wrapped version.
def final_transformer(field, alias):
return field.get_col(alias)
# Try resolving all the names as fields first. If there's an error,
# treat trailing names as lookups until a field can be resolved.
last_field_exception = None
for pivot in range(len(names), 0, -1):
try:
path, final_field, targets, rest = self.names_to_path(
names[:pivot], opts, allow_many, fail_on_missing=True,
)
except FieldError as exc:
if pivot == 1:
# The first item cannot be a lookup, so it's safe
# to raise the field error here.
raise
else:
last_field_exception = exc
else:
# The transforms are the remaining items that couldn't be
# resolved into fields.
transforms = names[pivot:]
break
for name in transforms:
def transform(field, alias, *, name, previous):
try:
wrapped = previous(field, alias)
return self.try_transform(wrapped, name)
except FieldError:
# FieldError is raised if the transform doesn't exist.
if isinstance(final_field, Field) and last_field_exception:
raise last_field_exception
else:
raise
final_transformer = functools.partial(transform, name=name, previous=final_transformer)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@ -1470,7 +1511,7 @@ class Query:
joins.append(alias)
if filtered_relation:
filtered_relation.path = joins[:]
return JoinInfo(final_field, targets, opts, joins, path)
return JoinInfo(final_field, targets, opts, joins, path, final_transformer)
def trim_joins(self, targets, joins, path):
"""
@ -1683,7 +1724,7 @@ class Query:
join_info.path,
)
for target in targets:
cols.append(target.get_col(final_alias))
cols.append(join_info.transform_function(target, final_alias))
if cols:
self.set_select(cols)
except MultiJoin:

View File

@ -138,6 +138,21 @@ SQL::
Note that in case there is no other lookup specified, Django interprets
``change__abs=27`` as ``change__abs__exact=27``.
This also allows the result to be used in ``ORDER BY`` and ``DISTINCT ON``
clauses. For example ``Experiment.objects.order_by('change__abs')`` generates::
SELECT ... ORDER BY ABS("experiments"."change") ASC
And on databases that support distinct on fields (such as PostgreSQL),
``Experiment.objects.distinct('change__abs')`` generates::
SELECT ... DISTINCT ON ABS("experiments"."change")
.. versionchanged:: 2.1
Ordering and distinct support as described in the last two paragraphs was
added.
When looking for which lookups are allowable after the ``Transform`` has been
applied, Django uses the ``output_field`` attribute. We didn't need to specify
this here as it didn't change, but supposing we were applying ``AbsoluteValue``

View File

@ -64,10 +64,14 @@ Some examples
# Aggregates can contain complex computations also
Company.objects.annotate(num_offerings=Count(F('products') + F('services')))
# Expressions can also be used in order_by()
# Expressions can also be used in order_by(), either directly
Company.objects.order_by(Length('name').asc())
Company.objects.order_by(Length('name').desc())
# or using the double underscore lookup syntax.
from django.db.models import CharField
from django.db.models.functions import Length
CharField.register_lookup(Length)
Company.objects.order_by('name__length')
Built-in Expressions
====================

View File

@ -535,6 +535,19 @@ The ``values()`` method also takes optional keyword arguments,
>>> Blog.objects.values(lower_name=Lower('name'))
<QuerySet [{'lower_name': 'beatles blog'}]>
You can use built-in and :doc:`custom lookups </howto/custom-lookups>` in
ordering. For example::
>>> from django.db.models import CharField
>>> from django.db.models.functions import Lower
>>> CharField.register_lookup(Lower, 'lower')
>>> Blog.objects.values('name__lower')
<QuerySet [{'name__lower': 'beatles blog'}]>
.. versionchanged:: 2.1
Support for lookups was added.
An aggregate within a ``values()`` clause is applied before other arguments
within the same ``values()`` clause. If you need to group by another value,
add it to an earlier ``values()`` clause instead. For example::
@ -580,6 +593,25 @@ A few subtleties that are worth mentioning:
* Calling :meth:`only()` and :meth:`defer()` after ``values()`` doesn't make
sense, so doing so will raise a ``NotImplementedError``.
* Combining transforms and aggregates requires the use of two :meth:`annotate`
calls, either explicitly or as keyword arguments to :meth:`values`. As above,
if the transform has been registered on the relevant field type the first
:meth:`annotate` can be omitted, thus the following examples are equivalent::
>>> from django.db.models import CharField, Count
>>> from django.db.models.functions import Lower
>>> CharField.register_lookup(Lower, 'lower')
>>> Blog.objects.values('entry__authors__name__lower').annotate(entries=Count('entry'))
<QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]>
>>> Blog.objects.values(
... entry__authors__name__lower=Lower('entry__authors__name')
... ).annotate(entries=Count('entry'))
<QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]>
>>> Blog.objects.annotate(
... entry__authors__name__lower=Lower('entry__authors__name')
... ).values('entry__authors__name__lower').annotate(entries=Count('entry'))
<QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]>
It is useful when you know you're only going to need values from a small number
of the available fields and you won't need the functionality of a model
instance object. It's more efficient to select only the fields you need to use.

View File

@ -187,6 +187,9 @@ Models
* Query expressions can now be negated using a minus sign.
* :meth:`.QuerySet.order_by` and :meth:`distinct(*fields) <.QuerySet.distinct>`
now support using field transforms.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
@ -242,6 +245,9 @@ Database backend API
* Renamed the ``allow_sliced_subqueries`` database feature flag to
``allow_sliced_subqueries_with_in``.
* ``DatabaseOperations.distinct_sql()`` now requires an additional ``params``
argument and returns a tuple of SQL and parameters instead of a SQL string.
:mod:`django.contrib.gis`
-------------------------

View File

@ -17,7 +17,7 @@ class DatabaseOperationTests(SimpleTestCase):
def test_distinct_on_fields(self):
msg = 'DISTINCT ON fields is not supported by this database backend'
with self.assertRaisesMessage(NotSupportedError, msg):
self.ops.distinct_sql(['a', 'b'])
self.ops.distinct_sql(['a', 'b'], None)
def test_deferrable_sql(self):
self.assertEqual(self.ops.deferrable_sql(), '')

View File

@ -63,6 +63,14 @@ class Mult3BilateralTransform(models.Transform):
return '3 * (%s)' % lhs, lhs_params
class LastDigitTransform(models.Transform):
lookup_name = 'lastdigit'
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
return 'SUBSTR(CAST(%s AS CHAR(2)), 2, 1)' % lhs, lhs_params
class UpperBilateralTransform(models.Transform):
bilateral = True
lookup_name = 'upper'
@ -379,6 +387,15 @@ class BilateralTransformTests(TestCase):
self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4])
self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
def test_transform_order_by(self):
with register_lookup(models.IntegerField, LastDigitTransform):
a1 = Author.objects.create(name='a1', age=11)
a2 = Author.objects.create(name='a2', age=23)
a3 = Author.objects.create(name='a3', age=32)
a4 = Author.objects.create(name='a4', age=40)
qs = Author.objects.order_by('age__lastdigit')
self.assertSequenceEqual(qs, [a4, a1, a3, a2])
def test_bilateral_fexpr(self):
with register_lookup(models.IntegerField, Mult3BilateralTransform):
a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)

View File

@ -1,4 +1,5 @@
from django.db.models import Max
from django.db.models import CharField, Max
from django.db.models.functions import Lower
from django.test import TestCase, skipUnlessDBFeature
from .models import Celebrity, Fan, Staff, StaffTag, Tag
@ -8,19 +9,19 @@ from .models import Celebrity, Fan, Staff, StaffTag, Tag
@skipUnlessDBFeature('supports_nullable_unique_constraints')
class DistinctOnTests(TestCase):
def setUp(self):
t1 = Tag.objects.create(name='t1')
Tag.objects.create(name='t2', parent=t1)
t3 = Tag.objects.create(name='t3', parent=t1)
Tag.objects.create(name='t4', parent=t3)
Tag.objects.create(name='t5', parent=t3)
self.t1 = Tag.objects.create(name='t1')
self.t2 = Tag.objects.create(name='t2', parent=self.t1)
self.t3 = Tag.objects.create(name='t3', parent=self.t1)
self.t4 = Tag.objects.create(name='t4', parent=self.t3)
self.t5 = Tag.objects.create(name='t5', parent=self.t3)
self.p1_o1 = Staff.objects.create(id=1, name="p1", organisation="o1")
self.p2_o1 = Staff.objects.create(id=2, name="p2", organisation="o1")
self.p3_o1 = Staff.objects.create(id=3, name="p3", organisation="o1")
self.p1_o2 = Staff.objects.create(id=4, name="p1", organisation="o2")
self.p1_o1.coworkers.add(self.p2_o1, self.p3_o1)
StaffTag.objects.create(staff=self.p1_o1, tag=t1)
StaffTag.objects.create(staff=self.p1_o1, tag=t1)
StaffTag.objects.create(staff=self.p1_o1, tag=self.t1)
StaffTag.objects.create(staff=self.p1_o1, tag=self.t1)
celeb1 = Celebrity.objects.create(name="c1")
celeb2 = Celebrity.objects.create(name="c2")
@ -95,6 +96,19 @@ class DistinctOnTests(TestCase):
c2 = c1.distinct('pk')
self.assertNotIn('OUTER JOIN', str(c2.query))
def test_transform(self):
new_name = self.t1.name.upper()
self.assertNotEqual(self.t1.name, new_name)
Tag.objects.create(name=new_name)
CharField.register_lookup(Lower)
try:
self.assertCountEqual(
Tag.objects.order_by().distinct('name__lower'),
[self.t1, self.t2, self.t3, self.t4, self.t5],
)
finally:
CharField._unregister_lookup(Lower)
def test_distinct_not_implemented_checks(self):
# distinct + annotate not allowed
msg = 'annotate() + distinct(fields) is not implemented.'

View File

@ -1363,6 +1363,40 @@ class ValueTests(TestCase):
ExpressionList()
class FieldTransformTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.sday = sday = datetime.date(2010, 6, 25)
cls.stime = stime = datetime.datetime(2010, 6, 25, 12, 15, 30, 747000)
cls.ex1 = Experiment.objects.create(
name='Experiment 1',
assigned=sday,
completed=sday + datetime.timedelta(2),
estimated_time=datetime.timedelta(2),
start=stime,
end=stime + datetime.timedelta(2),
)
def test_month_aggregation(self):
self.assertEqual(
Experiment.objects.aggregate(month_count=Count('assigned__month')),
{'month_count': 1}
)
def test_transform_in_values(self):
self.assertQuerysetEqual(
Experiment.objects.values('assigned__month'),
["{'assigned__month': 6}"]
)
def test_multiple_transforms_in_values(self):
self.assertQuerysetEqual(
Experiment.objects.values('end__date__month'),
["{'end__date__month': 6}"]
)
class ReprTests(TestCase):
def test_expressions(self):

View File

@ -309,6 +309,22 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[2:3]
)
def test_order_by_slice(self):
more_objs = (
NullableIntegerArrayModel.objects.create(field=[1, 637]),
NullableIntegerArrayModel.objects.create(field=[2, 1]),
NullableIntegerArrayModel.objects.create(field=[3, -98123]),
NullableIntegerArrayModel.objects.create(field=[4, 2]),
)
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.order_by('field__1'),
[
more_objs[2], more_objs[1], more_objs[3], self.objs[2],
self.objs[3], more_objs[0], self.objs[4], self.objs[1],
self.objs[0],
]
)
@unittest.expectedFailure
def test_slice_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])

View File

@ -148,6 +148,18 @@ class TestQuerying(HStoreTestCase):
self.objs[:2]
)
def test_order_by_field(self):
more_objs = (
HStoreModel.objects.create(field={'g': '637'}),
HStoreModel.objects.create(field={'g': '002'}),
HStoreModel.objects.create(field={'g': '042'}),
HStoreModel.objects.create(field={'g': '981'}),
)
self.assertSequenceEqual(
HStoreModel.objects.filter(field__has_key='g').order_by('field__g'),
[more_objs[1], more_objs[2], more_objs[0], more_objs[3]]
)
def test_keys_contains(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__keys__contains=['a']),

View File

@ -141,6 +141,31 @@ class TestQuerying(PostgreSQLTestCase):
[self.objs[0]]
)
def test_ordering_by_transform(self):
objs = [
JSONModel.objects.create(field={'ord': 93, 'name': 'bar'}),
JSONModel.objects.create(field={'ord': 22.1, 'name': 'foo'}),
JSONModel.objects.create(field={'ord': -1, 'name': 'baz'}),
JSONModel.objects.create(field={'ord': 21.931902, 'name': 'spam'}),
JSONModel.objects.create(field={'ord': -100291029, 'name': 'eggs'}),
]
query = JSONModel.objects.filter(field__name__isnull=False).order_by('field__ord')
self.assertSequenceEqual(query, [objs[4], objs[2], objs[3], objs[1], objs[0]])
def test_deep_values(self):
query = JSONModel.objects.values_list('field__k__l')
self.assertSequenceEqual(
query,
[
(None,), (None,), (None,), (None,), (None,), (None,),
(None,), (None,), ('m',), (None,), (None,), (None,),
]
)
def test_deep_distinct(self):
query = JSONModel.objects.distinct('field__k__l').values_list('field__k__l')
self.assertSequenceEqual(query, [('m',), (None,)])
def test_isnull_key(self):
# key__isnull works the same as has_key='key'.
self.assertSequenceEqual(