diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index f6d8925278f..465ac70b7b6 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -158,7 +158,7 @@ class BaseDatabaseOperations: """ return '' - def distinct_sql(self, fields): + def distinct_sql(self, fields, params): """ Return an SQL DISTINCT clause which removes duplicate rows from the result set. If any fields are given, only check the given fields for @@ -167,7 +167,7 @@ class BaseDatabaseOperations: if fields: raise NotSupportedError('DISTINCT ON fields is not supported by this database backend') else: - return 'DISTINCT' + return ['DISTINCT'], [] def fetch_returned_insert_id(self, cursor): """ diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 3b71cd4f2c5..6f48cfa2284 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -207,11 +207,12 @@ class DatabaseOperations(BaseDatabaseOperations): """ return 63 - def distinct_sql(self, fields): + def distinct_sql(self, fields, params): if fields: - return 'DISTINCT ON (%s)' % ', '.join(fields) + params = [param for param_list in params for param in param_list] + return (['DISTINCT ON (%s)' % ', '.join(fields)], params) else: - return 'DISTINCT' + return ['DISTINCT'], [] def last_executed_query(self, cursor, sql, params): # http://initd.org/psycopg/docs/cursor.html#cursor.query diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1f78a8b5b2a..1fdbd156b6c 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -451,7 +451,7 @@ class SQLCompiler: raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) result, params = self.get_combinator_sql(combinator, self.query.combinator_all) else: - distinct_fields = self.get_distinct() + distinct_fields, distinct_params = self.get_distinct() # This must come after 'select', 'ordering', and 'distinct' # (see docstring of get_from_clause() for details). from_, f_params = self.get_from_clause() @@ -461,7 +461,12 @@ class SQLCompiler: params = [] if self.query.distinct: - result.append(self.connection.ops.distinct_sql(distinct_fields)) + distinct_result, distinct_params = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + ) + result += distinct_result + params += distinct_params out_cols = [] col_idx = 1 @@ -621,21 +626,22 @@ class SQLCompiler: This method can alter the tables in the query, and thus it must be called before get_from_clause(). """ - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name result = [] + params = [] opts = self.query.get_meta() for name in self.query.distinct_fields: parts = name.split(LOOKUP_SEP) - _, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None) + _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None) targets, alias, _ = self.query.trim_joins(targets, joins, path) for target in targets: if name in self.query.annotation_select: result.append(name) else: - result.append("%s.%s" % (qn(alias), qn2(target.column))) - return result + r, p = self.compile(transform_function(target, alias)) + result.append(r) + params.append(p) + return result, params def find_ordering_name(self, name, opts, alias=None, default_order='ASC', already_seen=None): @@ -647,7 +653,7 @@ class SQLCompiler: name, order = get_order_dir(name, default_order) descending = order == 'DESC' pieces = name.split(LOOKUP_SEP) - field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias) + field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias) # If we get to this point and the field is a relation to another model, # append the default ordering for that model unless the attribute name @@ -666,7 +672,7 @@ class SQLCompiler: order, already_seen)) return results targets, alias, _ = self.query.trim_joins(targets, joins, path) - return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets] + return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets] def _setup_joins(self, pieces, opts, alias): """ @@ -677,10 +683,9 @@ class SQLCompiler: match. Executing SQL where this is not true is an error. """ alias = alias or self.query.get_initial_alias() - field, targets, opts, joins, path = self.query.setup_joins( - pieces, opts, alias) + field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias) alias = joins[-1] - return field, targets, alias, joins, path, opts + return field, targets, alias, joins, path, opts, transform_function def get_from_clause(self): """ @@ -786,7 +791,7 @@ class SQLCompiler: } related_klass_infos.append(klass_info) select_fields = [] - _, _, _, joins, _ = self.query.setup_joins( + _, _, _, joins, _, _ = self.query.setup_joins( [f.name], opts, root_alias) alias = joins[-1] columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta) @@ -843,7 +848,7 @@ class SQLCompiler: break if name in self.query._filtered_relations: fields_found.add(name) - f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias) + f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias) model = join_opts.model alias = joins[-1] from_parent = issubclass(model, opts.model) and model is not opts.model diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c8b557103f7..d39514a0a50 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL databases). The abstraction barrier only works one way: this module has to know all about the internals of models in order to get the information it needs. """ +import functools from collections import Counter, OrderedDict, namedtuple from collections.abc import Iterator, Mapping from itertools import chain, count, product @@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref +from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup from django.db.models.query_utils import ( @@ -56,7 +58,7 @@ def get_children_from_q(q): JoinInfo = namedtuple( 'JoinInfo', - ('final_field', 'targets', 'opts', 'joins', 'path') + ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function') ) @@ -1429,8 +1431,11 @@ class Query: generate a MultiJoin exception. Return the final field involved in the joins, the target field (used - for any 'where' constraint), the final 'opts' value, the joins and the - field path travelled to generate the joins. + for any 'where' constraint), the final 'opts' value, the joins, the + field path traveled to generate the joins, and a transform function + that takes a field and alias and is equivalent to `field.get_col(alias)` + in the simple case but wraps field transforms if they were included in + names. The target field is the field containing the concrete value. Final field can be something different, for example foreign key pointing to @@ -1439,10 +1444,46 @@ class Query: key field for example). """ joins = [alias] - # First, generate the path for the names - path, final_field, targets, rest = self.names_to_path( - names, opts, allow_many, fail_on_missing=True) + # The transform can't be applied yet, as joins must be trimmed later. + # To avoid making every caller of this method look up transforms + # directly, compute transforms here and and create a partial that + # converts fields to the appropriate wrapped version. + def final_transformer(field, alias): + return field.get_col(alias) + + # Try resolving all the names as fields first. If there's an error, + # treat trailing names as lookups until a field can be resolved. + last_field_exception = None + for pivot in range(len(names), 0, -1): + try: + path, final_field, targets, rest = self.names_to_path( + names[:pivot], opts, allow_many, fail_on_missing=True, + ) + except FieldError as exc: + if pivot == 1: + # The first item cannot be a lookup, so it's safe + # to raise the field error here. + raise + else: + last_field_exception = exc + else: + # The transforms are the remaining items that couldn't be + # resolved into fields. + transforms = names[pivot:] + break + for name in transforms: + def transform(field, alias, *, name, previous): + try: + wrapped = previous(field, alias) + return self.try_transform(wrapped, name) + except FieldError: + # FieldError is raised if the transform doesn't exist. + if isinstance(final_field, Field) and last_field_exception: + raise last_field_exception + else: + raise + final_transformer = functools.partial(transform, name=name, previous=final_transformer) # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1470,7 +1511,7 @@ class Query: joins.append(alias) if filtered_relation: filtered_relation.path = joins[:] - return JoinInfo(final_field, targets, opts, joins, path) + return JoinInfo(final_field, targets, opts, joins, path, final_transformer) def trim_joins(self, targets, joins, path): """ @@ -1683,7 +1724,7 @@ class Query: join_info.path, ) for target in targets: - cols.append(target.get_col(final_alias)) + cols.append(join_info.transform_function(target, final_alias)) if cols: self.set_select(cols) except MultiJoin: diff --git a/docs/howto/custom-lookups.txt b/docs/howto/custom-lookups.txt index 32037a29aa6..55fdc422372 100644 --- a/docs/howto/custom-lookups.txt +++ b/docs/howto/custom-lookups.txt @@ -138,6 +138,21 @@ SQL:: Note that in case there is no other lookup specified, Django interprets ``change__abs=27`` as ``change__abs__exact=27``. +This also allows the result to be used in ``ORDER BY`` and ``DISTINCT ON`` +clauses. For example ``Experiment.objects.order_by('change__abs')`` generates:: + + SELECT ... ORDER BY ABS("experiments"."change") ASC + +And on databases that support distinct on fields (such as PostgreSQL), +``Experiment.objects.distinct('change__abs')`` generates:: + + SELECT ... DISTINCT ON ABS("experiments"."change") + +.. versionchanged:: 2.1 + + Ordering and distinct support as described in the last two paragraphs was + added. + When looking for which lookups are allowable after the ``Transform`` has been applied, Django uses the ``output_field`` attribute. We didn't need to specify this here as it didn't change, but supposing we were applying ``AbsoluteValue`` diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 8ca975cc3a6..70e3a58616c 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -64,10 +64,14 @@ Some examples # Aggregates can contain complex computations also Company.objects.annotate(num_offerings=Count(F('products') + F('services'))) - # Expressions can also be used in order_by() + # Expressions can also be used in order_by(), either directly Company.objects.order_by(Length('name').asc()) Company.objects.order_by(Length('name').desc()) - + # or using the double underscore lookup syntax. + from django.db.models import CharField + from django.db.models.functions import Length + CharField.register_lookup(Length) + Company.objects.order_by('name__length') Built-in Expressions ==================== diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index be07f0821d8..2d6702beebf 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -535,6 +535,19 @@ The ``values()`` method also takes optional keyword arguments, >>> Blog.objects.values(lower_name=Lower('name')) +You can use built-in and :doc:`custom lookups ` in +ordering. For example:: + + >>> from django.db.models import CharField + >>> from django.db.models.functions import Lower + >>> CharField.register_lookup(Lower, 'lower') + >>> Blog.objects.values('name__lower') + + +.. versionchanged:: 2.1 + + Support for lookups was added. + An aggregate within a ``values()`` clause is applied before other arguments within the same ``values()`` clause. If you need to group by another value, add it to an earlier ``values()`` clause instead. For example:: @@ -580,6 +593,25 @@ A few subtleties that are worth mentioning: * Calling :meth:`only()` and :meth:`defer()` after ``values()`` doesn't make sense, so doing so will raise a ``NotImplementedError``. +* Combining transforms and aggregates requires the use of two :meth:`annotate` + calls, either explicitly or as keyword arguments to :meth:`values`. As above, + if the transform has been registered on the relevant field type the first + :meth:`annotate` can be omitted, thus the following examples are equivalent:: + + >>> from django.db.models import CharField, Count + >>> from django.db.models.functions import Lower + >>> CharField.register_lookup(Lower, 'lower') + >>> Blog.objects.values('entry__authors__name__lower').annotate(entries=Count('entry')) + + >>> Blog.objects.values( + ... entry__authors__name__lower=Lower('entry__authors__name') + ... ).annotate(entries=Count('entry')) + + >>> Blog.objects.annotate( + ... entry__authors__name__lower=Lower('entry__authors__name') + ... ).values('entry__authors__name__lower').annotate(entries=Count('entry')) + + It is useful when you know you're only going to need values from a small number of the available fields and you won't need the functionality of a model instance object. It's more efficient to select only the fields you need to use. diff --git a/docs/releases/2.1.txt b/docs/releases/2.1.txt index ae3dd67bc9f..31252432bbd 100644 --- a/docs/releases/2.1.txt +++ b/docs/releases/2.1.txt @@ -187,6 +187,9 @@ Models * Query expressions can now be negated using a minus sign. +* :meth:`.QuerySet.order_by` and :meth:`distinct(*fields) <.QuerySet.distinct>` + now support using field transforms. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ @@ -242,6 +245,9 @@ Database backend API * Renamed the ``allow_sliced_subqueries`` database feature flag to ``allow_sliced_subqueries_with_in``. +* ``DatabaseOperations.distinct_sql()`` now requires an additional ``params`` + argument and returns a tuple of SQL and parameters instead of a SQL string. + :mod:`django.contrib.gis` ------------------------- diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py index e83857500df..510436f0d4b 100644 --- a/tests/backends/base/test_operations.py +++ b/tests/backends/base/test_operations.py @@ -17,7 +17,7 @@ class DatabaseOperationTests(SimpleTestCase): def test_distinct_on_fields(self): msg = 'DISTINCT ON fields is not supported by this database backend' with self.assertRaisesMessage(NotSupportedError, msg): - self.ops.distinct_sql(['a', 'b']) + self.ops.distinct_sql(['a', 'b'], None) def test_deferrable_sql(self): self.assertEqual(self.ops.deferrable_sql(), '') diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 9661aebc495..418525c3ed4 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -63,6 +63,14 @@ class Mult3BilateralTransform(models.Transform): return '3 * (%s)' % lhs, lhs_params +class LastDigitTransform(models.Transform): + lookup_name = 'lastdigit' + + def as_sql(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs) + return 'SUBSTR(CAST(%s AS CHAR(2)), 2, 1)' % lhs, lhs_params + + class UpperBilateralTransform(models.Transform): bilateral = True lookup_name = 'upper' @@ -379,6 +387,15 @@ class BilateralTransformTests(TestCase): self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]) self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3]) + def test_transform_order_by(self): + with register_lookup(models.IntegerField, LastDigitTransform): + a1 = Author.objects.create(name='a1', age=11) + a2 = Author.objects.create(name='a2', age=23) + a3 = Author.objects.create(name='a3', age=32) + a4 = Author.objects.create(name='a4', age=40) + qs = Author.objects.order_by('age__lastdigit') + self.assertSequenceEqual(qs, [a4, a1, a3, a2]) + def test_bilateral_fexpr(self): with register_lookup(models.IntegerField, Mult3BilateralTransform): a1 = Author.objects.create(name='a1', age=1, average_rating=3.2) diff --git a/tests/distinct_on_fields/tests.py b/tests/distinct_on_fields/tests.py index 93a332cf830..ae4eb3bd19d 100644 --- a/tests/distinct_on_fields/tests.py +++ b/tests/distinct_on_fields/tests.py @@ -1,4 +1,5 @@ -from django.db.models import Max +from django.db.models import CharField, Max +from django.db.models.functions import Lower from django.test import TestCase, skipUnlessDBFeature from .models import Celebrity, Fan, Staff, StaffTag, Tag @@ -8,19 +9,19 @@ from .models import Celebrity, Fan, Staff, StaffTag, Tag @skipUnlessDBFeature('supports_nullable_unique_constraints') class DistinctOnTests(TestCase): def setUp(self): - t1 = Tag.objects.create(name='t1') - Tag.objects.create(name='t2', parent=t1) - t3 = Tag.objects.create(name='t3', parent=t1) - Tag.objects.create(name='t4', parent=t3) - Tag.objects.create(name='t5', parent=t3) + self.t1 = Tag.objects.create(name='t1') + self.t2 = Tag.objects.create(name='t2', parent=self.t1) + self.t3 = Tag.objects.create(name='t3', parent=self.t1) + self.t4 = Tag.objects.create(name='t4', parent=self.t3) + self.t5 = Tag.objects.create(name='t5', parent=self.t3) self.p1_o1 = Staff.objects.create(id=1, name="p1", organisation="o1") self.p2_o1 = Staff.objects.create(id=2, name="p2", organisation="o1") self.p3_o1 = Staff.objects.create(id=3, name="p3", organisation="o1") self.p1_o2 = Staff.objects.create(id=4, name="p1", organisation="o2") self.p1_o1.coworkers.add(self.p2_o1, self.p3_o1) - StaffTag.objects.create(staff=self.p1_o1, tag=t1) - StaffTag.objects.create(staff=self.p1_o1, tag=t1) + StaffTag.objects.create(staff=self.p1_o1, tag=self.t1) + StaffTag.objects.create(staff=self.p1_o1, tag=self.t1) celeb1 = Celebrity.objects.create(name="c1") celeb2 = Celebrity.objects.create(name="c2") @@ -95,6 +96,19 @@ class DistinctOnTests(TestCase): c2 = c1.distinct('pk') self.assertNotIn('OUTER JOIN', str(c2.query)) + def test_transform(self): + new_name = self.t1.name.upper() + self.assertNotEqual(self.t1.name, new_name) + Tag.objects.create(name=new_name) + CharField.register_lookup(Lower) + try: + self.assertCountEqual( + Tag.objects.order_by().distinct('name__lower'), + [self.t1, self.t2, self.t3, self.t4, self.t5], + ) + finally: + CharField._unregister_lookup(Lower) + def test_distinct_not_implemented_checks(self): # distinct + annotate not allowed msg = 'annotate() + distinct(fields) is not implemented.' diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index c2e3fb391a5..2980d750c8b 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1363,6 +1363,40 @@ class ValueTests(TestCase): ExpressionList() +class FieldTransformTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.sday = sday = datetime.date(2010, 6, 25) + cls.stime = stime = datetime.datetime(2010, 6, 25, 12, 15, 30, 747000) + cls.ex1 = Experiment.objects.create( + name='Experiment 1', + assigned=sday, + completed=sday + datetime.timedelta(2), + estimated_time=datetime.timedelta(2), + start=stime, + end=stime + datetime.timedelta(2), + ) + + def test_month_aggregation(self): + self.assertEqual( + Experiment.objects.aggregate(month_count=Count('assigned__month')), + {'month_count': 1} + ) + + def test_transform_in_values(self): + self.assertQuerysetEqual( + Experiment.objects.values('assigned__month'), + ["{'assigned__month': 6}"] + ) + + def test_multiple_transforms_in_values(self): + self.assertQuerysetEqual( + Experiment.objects.values('end__date__month'), + ["{'end__date__month': 6}"] + ) + + class ReprTests(TestCase): def test_expressions(self): diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 4776bda934a..03ffa4d6371 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -309,6 +309,22 @@ class TestQuerying(PostgreSQLTestCase): self.objs[2:3] ) + def test_order_by_slice(self): + more_objs = ( + NullableIntegerArrayModel.objects.create(field=[1, 637]), + NullableIntegerArrayModel.objects.create(field=[2, 1]), + NullableIntegerArrayModel.objects.create(field=[3, -98123]), + NullableIntegerArrayModel.objects.create(field=[4, 2]), + ) + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.order_by('field__1'), + [ + more_objs[2], more_objs[1], more_objs[3], self.objs[2], + self.objs[3], more_objs[0], self.objs[4], self.objs[1], + self.objs[0], + ] + ) + @unittest.expectedFailure def test_slice_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py index 069e570f51d..b58e5e5e207 100644 --- a/tests/postgres_tests/test_hstore.py +++ b/tests/postgres_tests/test_hstore.py @@ -148,6 +148,18 @@ class TestQuerying(HStoreTestCase): self.objs[:2] ) + def test_order_by_field(self): + more_objs = ( + HStoreModel.objects.create(field={'g': '637'}), + HStoreModel.objects.create(field={'g': '002'}), + HStoreModel.objects.create(field={'g': '042'}), + HStoreModel.objects.create(field={'g': '981'}), + ) + self.assertSequenceEqual( + HStoreModel.objects.filter(field__has_key='g').order_by('field__g'), + [more_objs[1], more_objs[2], more_objs[0], more_objs[3]] + ) + def test_keys_contains(self): self.assertSequenceEqual( HStoreModel.objects.filter(field__keys__contains=['a']), diff --git a/tests/postgres_tests/test_json.py b/tests/postgres_tests/test_json.py index a572e670ac4..b22cbfc571e 100644 --- a/tests/postgres_tests/test_json.py +++ b/tests/postgres_tests/test_json.py @@ -141,6 +141,31 @@ class TestQuerying(PostgreSQLTestCase): [self.objs[0]] ) + def test_ordering_by_transform(self): + objs = [ + JSONModel.objects.create(field={'ord': 93, 'name': 'bar'}), + JSONModel.objects.create(field={'ord': 22.1, 'name': 'foo'}), + JSONModel.objects.create(field={'ord': -1, 'name': 'baz'}), + JSONModel.objects.create(field={'ord': 21.931902, 'name': 'spam'}), + JSONModel.objects.create(field={'ord': -100291029, 'name': 'eggs'}), + ] + query = JSONModel.objects.filter(field__name__isnull=False).order_by('field__ord') + self.assertSequenceEqual(query, [objs[4], objs[2], objs[3], objs[1], objs[0]]) + + def test_deep_values(self): + query = JSONModel.objects.values_list('field__k__l') + self.assertSequenceEqual( + query, + [ + (None,), (None,), (None,), (None,), (None,), (None,), + (None,), (None,), ('m',), (None,), (None,), (None,), + ] + ) + + def test_deep_distinct(self): + query = JSONModel.objects.distinct('field__k__l').values_list('field__k__l') + self.assertSequenceEqual(query, [('m',), (None,)]) + def test_isnull_key(self): # key__isnull works the same as has_key='key'. self.assertSequenceEqual(