diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index bef93d36a2..16975fb98c 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -500,8 +500,6 @@ class TemporalSubtraction(CombinedExpression): @deconstructible class F(Combinable): """An object capable of resolving references to existing query objects.""" - # Can the expression be used in a WHERE clause? - filterable = True def __init__(self, name): """ diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 0369e01348..1849d42081 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1114,6 +1114,17 @@ class Query(BaseExpression): for v in value: self.check_query_object_type(v, opts, field) + def check_filterable(self, expression): + """Raise an error if expression cannot be used in a WHERE clause.""" + if not getattr(expression, 'filterable', 'True'): + raise NotSupportedError( + expression.__class__.__name__ + ' is disallowed in the filter ' + 'clause.' + ) + if hasattr(expression, 'get_source_expressions'): + for expr in expression.get_source_expressions(): + self.check_filterable(expr) + def build_lookup(self, lookups, lhs, rhs): """ Try to extract transforms and lookup from given lhs. @@ -1217,11 +1228,7 @@ class Query(BaseExpression): raise FieldError("Cannot parse keyword query %r" % arg) lookups, parts, reffed_expression = self.solve_lookup_type(arg) - if not getattr(reffed_expression, 'filterable', True): - raise NotSupportedError( - reffed_expression.__class__.__name__ + ' is disallowed in ' - 'the filter clause.' - ) + self.check_filterable(reffed_expression) if not allow_joins and len(parts) > 1: raise FieldError("Joined field references are not permitted in this query") @@ -1230,6 +1237,8 @@ class Query(BaseExpression): value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col) used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} + self.check_filterable(value) + clause = self.where_class() if reffed_expression: condition = self.build_lookup(lookups, reffed_expression, value) diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index f2ea95fa3c..8102ffe621 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -4,7 +4,8 @@ from unittest import mock, skipIf, skipUnless from django.core.exceptions import FieldError from django.db import NotSupportedError, connection from django.db.models import ( - F, OuterRef, RowRange, Subquery, Value, ValueRange, Window, WindowFrame, + F, Func, OuterRef, Q, RowRange, Subquery, Value, ValueRange, Window, + WindowFrame, ) from django.db.models.aggregates import Avg, Max, Min, Sum from django.db.models.functions import ( @@ -833,8 +834,17 @@ class NonQueryWindowTests(SimpleTestCase): def test_invalid_filter(self): msg = 'Window is disallowed in the filter clause' + qs = Employee.objects.annotate(dense_rank=Window(expression=DenseRank())) with self.assertRaisesMessage(NotSupportedError, msg): - Employee.objects.annotate(dense_rank=Window(expression=DenseRank())).filter(dense_rank__gte=1) + qs.filter(dense_rank__gte=1) + with self.assertRaisesMessage(NotSupportedError, msg): + qs.annotate(inc_rank=F('dense_rank') + Value(1)).filter(inc_rank__gte=1) + with self.assertRaisesMessage(NotSupportedError, msg): + qs.filter(id=F('dense_rank')) + with self.assertRaisesMessage(NotSupportedError, msg): + qs.filter(id=Func('dense_rank', 2, function='div')) + with self.assertRaisesMessage(NotSupportedError, msg): + qs.annotate(total=Sum('dense_rank', filter=Q(name='Jones'))).filter(total=1) def test_invalid_order_by(self): msg = 'order_by must be either an Expression or a sequence of expressions'