diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index a8ab03a55e9..bd2715fb43a 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -28,10 +28,13 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): # the SQLDeleteCompiler's default implementation when multiple tables # are involved since MySQL/MariaDB will generate a more efficient query # plan than when using a subquery. - where, having = self.query.where.split_having() - if self.single_alias or having: - # DELETE FROM cannot be used when filtering against aggregates - # since it doesn't allow for GROUP BY and HAVING clauses. + where, having, qualify = self.query.where.split_having_qualify( + must_group_by=self.query.group_by is not None + ) + if self.single_alias or having or qualify: + # DELETE FROM cannot be used when filtering against aggregates or + # window functions as it doesn't allow for GROUP BY/HAVING clauses + # and the subquery wrapping (necessary to emulate QUALIFY). return super().as_sql() result = [ "DELETE %s FROM" diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 7d212851188..58e945dd44e 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -836,6 +836,7 @@ class ResolvedOuterRef(F): """ contains_aggregate = False + contains_over_clause = False def as_sql(self, *args, **kwargs): raise ValueError( @@ -1210,6 +1211,12 @@ class OrderByList(Func): return "", () return super().as_sql(*args, **kwargs) + def get_group_by_cols(self): + group_by_cols = [] + for order_by in self.get_source_expressions(): + group_by_cols.extend(order_by.get_group_by_cols()) + return group_by_cols + @deconstructible(path="django.db.models.ExpressionWrapper") class ExpressionWrapper(SQLiteNumericMixin, Expression): @@ -1631,7 +1638,6 @@ class Window(SQLiteNumericMixin, Expression): # be introduced in the query as a result is not desired. contains_aggregate = False contains_over_clause = True - filterable = False def __init__( self, @@ -1733,7 +1739,12 @@ class Window(SQLiteNumericMixin, Expression): return "<%s: %s>" % (self.__class__.__name__, self) def get_group_by_cols(self, alias=None): - return [] + group_by_cols = [] + if self.partition_by: + group_by_cols.extend(self.partition_by.get_group_by_cols()) + if self.order_by is not None: + group_by_cols.extend(self.order_by.get_group_by_cols()) + return group_by_cols class WindowFrame(Expression): diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index 17a8622ff9a..1a845a1f7f1 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -14,6 +14,7 @@ from django.utils.deprecation import RemovedInDjango50Warning class MultiColSource: contains_aggregate = False + contains_over_clause = False def __init__(self, alias, targets, sources, field): self.targets, self.sources, self.field, self.alias = ( diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1c0ab2d2127..858142913b6 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -9,6 +9,7 @@ from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value from django.db.models.functions import Cast, Random +from django.db.models.lookups import Lookup from django.db.models.query_utils import select_related_descend from django.db.models.sql.constants import ( CURSOR, @@ -73,7 +74,9 @@ class SQLCompiler: """ self.setup_query(with_col_aliases=with_col_aliases) order_by = self.get_order_by() - self.where, self.having = self.query.where.split_having() + self.where, self.having, self.qualify = self.query.where.split_having_qualify( + must_group_by=self.query.group_by is not None + ) extra_select = self.get_extra_select(order_by, self.select) self.has_extra_select = bool(extra_select) group_by = self.get_group_by(self.select + extra_select, order_by) @@ -584,6 +587,74 @@ class SQLCompiler: params.extend(part) return result, params + def get_qualify_sql(self): + where_parts = [] + if self.where: + where_parts.append(self.where) + if self.having: + where_parts.append(self.having) + inner_query = self.query.clone() + inner_query.subquery = True + inner_query.where = inner_query.where.__class__(where_parts) + # Augment the inner query with any window function references that + # might have been masked via values() and alias(). If any masked + # aliases are added they'll be masked again to avoid fetching + # the data in the `if qual_aliases` branch below. + select = { + expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0] + } + qual_aliases = set() + replacements = {} + expressions = list(self.qualify.leaves()) + while expressions: + expr = expressions.pop() + if select_alias := (select.get(expr) or replacements.get(expr)): + replacements[expr] = select_alias + elif isinstance(expr, Lookup): + expressions.extend(expr.get_source_expressions()) + else: + num_qual_alias = len(qual_aliases) + select_alias = f"qual{num_qual_alias}" + qual_aliases.add(select_alias) + inner_query.add_annotation(expr, select_alias) + replacements[expr] = select_alias + self.qualify = self.qualify.replace_expressions( + {expr: Ref(alias, expr) for expr, alias in replacements.items()} + ) + inner_query_compiler = inner_query.get_compiler( + self.using, elide_empty=self.elide_empty + ) + inner_sql, inner_params = inner_query_compiler.as_sql( + # The limits must be applied to the outer query to avoid pruning + # results too eagerly. + with_limits=False, + # Force unique aliasing of selected columns to avoid collisions + # and make rhs predicates referencing easier. + with_col_aliases=True, + ) + qualify_sql, qualify_params = self.compile(self.qualify) + result = [ + "SELECT * FROM (", + inner_sql, + ")", + self.connection.ops.quote_name("qualify"), + "WHERE", + qualify_sql, + ] + if qual_aliases: + # If some select aliases were unmasked for filtering purposes they + # must be masked back. + cols = [self.connection.ops.quote_name(alias) for alias in select.values()] + result = [ + "SELECT", + ", ".join(cols), + "FROM (", + *result, + ")", + self.connection.ops.quote_name("qualify_mask"), + ] + return result, list(inner_params) + qualify_params + def as_sql(self, with_limits=True, with_col_aliases=False): """ Create the SQL for this query. Return the SQL string and list of @@ -614,6 +685,9 @@ class SQLCompiler: result, params = self.get_combinator_sql( combinator, self.query.combinator_all ) + elif self.qualify: + result, params = self.get_qualify_sql() + order_by = None else: distinct_fields, distinct_params = self.get_distinct() # This must come after 'select', 'ordering', and 'distinct' diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 42a4b054a59..e2af46a3092 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -35,48 +35,81 @@ class WhereNode(tree.Node): resolved = False conditional = True - def split_having(self, negated=False): + def split_having_qualify(self, negated=False, must_group_by=False): """ - Return two possibly None nodes: one for those parts of self that - should be included in the WHERE clause and one for those parts of - self that must be included in the HAVING clause. + Return three possibly None nodes: one for those parts of self that + should be included in the WHERE clause, one for those parts of self + that must be included in the HAVING clause, and one for those parts + that refer to window functions. """ - if not self.contains_aggregate: - return self, None + if not self.contains_aggregate and not self.contains_over_clause: + return self, None, None in_negated = negated ^ self.negated - # If the effective connector is OR or XOR and this node contains an - # aggregate, then we need to push the whole branch to HAVING clause. - may_need_split = ( + # Whether or not children must be connected in the same filtering + # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic. + must_remain_connected = ( (in_negated and self.connector == AND) or (not in_negated and self.connector == OR) or self.connector == XOR ) - if may_need_split and self.contains_aggregate: - return None, self + if ( + must_remain_connected + and self.contains_aggregate + and not self.contains_over_clause + ): + # It's must cheaper to short-circuit and stash everything in the + # HAVING clause than split children if possible. + return None, self, None where_parts = [] having_parts = [] + qualify_parts = [] for c in self.children: - if hasattr(c, "split_having"): - where_part, having_part = c.split_having(in_negated) + if hasattr(c, "split_having_qualify"): + where_part, having_part, qualify_part = c.split_having_qualify( + in_negated, must_group_by + ) if where_part is not None: where_parts.append(where_part) if having_part is not None: having_parts.append(having_part) + if qualify_part is not None: + qualify_parts.append(qualify_part) + elif c.contains_over_clause: + qualify_parts.append(c) elif c.contains_aggregate: having_parts.append(c) else: where_parts.append(c) - having_node = ( - self.create(having_parts, self.connector, self.negated) - if having_parts - else None - ) + if must_remain_connected and qualify_parts: + # Disjunctive heterogeneous predicates can be pushed down to + # qualify as long as no conditional aggregation is involved. + if not where_parts or (where_parts and not must_group_by): + return None, None, self + elif where_parts: + # In theory this should only be enforced when dealing with + # where_parts containing predicates against multi-valued + # relationships that could affect aggregation results but this + # is complex to infer properly. + raise NotImplementedError( + "Heterogeneous disjunctive predicates against window functions are " + "not implemented when performing conditional aggregation." + ) where_node = ( self.create(where_parts, self.connector, self.negated) if where_parts else None ) - return where_node, having_node + having_node = ( + self.create(having_parts, self.connector, self.negated) + if having_parts + else None + ) + qualify_node = ( + self.create(qualify_parts, self.connector, self.negated) + if qualify_parts + else None + ) + return where_node, having_node, qualify_node def as_sql(self, compiler, connection): """ @@ -183,6 +216,14 @@ class WhereNode(tree.Node): clone.relabel_aliases(change_map) return clone + def replace_expressions(self, replacements): + if replacement := replacements.get(self): + return replacement + clone = self.create(connector=self.connector, negated=self.negated) + for child in self.children: + clone.children.append(child.replace_expressions(replacements)) + return clone + @classmethod def _contains_aggregate(cls, obj): if isinstance(obj, tree.Node): @@ -231,6 +272,10 @@ class WhereNode(tree.Node): return BooleanField() + @property + def _output_field_or_none(self): + return self.output_field + def select_format(self, compiler, sql, params): # Wrap filters with a CASE WHEN expression if a database backend # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP @@ -245,19 +290,28 @@ class WhereNode(tree.Node): def get_lookup(self, lookup): return self.output_field.get_lookup(lookup) + def leaves(self): + for child in self.children: + if isinstance(child, WhereNode): + yield from child.leaves() + else: + yield child + class NothingNode: """A node that matches nothing.""" contains_aggregate = False + contains_over_clause = False def as_sql(self, compiler=None, connection=None): raise EmptyResultSet class ExtraWhere: - # The contents are a black box - assume no aggregates are used. + # The contents are a black box - assume no aggregates or windows are used. contains_aggregate = False + contains_over_clause = False def __init__(self, sqls, params): self.sqls = sqls @@ -269,9 +323,10 @@ class ExtraWhere: class SubqueryConstraint: - # Even if aggregates would be used in a subquery, the outer query isn't - # interested about those. + # Even if aggregates or windows would be used in a subquery, + # the outer query isn't interested about those. contains_aggregate = False + contains_over_clause = False def __init__(self, alias, columns, targets, query_object): self.alias = alias diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 25edd1f3e82..95f093e2a35 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -741,12 +741,6 @@ instead they are part of the selected columns. .. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None) - .. attribute:: filterable - - Defaults to ``False``. The SQL standard disallows referencing window - functions in the ``WHERE`` clause and Django raises an exception when - constructing a ``QuerySet`` that would do that. - .. attribute:: template Defaults to ``%(expression)s OVER (%(window)s)'``. If only the @@ -819,6 +813,31 @@ to reduce repetition:: >>> ), >>> ) +Filtering against window functions is supported as long as lookups are not +disjunctive (not using ``OR`` or ``XOR`` as a connector) and against a queryset +performing aggregation. + +For example, a query that relies on aggregation and has an ``OR``-ed filter +against a window function and a field is not supported. Applying combined +predicates post-aggregation could cause rows that would normally be excluded +from groups to be included:: + + >>> qs = Movie.objects.annotate( + >>> category_rank=Window( + >>> Rank(), partition_by='category', order_by='-rating' + >>> ), + >>> scenes_count=Count('actors'), + >>> ).filter( + >>> Q(category_rank__lte=3) | Q(title__contains='Batman') + >>> ) + >>> list(qs) + NotImplementedError: Heterogeneous disjunctive predicates against window functions + are not implemented when performing conditional aggregation. + +.. versionchanged:: 4.2 + + Support for filtering against window functions was added. + Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle support window expressions. Support for different window expression features varies among the different databases. For example, the options in diff --git a/docs/releases/4.2.txt b/docs/releases/4.2.txt index cb647500096..ed00ee1350f 100644 --- a/docs/releases/4.2.txt +++ b/docs/releases/4.2.txt @@ -189,7 +189,9 @@ Migrations Models ~~~~~~ -* ... +* ``QuerySet`` now extensively supports filtering against + :ref:`window-functions` with the exception of disjunctive filter lookups + against window functions when performing aggregation. Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions_window/models.py b/tests/expressions_window/models.py index 631e876e15f..cf324ea8f6e 100644 --- a/tests/expressions_window/models.py +++ b/tests/expressions_window/models.py @@ -17,6 +17,13 @@ class Employee(models.Model): bonus = models.DecimalField(decimal_places=2, max_digits=15, null=True) +class PastEmployeeDepartment(models.Model): + employee = models.ForeignKey( + Employee, related_name="past_departments", on_delete=models.CASCADE + ) + department = models.CharField(max_length=40, blank=False, null=False) + + class Detail(models.Model): value = models.JSONField() diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index 15f8a4d6b28..a71a3f947d9 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -6,10 +6,9 @@ from django.core.exceptions import FieldError from django.db import NotSupportedError, connection from django.db.models import ( Avg, - BooleanField, Case, + Count, F, - Func, IntegerField, Max, Min, @@ -41,15 +40,17 @@ from django.db.models.functions import ( RowNumber, Upper, ) +from django.db.models.lookups import Exact from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature -from .models import Detail, Employee +from .models import Classification, Detail, Employee, PastEmployeeDepartment @skipUnlessDBFeature("supports_over_clause") class WindowFunctionTests(TestCase): @classmethod def setUpTestData(cls): + classification = Classification.objects.create() Employee.objects.bulk_create( [ Employee( @@ -59,6 +60,7 @@ class WindowFunctionTests(TestCase): hire_date=e[3], age=e[4], bonus=Decimal(e[1]) / 400, + classification=classification, ) for e in [ ("Jones", 45000, "Accounting", datetime.datetime(2005, 11, 1), 20), @@ -82,6 +84,13 @@ class WindowFunctionTests(TestCase): ] ] ) + employees = list(Employee.objects.order_by("pk")) + PastEmployeeDepartment.objects.bulk_create( + [ + PastEmployeeDepartment(employee=employees[6], department="Sales"), + PastEmployeeDepartment(employee=employees[10], department="IT"), + ] + ) def test_dense_rank(self): tests = [ @@ -902,6 +911,263 @@ class WindowFunctionTests(TestCase): ) self.assertEqual(qs.count(), 12) + def test_filter(self): + qs = Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ).order_by("department", "name") + # Direct window reference. + self.assertQuerysetEqual( + qs.filter(department_salary_rank=1), + ["Adams", "Wilkinson", "Miller", "Johnson", "Smith"], + lambda employee: employee.name, + ) + # Through a combined expression containing a window. + self.assertQuerysetEqual( + qs.filter(department_avg_age_diff__gt=0), + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + # Intersection of multiple windows. + self.assertQuerysetEqual( + qs.filter(department_salary_rank=1, department_avg_age_diff__gt=0), + ["Miller"], + lambda employee: employee.name, + ) + # Union of multiple windows. + self.assertQuerysetEqual( + qs.filter(Q(department_salary_rank=1) | Q(department_avg_age_diff__gt=0)), + [ + "Adams", + "Jenson", + "Jones", + "Williams", + "Wilkinson", + "Miller", + "Johnson", + "Smith", + "Smith", + ], + lambda employee: employee.name, + ) + + def test_filter_conditional_annotation(self): + qs = ( + Employee.objects.annotate( + rank=Window(Rank(), partition_by="department", order_by="-salary"), + case_first_rank=Case( + When(rank=1, then=True), + default=False, + ), + q_first_rank=Q(rank=1), + ) + .order_by("name") + .values_list("name", flat=True) + ) + for annotation in ["case_first_rank", "q_first_rank"]: + with self.subTest(annotation=annotation): + self.assertSequenceEqual( + qs.filter(**{annotation: True}), + ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"], + ) + + def test_filter_conditional_expression(self): + qs = ( + Employee.objects.filter( + Exact(Window(Rank(), partition_by="department", order_by="-salary"), 1) + ) + .order_by("name") + .values_list("name", flat=True) + ) + self.assertSequenceEqual( + qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"] + ) + + def test_filter_column_ref_rhs(self): + qs = ( + Employee.objects.annotate( + max_dept_salary=Window(Max("salary"), partition_by="department") + ) + .filter(max_dept_salary=F("salary")) + .order_by("name") + .values_list("name", flat=True) + ) + self.assertSequenceEqual( + qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"] + ) + + def test_filter_values(self): + qs = ( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + ) + .order_by("department", "name") + .values_list(Upper("name"), flat=True) + ) + self.assertSequenceEqual( + qs.filter(department_salary_rank=1), + ["ADAMS", "WILKINSON", "MILLER", "JOHNSON", "SMITH"], + ) + + def test_filter_alias(self): + qs = Employee.objects.alias( + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ).order_by("department", "name") + self.assertQuerysetEqual( + qs.filter(department_avg_age_diff__gt=0), + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + + def test_filter_select_related(self): + qs = ( + Employee.objects.alias( + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ) + .select_related("classification") + .filter(department_avg_age_diff__gt=0) + .order_by("department", "name") + ) + self.assertQuerysetEqual( + qs, + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + with self.assertNumQueries(0): + qs[0].classification + + def test_exclude(self): + qs = Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ).order_by("department", "name") + # Direct window reference. + self.assertQuerysetEqual( + qs.exclude(department_salary_rank__gt=1), + ["Adams", "Wilkinson", "Miller", "Johnson", "Smith"], + lambda employee: employee.name, + ) + # Through a combined expression containing a window. + self.assertQuerysetEqual( + qs.exclude(department_avg_age_diff__lte=0), + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + # Union of multiple windows. + self.assertQuerysetEqual( + qs.exclude( + Q(department_salary_rank__gt=1) | Q(department_avg_age_diff__lte=0) + ), + ["Miller"], + lambda employee: employee.name, + ) + # Intersection of multiple windows. + self.assertQuerysetEqual( + qs.exclude(department_salary_rank__gt=1, department_avg_age_diff__lte=0), + [ + "Adams", + "Jenson", + "Jones", + "Williams", + "Wilkinson", + "Miller", + "Johnson", + "Smith", + "Smith", + ], + lambda employee: employee.name, + ) + + def test_heterogeneous_filter(self): + qs = ( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + ) + .order_by("name") + .values_list("name", flat=True) + ) + # Heterogeneous filter between window function and aggregates pushes + # the WHERE clause to the QUALIFY outer query. + self.assertSequenceEqual( + qs.filter( + department_salary_rank=1, department__in=["Accounting", "Management"] + ), + ["Adams", "Miller"], + ) + self.assertSequenceEqual( + qs.filter( + Q(department_salary_rank=1) + | Q(department__in=["Accounting", "Management"]) + ), + [ + "Adams", + "Jenson", + "Johnson", + "Johnson", + "Jones", + "Miller", + "Smith", + "Wilkinson", + "Williams", + ], + ) + # Heterogeneous filter between window function and aggregates pushes + # the HAVING clause to the QUALIFY outer query. + qs = qs.annotate(past_department_count=Count("past_departments")) + self.assertSequenceEqual( + qs.filter(department_salary_rank=1, past_department_count__gte=1), + ["Johnson", "Miller"], + ) + self.assertSequenceEqual( + qs.filter(Q(department_salary_rank=1) | Q(past_department_count__gte=1)), + ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"], + ) + + def test_limited_filter(self): + """ + A query filtering against a window function have its limit applied + after window filtering takes place. + """ + self.assertQuerysetEqual( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ) + ) + .filter(department_salary_rank=1) + .order_by("department")[0:3], + ["Adams", "Wilkinson", "Miller"], + lambda employee: employee.name, + ) + + def test_filter_count(self): + self.assertEqual( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ) + ) + .filter(department_salary_rank=1) + .count(), + 5, + ) + @skipUnlessDBFeature("supports_frame_range_fixed_distance") def test_range_n_preceding_and_following(self): qs = Employee.objects.annotate( @@ -1071,6 +1337,7 @@ class WindowFunctionTests(TestCase): ), year=ExtractYear("hire_date"), ) + .filter(sum__gte=45000) .values("year", "sum") .distinct("year") .order_by("year") @@ -1081,7 +1348,6 @@ class WindowFunctionTests(TestCase): {"year": 2008, "sum": 45000}, {"year": 2009, "sum": 128000}, {"year": 2011, "sum": 60000}, - {"year": 2012, "sum": 40000}, {"year": 2013, "sum": 84000}, ] for idx, val in zip(range(len(results)), results): @@ -1348,34 +1614,18 @@ class NonQueryWindowTests(SimpleTestCase): frame.window_frame_start_end(None, None, None) def test_invalid_filter(self): - msg = "Window is disallowed in the filter clause" - qs = Employee.objects.annotate(dense_rank=Window(expression=DenseRank())) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(dense_rank__gte=1) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.annotate(inc_rank=F("dense_rank") + Value(1)).filter(inc_rank__gte=1) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(id=F("dense_rank")) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(id=Func("dense_rank", 2, function="div")) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.annotate(total=Sum("dense_rank", filter=Q(name="Jones"))).filter(total=1) - - def test_conditional_annotation(self): - qs = Employee.objects.annotate( - dense_rank=Window(expression=DenseRank()), - ).annotate( - equal=Case( - When(id=F("dense_rank"), then=Value(True)), - default=Value(False), - output_field=BooleanField(), - ), + msg = ( + "Heterogeneous disjunctive predicates against window functions are not " + "implemented when performing conditional aggregation." ) - # The SQL standard disallows referencing window functions in the WHERE - # clause. - msg = "Window is disallowed in the filter clause" - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(equal=True) + qs = Employee.objects.annotate( + window=Window(Rank()), + past_dept_cnt=Count("past_departments"), + ) + with self.assertRaisesMessage(NotImplementedError, msg): + list(qs.filter(Q(window=1) | Q(department="Accounting"))) + with self.assertRaisesMessage(NotImplementedError, msg): + list(qs.exclude(window=1, department="Accounting")) def test_invalid_order_by(self): msg = (