Fixed #27021 -- Allowed lookup expressions in annotations, aggregations, and QuerySet.filter().

Thanks Hannes Ljungberg and Simon Charette for reviews.

Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Ian Foote 2021-04-02 18:25:20 +01:00 committed by Mariusz Felisiak
parent f5dccbafb9
commit f42ccdd835
11 changed files with 268 additions and 48 deletions

View File

@ -6,7 +6,7 @@ from django.conf import settings
from django.db import DatabaseError, NotSupportedError from django.db import DatabaseError, NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import strip_quotes, truncate_name from django.db.backends.utils import strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
from django.utils import timezone from django.utils import timezone
@ -202,7 +202,7 @@ END;
# Oracle stores empty strings as null. If the field accepts the empty # Oracle stores empty strings as null. If the field accepts the empty
# string, undo this to adhere to the Django convention of using # string, undo this to adhere to the Django convention of using
# the empty string instead of null. # the empty string instead of null.
if expression.field.empty_strings_allowed: if expression.output_field.empty_strings_allowed:
converters.append( converters.append(
self.convert_empty_bytes self.convert_empty_bytes
if internal_type == 'BinaryField' else if internal_type == 'BinaryField' else
@ -639,7 +639,7 @@ END;
Oracle supports only EXISTS(...) or filters in the WHERE clause, others Oracle supports only EXISTS(...) or filters in the WHERE clause, others
must be compared with True. must be compared with True.
""" """
if isinstance(expression, (Exists, WhereNode)): if isinstance(expression, (Exists, Lookup, WhereNode)):
return True return True
if isinstance(expression, ExpressionWrapper) and expression.conditional: if isinstance(expression, ExpressionWrapper) and expression.conditional:
return self.conditional_expression_supported_in_where_clause(expression.expression) return self.conditional_expression_supported_in_where_clause(expression.expression)

View File

@ -1248,9 +1248,9 @@ class OrderBy(Expression):
return (template % placeholders).rstrip(), params return (template % placeholders).rstrip(), params
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
# Oracle doesn't allow ORDER BY EXISTS() unless it's wrapped in # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped
# a CASE WHEN. # in a CASE WHEN.
if isinstance(self.expression, Exists): if connection.ops.conditional_expression_supported_in_where_clause(self.expression):
copy = self.copy() copy = self.copy()
copy.expression = Case( copy.expression = Case(
When(self.expression, then=True), When(self.expression, then=True),

View File

@ -22,6 +22,9 @@ class MultiColSource:
def get_lookup(self, lookup): def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup) return self.output_field.get_lookup(lookup)
def resolve_expression(self, *args, **kwargs):
return self
def get_normalized_value(value, lhs): def get_normalized_value(value, lhs):
from django.db.models import Model from django.db.models import Model

View File

@ -1,11 +1,10 @@
import itertools import itertools
import math import math
from copy import copy
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Func, Value, When from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.fields import ( from django.db.models.fields import (
CharField, DateTimeField, Field, IntegerField, UUIDField, BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
) )
from django.db.models.query_utils import RegisterLookupMixin from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet from django.utils.datastructures import OrderedSet
@ -13,7 +12,7 @@ from django.utils.functional import cached_property
from django.utils.hashable import make_hashable from django.utils.hashable import make_hashable
class Lookup: class Lookup(Expression):
lookup_name = None lookup_name = None
prepare_rhs = True prepare_rhs = True
can_use_none_as_rhs = False can_use_none_as_rhs = False
@ -21,6 +20,7 @@ class Lookup:
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs):
self.lhs, self.rhs = lhs, rhs self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup() self.rhs = self.get_prep_lookup()
self.lhs = self.get_prep_lhs()
if hasattr(self.lhs, 'get_bilateral_transforms'): if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms() bilateral_transforms = self.lhs.get_bilateral_transforms()
else: else:
@ -72,12 +72,20 @@ class Lookup:
self.lhs, self.rhs = new_exprs self.lhs, self.rhs = new_exprs
def get_prep_lookup(self): def get_prep_lookup(self):
if hasattr(self.rhs, 'resolve_expression'): if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'):
return self.rhs return self.rhs
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): if hasattr(self.lhs, 'output_field'):
if hasattr(self.lhs.output_field, 'get_prep_value'):
return self.lhs.output_field.get_prep_value(self.rhs) return self.lhs.output_field.get_prep_value(self.rhs)
elif self.rhs_is_direct_value():
return Value(self.rhs)
return self.rhs return self.rhs
def get_prep_lhs(self):
if hasattr(self.lhs, 'resolve_expression'):
return self.lhs
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
return ('%s', [value]) return ('%s', [value])
@ -85,7 +93,11 @@ class Lookup:
lhs = lhs or self.lhs lhs = lhs or self.lhs
if hasattr(lhs, 'resolve_expression'): if hasattr(lhs, 'resolve_expression'):
lhs = lhs.resolve_expression(compiler.query) lhs = lhs.resolve_expression(compiler.query)
return compiler.compile(lhs) sql, params = compiler.compile(lhs)
if isinstance(lhs, Lookup):
# Wrapped in parentheses to respect operator precedence.
sql = f'({sql})'
return sql, params
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
value = self.rhs value = self.rhs
@ -110,22 +122,12 @@ class Lookup:
def rhs_is_direct_value(self): def rhs_is_direct_value(self):
return not hasattr(self.rhs, 'as_sql') return not hasattr(self.rhs, 'as_sql')
def relabeled_clone(self, relabels):
new = copy(self)
new.lhs = new.lhs.relabeled_clone(relabels)
if hasattr(new.rhs, 'relabeled_clone'):
new.rhs = new.rhs.relabeled_clone(relabels)
return new
def get_group_by_cols(self, alias=None): def get_group_by_cols(self, alias=None):
cols = self.lhs.get_group_by_cols() cols = []
if hasattr(self.rhs, 'get_group_by_cols'): for source in self.get_source_expressions():
cols.extend(self.rhs.get_group_by_cols()) cols.extend(source.get_group_by_cols())
return cols return cols
def as_sql(self, compiler, connection):
raise NotImplementedError
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
# Oracle doesn't allow EXISTS() and filters to be compared to another # Oracle doesn't allow EXISTS() and filters to be compared to another
# expression unless they're wrapped in a CASE WHEN. # expression unless they're wrapped in a CASE WHEN.
@ -140,16 +142,8 @@ class Lookup:
return lookup.as_sql(compiler, connection) return lookup.as_sql(compiler, connection)
@cached_property @cached_property
def contains_aggregate(self): def output_field(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) return BooleanField()
@cached_property
def contains_over_clause(self):
return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
@property
def is_summary(self):
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
@property @property
def identity(self): def identity(self):
@ -163,6 +157,21 @@ class Lookup:
def __hash__(self): def __hash__(self):
return hash(make_hashable(self.identity)) return hash(make_hashable(self.identity))
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
return sql, params
class Transform(RegisterLookupMixin, Func): class Transform(RegisterLookupMixin, Func):
""" """

View File

@ -1262,9 +1262,9 @@ class Query(BaseExpression):
if hasattr(filter_expr, 'resolve_expression'): if hasattr(filter_expr, 'resolve_expression'):
if not getattr(filter_expr, 'conditional', False): if not getattr(filter_expr, 'conditional', False):
raise TypeError('Cannot filter against a non-conditional expression.') raise TypeError('Cannot filter against a non-conditional expression.')
condition = self.build_lookup( condition = filter_expr.resolve_expression(self, allow_joins=allow_joins)
['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True if not isinstance(condition, Lookup):
) condition = self.build_lookup(['exact'], condition, True)
clause = self.where_class() clause = self.where_class()
clause.add(condition, AND) clause.add(condition, AND)
return clause, [] return clause, []

View File

@ -208,6 +208,25 @@ class WhereNode(tree.Node):
clone.resolved = True clone.resolved = True
return clone return clone
@cached_property
def output_field(self):
from django.db.models import BooleanField
return BooleanField()
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
return sql, params
def get_db_converters(self, connection):
return self.output_field.get_db_converters(connection)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
class NothingNode: class NothingNode:
"""A node that matches nothing.""" """A node that matches nothing."""

View File

@ -48,6 +48,10 @@ objects that have an ``output_field`` that is a
:class:`~django.db.models.BooleanField`. The result is provided using the :class:`~django.db.models.BooleanField`. The result is provided using the
``then`` keyword. ``then`` keyword.
.. versionchanged:: 4.0
Support for lookup expressions was added.
Some examples:: Some examples::
>>> from django.db.models import F, Q, When >>> from django.db.models import F, Q, When
@ -68,6 +72,13 @@ Some examples::
... account_type=OuterRef('account_type'), ... account_type=OuterRef('account_type'),
... ).exclude(pk=OuterRef('pk')).values('pk') ... ).exclude(pk=OuterRef('pk')).values('pk')
>>> When(Exists(non_unique_account_type), then=Value('non unique')) >>> When(Exists(non_unique_account_type), then=Value('non unique'))
>>> # Condition can be created using lookup expressions.
>>> from django.db.models.lookups import GreaterThan, LessThan
>>> When(
... GreaterThan(F('registered_on'), date(2014, 1, 1)) &
... LessThan(F('registered_on'), date(2015, 1, 1)),
... then='account_type',
... )
Keep in mind that each of these values can be an expression. Keep in mind that each of these values can be an expression.

View File

@ -25,6 +25,7 @@ Some examples
from django.db.models import Count, F, Value from django.db.models import Count, F, Value
from django.db.models.functions import Length, Upper from django.db.models.functions import Length, Upper
from django.db.models.lookups import GreaterThan
# Find companies that have more employees than chairs. # Find companies that have more employees than chairs.
Company.objects.filter(num_employees__gt=F('num_chairs')) Company.objects.filter(num_employees__gt=F('num_chairs'))
@ -76,6 +77,13 @@ Some examples
Exists(Employee.objects.filter(company=OuterRef('pk'), salary__gt=10)) Exists(Employee.objects.filter(company=OuterRef('pk'), salary__gt=10))
) )
# Lookup expressions can also be used directly in filters
Company.objects.filter(GreaterThan(F('num_employees'), F('num_chairs')))
# or annotations.
Company.objects.annotate(
need_chairs=GreaterThan(F('num_employees'), F('num_chairs')),
)
Built-in Expressions Built-in Expressions
==================== ====================

View File

@ -177,16 +177,21 @@ following methods:
comparison between ``lhs`` and ``rhs`` such as ``lhs in rhs`` or comparison between ``lhs`` and ``rhs`` such as ``lhs in rhs`` or
``lhs > rhs``. ``lhs > rhs``.
The notation to use a lookup in an expression is The primary notation to use a lookup in an expression is
``<lhs>__<lookup_name>=<rhs>``. ``<lhs>__<lookup_name>=<rhs>``. Lookups can also be used directly in
``QuerySet`` filters::
This class acts as a query expression, but, since it has ``=<rhs>`` on its Book.objects.filter(LessThan(F('word_count'), 7500))
construction, lookups must always be the end of a lookup expression.
…or annotations::
Book.objects.annotate(is_short_story=LessThan(F('word_count'), 7500))
.. attribute:: lhs .. attribute:: lhs
The left-hand side - what is being looked up. The object must follow The left-hand side - what is being looked up. The object typically
the :ref:`Query Expression API <query-expression>`. follows the :ref:`Query Expression API <query-expression>`. It may also
be a plain value.
.. attribute:: rhs .. attribute:: rhs
@ -213,3 +218,8 @@ following methods:
.. method:: process_rhs(compiler, connection) .. method:: process_rhs(compiler, connection)
Behaves the same way as :meth:`process_lhs`, for the right-hand side. Behaves the same way as :meth:`process_lhs`, for the right-hand side.
.. versionchanged:: 4.0
Support for using lookups in ``QuerySet`` annotations, aggregations,
and directly in filters was added.

View File

@ -277,6 +277,9 @@ Models
* The ``skip_locked`` argument of :meth:`.QuerySet.select_for_update()` is now * The ``skip_locked`` argument of :meth:`.QuerySet.select_for_update()` is now
allowed on MariaDB 10.6+. allowed on MariaDB 10.6+.
* :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet``
annotations, aggregations, and directly in filters.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -6,9 +6,13 @@ from operator import attrgetter
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, models from django.db import connection, models
from django.db.models import ( from django.db.models import (
BooleanField, Exists, ExpressionWrapper, F, Max, OuterRef, Q, BooleanField, Case, Exists, ExpressionWrapper, F, Max, OuterRef, Q,
Subquery, Value, When,
)
from django.db.models.functions import Cast, Substr
from django.db.models.lookups import (
Exact, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual,
) )
from django.db.models.functions import Substr
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import isolate_apps from django.test.utils import isolate_apps
@ -1020,3 +1024,156 @@ class LookupTests(TestCase):
)), )),
[stock_1, stock_2], [stock_1, stock_2],
) )
class LookupQueryingTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.s1 = Season.objects.create(year=1942, gt=1942)
cls.s2 = Season.objects.create(year=1842, gt=1942)
cls.s3 = Season.objects.create(year=2042, gt=1942)
def test_annotate(self):
qs = Season.objects.annotate(equal=Exact(F('year'), 1942))
self.assertCountEqual(
qs.values_list('year', 'equal'),
((1942, True), (1842, False), (2042, False)),
)
def test_alias(self):
qs = Season.objects.alias(greater=GreaterThan(F('year'), 1910))
self.assertCountEqual(qs.filter(greater=True), [self.s1, self.s3])
def test_annotate_value_greater_than_value(self):
qs = Season.objects.annotate(greater=GreaterThan(Value(40), Value(30)))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, True), (1842, True), (2042, True)),
)
def test_annotate_field_greater_than_field(self):
qs = Season.objects.annotate(greater=GreaterThan(F('year'), F('gt')))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, False), (1842, False), (2042, True)),
)
def test_annotate_field_greater_than_value(self):
qs = Season.objects.annotate(greater=GreaterThan(F('year'), Value(1930)))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, True), (1842, False), (2042, True)),
)
def test_annotate_field_greater_than_literal(self):
qs = Season.objects.annotate(greater=GreaterThan(F('year'), 1930))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, True), (1842, False), (2042, True)),
)
def test_annotate_literal_greater_than_field(self):
qs = Season.objects.annotate(greater=GreaterThan(1930, F('year')))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, False), (1842, True), (2042, False)),
)
def test_annotate_less_than_float(self):
qs = Season.objects.annotate(lesser=LessThan(F('year'), 1942.1))
self.assertCountEqual(
qs.values_list('year', 'lesser'),
((1942, True), (1842, True), (2042, False)),
)
def test_annotate_greater_than_or_equal(self):
qs = Season.objects.annotate(greater=GreaterThanOrEqual(F('year'), 1942))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, True), (1842, False), (2042, True)),
)
def test_annotate_greater_than_or_equal_float(self):
qs = Season.objects.annotate(greater=GreaterThanOrEqual(F('year'), 1942.1))
self.assertCountEqual(
qs.values_list('year', 'greater'),
((1942, False), (1842, False), (2042, True)),
)
def test_combined_lookups(self):
expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
qs = Season.objects.annotate(gte=expression)
self.assertCountEqual(
qs.values_list('year', 'gte'),
((1942, True), (1842, False), (2042, True)),
)
def test_lookup_in_filter(self):
qs = Season.objects.filter(GreaterThan(F('year'), 1910))
self.assertCountEqual(qs, [self.s1, self.s3])
def test_filter_lookup_lhs(self):
qs = Season.objects.annotate(before_20=LessThan(F('year'), 2000)).filter(
before_20=LessThan(F('year'), 1900),
)
self.assertCountEqual(qs, [self.s2, self.s3])
def test_filter_wrapped_lookup_lhs(self):
qs = Season.objects.annotate(before_20=ExpressionWrapper(
Q(year__lt=2000),
output_field=BooleanField(),
)).filter(before_20=LessThan(F('year'), 1900)).values_list('year', flat=True)
self.assertCountEqual(qs, [1842, 2042])
def test_filter_exists_lhs(self):
qs = Season.objects.annotate(before_20=Exists(
Season.objects.filter(pk=OuterRef('pk'), year__lt=2000),
)).filter(before_20=LessThan(F('year'), 1900))
self.assertCountEqual(qs, [self.s2, self.s3])
def test_filter_subquery_lhs(self):
qs = Season.objects.annotate(before_20=Subquery(
Season.objects.filter(pk=OuterRef('pk')).values(
lesser=LessThan(F('year'), 2000),
),
)).filter(before_20=LessThan(F('year'), 1900))
self.assertCountEqual(qs, [self.s2, self.s3])
def test_combined_lookups_in_filter(self):
expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
qs = Season.objects.filter(expression)
self.assertCountEqual(qs, [self.s1, self.s3])
def test_combined_annotated_lookups_in_filter(self):
expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
qs = Season.objects.annotate(gte=expression).filter(gte=True)
self.assertCountEqual(qs, [self.s1, self.s3])
def test_combined_annotated_lookups_in_filter_false(self):
expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
qs = Season.objects.annotate(gte=expression).filter(gte=False)
self.assertSequenceEqual(qs, [self.s2])
def test_lookup_in_order_by(self):
qs = Season.objects.order_by(LessThan(F('year'), 1910), F('year'))
self.assertSequenceEqual(qs, [self.s1, self.s3, self.s2])
@skipUnlessDBFeature('supports_boolean_expr_in_select_clause')
def test_aggregate_combined_lookup(self):
expression = Cast(GreaterThan(F('year'), 1900), models.IntegerField())
qs = Season.objects.aggregate(modern=models.Sum(expression))
self.assertEqual(qs['modern'], 2)
def test_conditional_expression(self):
qs = Season.objects.annotate(century=Case(
When(
GreaterThan(F('year'), 1900) & LessThanOrEqual(F('year'), 2000),
then=Value('20th'),
),
default=Value('other'),
)).values('year', 'century')
self.assertCountEqual(qs, [
{'year': 1942, 'century': '20th'},
{'year': 1842, 'century': 'other'},
{'year': 2042, 'century': 'other'},
])