From f42ccdd835e5b3f0914b5e6f87621c648136ea36 Mon Sep 17 00:00:00 2001 From: Ian Foote Date: Fri, 2 Apr 2021 18:25:20 +0100 Subject: [PATCH] Fixed #27021 -- Allowed lookup expressions in annotations, aggregations, and QuerySet.filter(). Thanks Hannes Ljungberg and Simon Charette for reviews. Co-authored-by: Mariusz Felisiak --- django/db/backends/oracle/operations.py | 6 +- django/db/models/expressions.py | 6 +- django/db/models/fields/related_lookups.py | 3 + django/db/models/lookups.py | 71 +++++---- django/db/models/sql/query.py | 6 +- django/db/models/sql/where.py | 19 +++ docs/ref/models/conditional-expressions.txt | 11 ++ docs/ref/models/expressions.txt | 8 + docs/ref/models/lookups.txt | 22 ++- docs/releases/4.0.txt | 3 + tests/lookup/tests.py | 161 +++++++++++++++++++- 11 files changed, 268 insertions(+), 48 deletions(-) diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 095a125ced..4cfc7da070 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -6,7 +6,7 @@ from django.conf import settings from django.db import DatabaseError, NotSupportedError from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import strip_quotes, truncate_name -from django.db.models import AutoField, Exists, ExpressionWrapper +from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup from django.db.models.expressions import RawSQL from django.db.models.sql.where import WhereNode from django.utils import timezone @@ -202,7 +202,7 @@ END; # Oracle stores empty strings as null. If the field accepts the empty # string, undo this to adhere to the Django convention of using # the empty string instead of null. - if expression.field.empty_strings_allowed: + if expression.output_field.empty_strings_allowed: converters.append( self.convert_empty_bytes if internal_type == 'BinaryField' else @@ -639,7 +639,7 @@ END; Oracle supports only EXISTS(...) or filters in the WHERE clause, others must be compared with True. """ - if isinstance(expression, (Exists, WhereNode)): + if isinstance(expression, (Exists, Lookup, WhereNode)): return True if isinstance(expression, ExpressionWrapper) and expression.conditional: return self.conditional_expression_supported_in_where_clause(expression.expression) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index b30d1f959b..9381257bb2 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1248,9 +1248,9 @@ class OrderBy(Expression): return (template % placeholders).rstrip(), params def as_oracle(self, compiler, connection): - # Oracle doesn't allow ORDER BY EXISTS() unless it's wrapped in - # a CASE WHEN. - if isinstance(self.expression, Exists): + # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped + # in a CASE WHEN. + if connection.ops.conditional_expression_supported_in_where_clause(self.expression): copy = self.copy() copy.expression = Case( When(self.expression, then=True), diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index d745ecd5f9..34cca8ba5e 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -22,6 +22,9 @@ class MultiColSource: def get_lookup(self, lookup): return self.output_field.get_lookup(lookup) + def resolve_expression(self, *args, **kwargs): + return self + def get_normalized_value(value, lhs): from django.db.models import Model diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 8eb6702204..0f104416de 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,11 +1,10 @@ import itertools import math -from copy import copy from django.core.exceptions import EmptyResultSet -from django.db.models.expressions import Case, Func, Value, When +from django.db.models.expressions import Case, Expression, Func, Value, When from django.db.models.fields import ( - CharField, DateTimeField, Field, IntegerField, UUIDField, + BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField, ) from django.db.models.query_utils import RegisterLookupMixin from django.utils.datastructures import OrderedSet @@ -13,7 +12,7 @@ from django.utils.functional import cached_property from django.utils.hashable import make_hashable -class Lookup: +class Lookup(Expression): lookup_name = None prepare_rhs = True can_use_none_as_rhs = False @@ -21,6 +20,7 @@ class Lookup: def __init__(self, lhs, rhs): self.lhs, self.rhs = lhs, rhs self.rhs = self.get_prep_lookup() + self.lhs = self.get_prep_lhs() if hasattr(self.lhs, 'get_bilateral_transforms'): bilateral_transforms = self.lhs.get_bilateral_transforms() else: @@ -72,12 +72,20 @@ class Lookup: self.lhs, self.rhs = new_exprs def get_prep_lookup(self): - if hasattr(self.rhs, 'resolve_expression'): + if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'): return self.rhs - if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): - return self.lhs.output_field.get_prep_value(self.rhs) + if hasattr(self.lhs, 'output_field'): + if hasattr(self.lhs.output_field, 'get_prep_value'): + return self.lhs.output_field.get_prep_value(self.rhs) + elif self.rhs_is_direct_value(): + return Value(self.rhs) return self.rhs + def get_prep_lhs(self): + if hasattr(self.lhs, 'resolve_expression'): + return self.lhs + return Value(self.lhs) + def get_db_prep_lookup(self, value, connection): return ('%s', [value]) @@ -85,7 +93,11 @@ class Lookup: lhs = lhs or self.lhs if hasattr(lhs, 'resolve_expression'): lhs = lhs.resolve_expression(compiler.query) - return compiler.compile(lhs) + sql, params = compiler.compile(lhs) + if isinstance(lhs, Lookup): + # Wrapped in parentheses to respect operator precedence. + sql = f'({sql})' + return sql, params def process_rhs(self, compiler, connection): value = self.rhs @@ -110,22 +122,12 @@ class Lookup: def rhs_is_direct_value(self): return not hasattr(self.rhs, 'as_sql') - def relabeled_clone(self, relabels): - new = copy(self) - new.lhs = new.lhs.relabeled_clone(relabels) - if hasattr(new.rhs, 'relabeled_clone'): - new.rhs = new.rhs.relabeled_clone(relabels) - return new - def get_group_by_cols(self, alias=None): - cols = self.lhs.get_group_by_cols() - if hasattr(self.rhs, 'get_group_by_cols'): - cols.extend(self.rhs.get_group_by_cols()) + cols = [] + for source in self.get_source_expressions(): + cols.extend(source.get_group_by_cols()) return cols - def as_sql(self, compiler, connection): - raise NotImplementedError - def as_oracle(self, compiler, connection): # Oracle doesn't allow EXISTS() and filters to be compared to another # expression unless they're wrapped in a CASE WHEN. @@ -140,16 +142,8 @@ class Lookup: return lookup.as_sql(compiler, connection) @cached_property - def contains_aggregate(self): - return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) - - @cached_property - def contains_over_clause(self): - return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False) - - @property - def is_summary(self): - return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) + def output_field(self): + return BooleanField() @property def identity(self): @@ -163,6 +157,21 @@ class Lookup: def __hash__(self): return hash(make_hashable(self.identity)) + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + c = self.copy() + c.is_summary = summarize + c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) + return c + + def select_format(self, compiler, sql, params): + # Wrap filters with a CASE WHEN expression if a database backend + # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP + # BY list. + if not compiler.connection.features.supports_boolean_expr_in_select_clause: + sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END' + return sql, params + class Transform(RegisterLookupMixin, Func): """ diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 87be8ea9f0..2412e6ad4e 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1262,9 +1262,9 @@ class Query(BaseExpression): if hasattr(filter_expr, 'resolve_expression'): if not getattr(filter_expr, 'conditional', False): raise TypeError('Cannot filter against a non-conditional expression.') - condition = self.build_lookup( - ['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True - ) + condition = filter_expr.resolve_expression(self, allow_joins=allow_joins) + if not isinstance(condition, Lookup): + condition = self.build_lookup(['exact'], condition, True) clause = self.where_class() clause.add(condition, AND) return clause, [] diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 2577e1d7a5..5a4da97396 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -208,6 +208,25 @@ class WhereNode(tree.Node): clone.resolved = True return clone + @cached_property + def output_field(self): + from django.db.models import BooleanField + return BooleanField() + + def select_format(self, compiler, sql, params): + # Wrap filters with a CASE WHEN expression if a database backend + # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP + # BY list. + if not compiler.connection.features.supports_boolean_expr_in_select_clause: + sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END' + return sql, params + + def get_db_converters(self, connection): + return self.output_field.get_db_converters(connection) + + def get_lookup(self, lookup): + return self.output_field.get_lookup(lookup) + class NothingNode: """A node that matches nothing.""" diff --git a/docs/ref/models/conditional-expressions.txt b/docs/ref/models/conditional-expressions.txt index 546733dd8d..b7ab3ec3dc 100644 --- a/docs/ref/models/conditional-expressions.txt +++ b/docs/ref/models/conditional-expressions.txt @@ -48,6 +48,10 @@ objects that have an ``output_field`` that is a :class:`~django.db.models.BooleanField`. The result is provided using the ``then`` keyword. +.. versionchanged:: 4.0 + + Support for lookup expressions was added. + Some examples:: >>> from django.db.models import F, Q, When @@ -68,6 +72,13 @@ Some examples:: ... account_type=OuterRef('account_type'), ... ).exclude(pk=OuterRef('pk')).values('pk') >>> When(Exists(non_unique_account_type), then=Value('non unique')) + >>> # Condition can be created using lookup expressions. + >>> from django.db.models.lookups import GreaterThan, LessThan + >>> When( + ... GreaterThan(F('registered_on'), date(2014, 1, 1)) & + ... LessThan(F('registered_on'), date(2015, 1, 1)), + ... then='account_type', + ... ) Keep in mind that each of these values can be an expression. diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index d45cfbe00a..845e734dc1 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -25,6 +25,7 @@ Some examples from django.db.models import Count, F, Value from django.db.models.functions import Length, Upper + from django.db.models.lookups import GreaterThan # Find companies that have more employees than chairs. Company.objects.filter(num_employees__gt=F('num_chairs')) @@ -76,6 +77,13 @@ Some examples Exists(Employee.objects.filter(company=OuterRef('pk'), salary__gt=10)) ) + # Lookup expressions can also be used directly in filters + Company.objects.filter(GreaterThan(F('num_employees'), F('num_chairs'))) + # or annotations. + Company.objects.annotate( + need_chairs=GreaterThan(F('num_employees'), F('num_chairs')), + ) + Built-in Expressions ==================== diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt index fdbde328de..f4fa0f899d 100644 --- a/docs/ref/models/lookups.txt +++ b/docs/ref/models/lookups.txt @@ -177,16 +177,21 @@ following methods: comparison between ``lhs`` and ``rhs`` such as ``lhs in rhs`` or ``lhs > rhs``. - The notation to use a lookup in an expression is - ``__=``. + The primary notation to use a lookup in an expression is + ``__=``. Lookups can also be used directly in + ``QuerySet`` filters:: - This class acts as a query expression, but, since it has ``=`` on its - construction, lookups must always be the end of a lookup expression. + Book.objects.filter(LessThan(F('word_count'), 7500)) + + …or annotations:: + + Book.objects.annotate(is_short_story=LessThan(F('word_count'), 7500)) .. attribute:: lhs - The left-hand side - what is being looked up. The object must follow - the :ref:`Query Expression API `. + The left-hand side - what is being looked up. The object typically + follows the :ref:`Query Expression API `. It may also + be a plain value. .. attribute:: rhs @@ -213,3 +218,8 @@ following methods: .. method:: process_rhs(compiler, connection) Behaves the same way as :meth:`process_lhs`, for the right-hand side. + + .. versionchanged:: 4.0 + + Support for using lookups in ``QuerySet`` annotations, aggregations, + and directly in filters was added. diff --git a/docs/releases/4.0.txt b/docs/releases/4.0.txt index 3ef008e5fe..be59681fda 100644 --- a/docs/releases/4.0.txt +++ b/docs/releases/4.0.txt @@ -277,6 +277,9 @@ Models * The ``skip_locked`` argument of :meth:`.QuerySet.select_for_update()` is now allowed on MariaDB 10.6+. +* :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet`` + annotations, aggregations, and directly in filters. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index e38eae087c..168a621a7e 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -6,9 +6,13 @@ from operator import attrgetter from django.core.exceptions import FieldError from django.db import connection, models from django.db.models import ( - BooleanField, Exists, ExpressionWrapper, F, Max, OuterRef, Q, + BooleanField, Case, Exists, ExpressionWrapper, F, Max, OuterRef, Q, + Subquery, Value, When, +) +from django.db.models.functions import Cast, Substr +from django.db.models.lookups import ( + Exact, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, ) -from django.db.models.functions import Substr from django.test import TestCase, skipUnlessDBFeature from django.test.utils import isolate_apps @@ -1020,3 +1024,156 @@ class LookupTests(TestCase): )), [stock_1, stock_2], ) + + +class LookupQueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.s1 = Season.objects.create(year=1942, gt=1942) + cls.s2 = Season.objects.create(year=1842, gt=1942) + cls.s3 = Season.objects.create(year=2042, gt=1942) + + def test_annotate(self): + qs = Season.objects.annotate(equal=Exact(F('year'), 1942)) + self.assertCountEqual( + qs.values_list('year', 'equal'), + ((1942, True), (1842, False), (2042, False)), + ) + + def test_alias(self): + qs = Season.objects.alias(greater=GreaterThan(F('year'), 1910)) + self.assertCountEqual(qs.filter(greater=True), [self.s1, self.s3]) + + def test_annotate_value_greater_than_value(self): + qs = Season.objects.annotate(greater=GreaterThan(Value(40), Value(30))) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, True), (1842, True), (2042, True)), + ) + + def test_annotate_field_greater_than_field(self): + qs = Season.objects.annotate(greater=GreaterThan(F('year'), F('gt'))) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, False), (1842, False), (2042, True)), + ) + + def test_annotate_field_greater_than_value(self): + qs = Season.objects.annotate(greater=GreaterThan(F('year'), Value(1930))) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, True), (1842, False), (2042, True)), + ) + + def test_annotate_field_greater_than_literal(self): + qs = Season.objects.annotate(greater=GreaterThan(F('year'), 1930)) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, True), (1842, False), (2042, True)), + ) + + def test_annotate_literal_greater_than_field(self): + qs = Season.objects.annotate(greater=GreaterThan(1930, F('year'))) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, False), (1842, True), (2042, False)), + ) + + def test_annotate_less_than_float(self): + qs = Season.objects.annotate(lesser=LessThan(F('year'), 1942.1)) + self.assertCountEqual( + qs.values_list('year', 'lesser'), + ((1942, True), (1842, True), (2042, False)), + ) + + def test_annotate_greater_than_or_equal(self): + qs = Season.objects.annotate(greater=GreaterThanOrEqual(F('year'), 1942)) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, True), (1842, False), (2042, True)), + ) + + def test_annotate_greater_than_or_equal_float(self): + qs = Season.objects.annotate(greater=GreaterThanOrEqual(F('year'), 1942.1)) + self.assertCountEqual( + qs.values_list('year', 'greater'), + ((1942, False), (1842, False), (2042, True)), + ) + + def test_combined_lookups(self): + expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942) + qs = Season.objects.annotate(gte=expression) + self.assertCountEqual( + qs.values_list('year', 'gte'), + ((1942, True), (1842, False), (2042, True)), + ) + + def test_lookup_in_filter(self): + qs = Season.objects.filter(GreaterThan(F('year'), 1910)) + self.assertCountEqual(qs, [self.s1, self.s3]) + + def test_filter_lookup_lhs(self): + qs = Season.objects.annotate(before_20=LessThan(F('year'), 2000)).filter( + before_20=LessThan(F('year'), 1900), + ) + self.assertCountEqual(qs, [self.s2, self.s3]) + + def test_filter_wrapped_lookup_lhs(self): + qs = Season.objects.annotate(before_20=ExpressionWrapper( + Q(year__lt=2000), + output_field=BooleanField(), + )).filter(before_20=LessThan(F('year'), 1900)).values_list('year', flat=True) + self.assertCountEqual(qs, [1842, 2042]) + + def test_filter_exists_lhs(self): + qs = Season.objects.annotate(before_20=Exists( + Season.objects.filter(pk=OuterRef('pk'), year__lt=2000), + )).filter(before_20=LessThan(F('year'), 1900)) + self.assertCountEqual(qs, [self.s2, self.s3]) + + def test_filter_subquery_lhs(self): + qs = Season.objects.annotate(before_20=Subquery( + Season.objects.filter(pk=OuterRef('pk')).values( + lesser=LessThan(F('year'), 2000), + ), + )).filter(before_20=LessThan(F('year'), 1900)) + self.assertCountEqual(qs, [self.s2, self.s3]) + + def test_combined_lookups_in_filter(self): + expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942) + qs = Season.objects.filter(expression) + self.assertCountEqual(qs, [self.s1, self.s3]) + + def test_combined_annotated_lookups_in_filter(self): + expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942) + qs = Season.objects.annotate(gte=expression).filter(gte=True) + self.assertCountEqual(qs, [self.s1, self.s3]) + + def test_combined_annotated_lookups_in_filter_false(self): + expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942) + qs = Season.objects.annotate(gte=expression).filter(gte=False) + self.assertSequenceEqual(qs, [self.s2]) + + def test_lookup_in_order_by(self): + qs = Season.objects.order_by(LessThan(F('year'), 1910), F('year')) + self.assertSequenceEqual(qs, [self.s1, self.s3, self.s2]) + + @skipUnlessDBFeature('supports_boolean_expr_in_select_clause') + def test_aggregate_combined_lookup(self): + expression = Cast(GreaterThan(F('year'), 1900), models.IntegerField()) + qs = Season.objects.aggregate(modern=models.Sum(expression)) + self.assertEqual(qs['modern'], 2) + + def test_conditional_expression(self): + qs = Season.objects.annotate(century=Case( + When( + GreaterThan(F('year'), 1900) & LessThanOrEqual(F('year'), 2000), + then=Value('20th'), + ), + default=Value('other'), + )).values('year', 'century') + self.assertCountEqual(qs, [ + {'year': 1942, 'century': '20th'}, + {'year': 1842, 'century': 'other'}, + {'year': 2042, 'century': 'other'}, + ])