diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index cbf3f024d9..4ef7cc24a6 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -236,6 +236,9 @@ class BaseDatabaseFeatures: # Does the backend support indexing a TextField? supports_index_on_text_field = True + # Does the backed support window expressions (expression OVER (...))? + supports_over_clause = False + # Does the backend support CAST with precision? supports_cast_with_precision = True diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 495d7e910b..fd74263a4a 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -4,7 +4,7 @@ from importlib import import_module from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.db import transaction +from django.db import NotSupportedError, transaction from django.db.backends import utils from django.utils import timezone from django.utils.dateparse import parse_duration @@ -39,6 +39,13 @@ class BaseDatabaseOperations: # CharField data type if the max_length argument isn't provided. cast_char_field_without_max_length = None + # Start and end points for window expressions. + PRECEDING = 'PRECEDING' + FOLLOWING = 'FOLLOWING' + UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING + UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING + CURRENT_ROW = 'CURRENT ROW' + def __init__(self, connection): self.connection = connection self._cache = None @@ -598,3 +605,34 @@ class BaseDatabaseOperations: rhs_sql, rhs_params = rhs return "(%s - %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params raise NotImplementedError("This backend does not support %s subtraction." % internal_type) + + def window_frame_start(self, start): + if isinstance(start, int): + if start < 0: + return '%d %s' % (abs(start), self.PRECEDING) + elif start == 0: + return self.CURRENT_ROW + elif start is None: + return self.UNBOUNDED_PRECEDING + raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start) + + def window_frame_end(self, end): + if isinstance(end, int): + if end == 0: + return self.CURRENT_ROW + elif end > 0: + return '%d %s' % (end, self.FOLLOWING) + elif end is None: + return self.UNBOUNDED_FOLLOWING + raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end) + + def window_frame_rows_start_end(self, start=None, end=None): + """ + Return SQL for start and end points in an OVER clause window frame. + """ + if not self.connection.features.supports_over_clause: + raise NotSupportedError('This backend does not support window expressions.') + return self.window_frame_start(start), self.window_frame_end(end) + + def window_frame_range_start_end(self, start=None, end=None): + return self.window_frame_rows_start_end(start, end) diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 033460e09b..18ab088941 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -81,6 +81,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): result = cursor.fetchone() return result and result[0] == 1 + @cached_property + def supports_over_clause(self): + return self.connection.mysql_version >= (8, 0, 2) + @cached_property def supports_transactions(self): """ diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 62d7c75300..a378372947 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -54,3 +54,4 @@ class DatabaseFeatures(BaseDatabaseFeatures): END; """ supports_callproc_kwargs = True + supports_over_clause = True diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index de3fc6d1af..0349493dae 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -48,6 +48,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): V_I := P_I; END; $$ LANGUAGE plpgsql;""" + supports_over_clause = True @cached_property def supports_aggregate_filter_clause(self): diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index bff8e87d68..268fa99b51 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -1,6 +1,7 @@ from psycopg2.extras import Inet from django.conf import settings +from django.db import NotSupportedError from django.db.backends.base.operations import BaseDatabaseOperations @@ -247,3 +248,12 @@ class DatabaseOperations(BaseDatabaseOperations): rhs_sql, rhs_params = rhs return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), lhs_params + rhs_params return super().subtract_temporals(internal_type, lhs, rhs) + + def window_frame_range_start_end(self, start=None, end=None): + start_, end_ = super().window_frame_range_start_end(start, end) + if (start and start < 0) or (end and end > 0): + raise NotSupportedError( + 'PostgreSQL only supports UNBOUNDED together with PRECEDING ' + 'and FOLLOWING.' + ) + return start_, end_ diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 06806fcdc4..d29addd1f7 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -6,8 +6,8 @@ from django.db.models.deletion import ( CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, ) from django.db.models.expressions import ( - Case, Exists, Expression, ExpressionWrapper, F, Func, OuterRef, Subquery, - Value, When, + Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func, + OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame, ) from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all @@ -64,8 +64,9 @@ __all__ += [ 'ObjectDoesNotExist', 'signals', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', 'ProtectedError', - 'Case', 'Exists', 'Expression', 'ExpressionWrapper', 'F', 'Func', - 'OuterRef', 'Subquery', 'Value', 'When', + 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F', + 'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When', + 'Window', 'WindowFrame', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index c91200ef7f..4ed763cfe1 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -15,6 +15,7 @@ class Aggregate(Func): contains_aggregate = True name = None filter_template = '%s FILTER (WHERE %%(filter)s)' + window_compatible = True def __init__(self, *args, filter=None, **kwargs): self.filter = filter diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 1937ca16c7..49ca801924 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -3,6 +3,7 @@ import datetime from decimal import Decimal from django.core.exceptions import EmptyResultSet, FieldError +from django.db import connection from django.db.models import fields from django.db.models.query_utils import Q from django.utils.deconstruct import deconstructible @@ -140,6 +141,10 @@ class BaseExpression: # aggregate specific fields is_summary = False _output_field_resolved_to_none = False + # Can the expression be used in a WHERE clause? + filterable = True + # Can the expression can be used as a source expression in Window? + window_compatible = False def __init__(self, output_field=None): if output_field is not None: @@ -206,6 +211,13 @@ class BaseExpression: return True return False + @cached_property + def contains_over_clause(self): + for expr in self.get_source_expressions(): + if expr and expr.contains_over_clause: + return True + return False + @cached_property def contains_column_references(self): for expr in self.get_source_expressions(): @@ -232,6 +244,7 @@ class BaseExpression: c.is_summary = summarize c.set_source_expressions([ expr.resolve_expression(query, allow_joins, reuse, summarize) + if expr else None for expr in c.get_source_expressions() ]) return c @@ -482,6 +495,9 @@ class TemporalSubtraction(CombinedExpression): @deconstructible class F(Combinable): """An object capable of resolving references to existing query objects.""" + # Can the expression be used in a WHERE clause? + filterable = True + def __init__(self, name): """ Arguments: @@ -767,6 +783,23 @@ class Ref(Expression): return [self] +class ExpressionList(Func): + """ + An expression containing multiple expressions. Can be used to provide a + list of expressions as an argument to another expression, like an + ordering clause. + """ + template = '%(expressions)s' + + def __init__(self, *expressions, **extra): + if len(expressions) == 0: + raise ValueError('%s requires at least one expression.' % self.__class__.__name__) + super().__init__(*expressions, **extra) + + def __str__(self): + return self.arg_joiner.join(str(arg) for arg in self.source_expressions) + + class ExpressionWrapper(Expression): """ An expression that can wrap another expression so that it can provide @@ -1118,3 +1151,168 @@ class OrderBy(BaseExpression): def desc(self): self.descending = True + + +class Window(Expression): + template = '%(expression)s OVER (%(window)s)' + # Although the main expression may either be an aggregate or an + # expression with an aggregate function, the GROUP BY that will + # be introduced in the query as a result is not desired. + contains_aggregate = False + contains_over_clause = True + filterable = False + + def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None): + self.partition_by = partition_by + self.order_by = order_by + self.frame = frame + + if not getattr(expression, 'window_compatible', False): + raise ValueError( + "Expression '%s' isn't compatible with OVER clauses." % + expression.__class__.__name__ + ) + + if self.partition_by is not None: + if not isinstance(self.partition_by, (tuple, list)): + self.partition_by = (self.partition_by,) + self.partition_by = ExpressionList(*self.partition_by) + + if self.order_by is not None: + if isinstance(self.order_by, (list, tuple)): + self.order_by = ExpressionList(*self.order_by) + elif not isinstance(self.order_by, BaseExpression): + raise ValueError( + 'order_by must be either an Expression or a sequence of ' + 'expressions.' + ) + super().__init__(output_field=output_field) + self.source_expression = self._parse_expressions(expression)[0] + + def _resolve_output_field(self): + return self.source_expression.output_field + + def get_source_expressions(self): + return [self.source_expression, self.partition_by, self.order_by, self.frame] + + def set_source_expressions(self, exprs): + self.source_expression, self.partition_by, self.order_by, self.frame = exprs + + def as_sql(self, compiler, connection, function=None, template=None): + connection.ops.check_expression_support(self) + expr_sql, params = compiler.compile(self.source_expression) + window_sql, window_params = [], [] + + if self.partition_by is not None: + sql_expr, sql_params = self.partition_by.as_sql( + compiler=compiler, connection=connection, + template='PARTITION BY %(expressions)s', + ) + window_sql.extend(sql_expr) + window_params.extend(sql_params) + + if self.order_by is not None: + window_sql.append(' ORDER BY ') + order_sql, order_params = compiler.compile(self.order_by) + window_sql.extend(''.join(order_sql)) + window_params.extend(order_params) + + if self.frame: + frame_sql, frame_params = compiler.compile(self.frame) + window_sql.extend(' ' + frame_sql) + window_params.extend(frame_params) + + params.extend(window_params) + template = template or self.template + + return template % { + 'expression': expr_sql, + 'window': ''.join(window_sql).strip() + }, params + + def __str__(self): + return '{} OVER ({}{}{})'.format( + str(self.source_expression), + 'PARTITION BY ' + str(self.partition_by) if self.partition_by else '', + 'ORDER BY ' + str(self.order_by) if self.order_by else '', + str(self.frame or ''), + ) + + def __repr__(self): + return '<%s: %s>' % (self.__class__.__name__, self) + + def get_group_by_cols(self): + return [] + + +class WindowFrame(Expression): + """ + Model the frame clause in window expressions. There are two types of frame + clauses which are subclasses, however, all processing and validation (by no + means intended to be complete) is done here. Thus, providing an end for a + frame is optional (the default is UNBOUNDED FOLLOWING, which is the last + row in the frame). + """ + template = '%(frame_type)s BETWEEN %(start)s AND %(end)s' + + def __init__(self, start=None, end=None): + self.start = start + self.end = end + + def set_source_expressions(self, exprs): + self.start, self.end = exprs + + def get_source_expressions(self): + return [Value(self.start), Value(self.end)] + + def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) + start, end = self.window_frame_start_end(connection, self.start.value, self.end.value) + return self.template % { + 'frame_type': self.frame_type, + 'start': start, + 'end': end, + }, [] + + def __repr__(self): + return '<%s: %s>' % (self.__class__.__name__, self) + + def get_group_by_cols(self): + return [] + + def __str__(self): + if self.start is not None and self.start < 0: + start = '%d %s' % (abs(self.start), connection.ops.PRECEDING) + elif self.start is not None and self.start == 0: + start = connection.ops.CURRENT_ROW + else: + start = connection.ops.UNBOUNDED_PRECEDING + + if self.end is not None and self.end > 0: + end = '%d %s' % (self.end, connection.ops.FOLLOWING) + elif self.end is not None and self.end == 0: + end = connection.ops.CURRENT_ROW + else: + end = connection.ops.UNBOUNDED_FOLLOWING + return self.template % { + 'frame_type': self.frame_type, + 'start': start, + 'end': end, + } + + def window_frame_start_end(self, connection, start, end): + raise NotImplementedError('Subclasses must implement window_frame_start_end().') + + +class RowRange(WindowFrame): + frame_type = 'ROWS' + + def window_frame_start_end(self, connection, start, end): + return connection.ops.window_frame_rows_start_end(start, end) + + +class ValueRange(WindowFrame): + frame_type = 'RANGE' + + def window_frame_start_end(self, connection, start, end): + return connection.ops.window_frame_range_start_end(start, end) diff --git a/django/db/models/functions/__init__.py b/django/db/models/functions/__init__.py index f2e59f38ff..aab74b232a 100644 --- a/django/db/models/functions/__init__.py +++ b/django/db/models/functions/__init__.py @@ -8,6 +8,10 @@ from .datetime import ( Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncYear, ) +from .window import ( + CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile, + PercentRank, Rank, RowNumber, +) __all__ = [ # base @@ -18,4 +22,7 @@ __all__ = [ 'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractWeekDay', 'ExtractYear', 'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth', 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncYear', + # window + 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', + 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', ] diff --git a/django/db/models/functions/window.py b/django/db/models/functions/window.py new file mode 100644 index 0000000000..3719dfca88 --- /dev/null +++ b/django/db/models/functions/window.py @@ -0,0 +1,118 @@ +from django.db.models import FloatField, Func, IntegerField + +__all__ = [ + 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', + 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', +] + + +class CumeDist(Func): + function = 'CUME_DIST' + name = 'CumeDist' + output_field = FloatField() + window_compatible = True + + +class DenseRank(Func): + function = 'DENSE_RANK' + name = 'DenseRank' + output_field = IntegerField() + window_compatible = True + + +class FirstValue(Func): + arity = 1 + function = 'FIRST_VALUE' + name = 'FirstValue' + window_compatible = True + + +class LagLeadFunction(Func): + window_compatible = True + + def __init__(self, expression, offset=1, default=None, **extra): + if expression is None: + raise ValueError( + '%s requires a non-null source expression.' % + self.__class__.__name__ + ) + if offset is None or offset <= 0: + raise ValueError( + '%s requires a positive integer for the offset.' % + self.__class__.__name__ + ) + args = (expression, offset) + if default is not None: + args += (default,) + super().__init__(*args, **extra) + + def _resolve_output_field(self): + sources = self.get_source_expressions() + return sources[0].output_field + + +class Lag(LagLeadFunction): + function = 'LAG' + name = 'Lag' + + +class LastValue(Func): + arity = 1 + function = 'LAST_VALUE' + name = 'LastValue' + window_compatible = True + + +class Lead(LagLeadFunction): + function = 'LEAD' + name = 'Lead' + + +class NthValue(Func): + function = 'NTH_VALUE' + name = 'NthValue' + window_compatible = True + + def __init__(self, expression, nth=1, **extra): + if expression is None: + raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__) + if nth is None or nth <= 0: + raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__) + super().__init__(expression, nth, **extra) + + def _resolve_output_field(self): + sources = self.get_source_expressions() + return sources[0].output_field + + +class Ntile(Func): + function = 'NTILE' + name = 'Ntile' + output_field = IntegerField() + window_compatible = True + + def __init__(self, num_buckets=1, **extra): + if num_buckets <= 0: + raise ValueError('num_buckets must be greater than 0.') + super().__init__(num_buckets, **extra) + + +class PercentRank(Func): + function = 'PERCENT_RANK' + name = 'PercentRank' + output_field = FloatField() + window_compatible = True + + +class Rank(Func): + function = 'RANK' + name = 'Rank' + output_field = IntegerField() + window_compatible = True + + +class RowNumber(Func): + function = 'ROW_NUMBER' + name = 'RowNumber' + output_field = IntegerField() + window_compatible = True diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d82c29af66..f79f435515 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -115,6 +115,10 @@ class Lookup: def contains_aggregate(self): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) + @cached_property + def contains_over_clause(self): + return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False) + @property def is_summary(self): return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 01c303eb7e..11ff51f60f 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1107,6 +1107,8 @@ class SQLInsertCompiler(SQLCompiler): ) if value.contains_aggregate: raise FieldError("Aggregate functions are not allowed in this query") + if value.contains_over_clause: + raise FieldError('Window expressions are not allowed in this query.') else: value = field.get_db_prep_save(value, connection=self.connection) return value @@ -1262,6 +1264,8 @@ class SQLUpdateCompiler(SQLCompiler): val = val.resolve_expression(self.query, allow_joins=False, for_save=True) if val.contains_aggregate: raise FieldError("Aggregate functions are not allowed in this query") + if val.contains_over_clause: + raise FieldError('Window expressions are not allowed in this query.') elif hasattr(val, 'prepare_database_save'): if field.remote_field: val = field.get_db_prep_save( diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 017edea873..4cd22c7b8a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -13,7 +13,7 @@ from string import ascii_uppercase from django.core.exceptions import ( EmptyResultSet, FieldDoesNotExist, FieldError, ) -from django.db import DEFAULT_DB_ALIAS, connections +from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref @@ -1125,6 +1125,13 @@ class Query: if not arg: raise FieldError("Cannot parse keyword query %r" % arg) lookups, parts, reffed_expression = self.solve_lookup_type(arg) + + if not getattr(reffed_expression, 'filterable', True): + raise NotSupportedError( + reffed_expression.__class__.__name__ + ' is disallowed in ' + 'the filter clause.' + ) + if not allow_joins and len(parts) > 1: raise FieldError("Joined field references are not permitted in this query") diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ed24b08bd0..0ca95f7018 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -167,6 +167,16 @@ class WhereNode(tree.Node): def contains_aggregate(self): return self._contains_aggregate(self) + @classmethod + def _contains_over_clause(cls, obj): + if isinstance(obj, tree.Node): + return any(cls._contains_over_clause(c) for c in obj.children) + return obj.contains_over_clause + + @cached_property + def contains_over_clause(self): + return self._contains_over_clause(self) + @property def is_summary(self): return any(child.is_summary for child in self.children) diff --git a/docs/ref/models/database-functions.txt b/docs/ref/models/database-functions.txt index 875786c640..f5a9af3a05 100644 --- a/docs/ref/models/database-functions.txt +++ b/docs/ref/models/database-functions.txt @@ -819,3 +819,132 @@ Usage example:: 'minute': 'minute': datetime.datetime(2014, 6, 15, 14, 30, tzinfo=), 'second': datetime.datetime(2014, 6, 15, 14, 30, 50, tzinfo=) } + +.. _window-functions: + +Window functions +================ + +.. versionadded:: 2.0 + +There are a number of functions to use in a +:class:`~django.db.models.expressions.Window` expression for computing the rank +of elements or the :class:`Ntile` of some rows. + +``CumeDist`` +------------ + +.. class:: CumeDist(*expressions, **extra) + +Calculates the cumulative distribution of a value within a window or partition. +The cumulative distribution is defined as the number of rows preceding or +peered with the current row divided by the total number of rows in the frame. + +``DenseRank`` +------------- + +.. class:: DenseRank(*expressions, **extra) + +Equivalent to :class:`Rank` but does not have gaps. + +``FirstValue`` +-------------- + +.. class:: FirstValue(expression, **extra) + +Returns the value evaluated at the row that's the first row of the window +frame, or ``None`` if no such value exists. + +``Lag`` +------- + +.. class:: Lag(expression, offset=1, default=None, **extra) + +Calculates the value offset by ``offset``, and if no row exists there, returns +``default``. + +``default`` must have the same type as the ``expression``, however, this is +only validated by the database and not in Python. + +``LastValue`` +------------- + +.. class:: LastValue(expression, **extra) + +Comparable to :class:`FirstValue`, it calculates the last value in a given +frame clause. + +``Lead`` +-------- + +.. class:: Lead(expression, offset=1, default=None, **extra) + +Calculates the leading value in a given :ref:`frame `. Both +``offset`` and ``default`` are evaluated with respect to the current row. + +``default`` must have the same type as the ``expression``, however, this is +only validated by the database and not in Python. + +``NthValue`` +------------ + +.. class:: NthValue(expression, nth=1, **extra) + +Computes the row relative to the offset ``nth`` (must be a positive value) +within the window. Returns ``None`` if no row exists. + +Some databases may handle a nonexistent nth-value differently. For example, +Oracle returns an empty string rather than ``None`` for character-based +expressions. Django doesn't do any conversions in these cases. + +``Ntile`` +--------- + +.. class:: Ntile(num_buckets=1, **extra) + +Calculates a partition for each of the rows in the frame clause, distributing +numbers as evenly as possible between 1 and ``num_buckets``. If the rows don't +divide evenly into a number of buckets, one or more buckets will be represented +more frequently. + +``PercentRank`` +--------------- + +.. class:: PercentRank(*expressions, **extra) + +Computes the percentile rank of the rows in the frame clause. This +computation is equivalent to evaluating:: + + (rank - 1) / (total rows - 1) + +The following table explains the calculation for the percentile rank of a row: + +===== ===== ==== ============ ============ +Row # Value Rank Calculation Percent Rank +===== ===== ==== ============ ============ +1 15 1 (1-1)/(7-1) 0.0000 +2 20 2 (2-1)/(7-1) 0.1666 +3 20 2 (2-1)/(7-1) 0.1666 +4 20 2 (2-1)/(7-1) 0.1666 +5 30 5 (5-1)/(7-1) 0.6666 +6 30 5 (5-1)/(7-1) 0.6666 +7 40 7 (7-1)/(7-1) 1.0000 +===== ===== ==== ============ ============ + +``Rank`` +-------- + +.. class:: Rank(*expressions, **extra) + +Comparable to ``RowNumber``, this function ranks rows in the window. The +computed rank contains gaps. Use :class:`DenseRank` to compute rank without +gaps. + +``RowNumber`` +------------- + +.. class:: RowNumber(*expressions, **extra) + +Computes the row number according to the ordering of either the frame clause +or the ordering of the whole query if there is no partitioning of the +:ref:`window frame `. diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 0974a9dd51..6ef1001a9f 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -353,6 +353,13 @@ The ``Aggregate`` API is as follows: generated. Specifically, the ``function`` will be interpolated as the ``function`` placeholder within :attr:`template`. Defaults to ``None``. + .. attribute:: window_compatible + + .. versionadded:: 2.0 + + Defaults to ``True`` since most aggregate functions can be used as the + source expression in :class:`~django.db.models.expressions.Window`. + The ``expression`` argument can be the name of a field on the model, or another expression. It will be converted to a string and used as the ``expressions`` placeholder within the ``template``. @@ -649,6 +656,184 @@ should avoid them if possible. force you to acknowledge that you're not interpolating your SQL with user provided data. +Window functions +---------------- + +.. versionadded:: 2.0 + +Window functions provide a way to apply functions on partitions. Unlike a +normal aggregation function which computes a final result for each set defined +by the group by, window functions operate on :ref:`frames ` and +partitions, and compute the result for each row. + +You can specify multiple windows in the same query which in Django ORM would be +equivalent to including multiple expressions in a :doc:`QuerySet.annotate() +` call. The ORM doesn't make use of named windows, +instead they are part of the selected columns. + +.. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None) + + .. attribute:: filterable + + Defaults to ``False``. The SQL standard disallows referencing window + functions in the ``WHERE`` clause and Django raises an exception when + constructing a ``QuerySet`` that would do that. + + .. attribute:: template + + Defaults to ``%(expression)s OVER (%(window)s)'``. If only the + ``expression`` argument is provided, the window clause will be blank. + +The ``Window`` class is the main expression for an ``OVER`` clause. + +The ``expression`` argument is either a :ref:`window function +`, an :ref:`aggregate function `, or +an expression that's compatible in a window clause. + +The ``partition_by`` argument is a list of expressions (column names should be +wrapped in an ``F``-object) that control the partitioning of the rows. +Partitioning narrows which rows are used to compute the result set. + +The ``output_field`` is specified either as an argument or by the expression. + +The ``order_by`` argument accepts a sequence of expressions on which you can +call :meth:`~django.db.models.Expression.asc` and +:meth:`~django.db.models.Expression.desc`. The ordering controls the order in +which the expression is applied. For example, if you sum over the rows in a +partition, the first result is just the value of the first row, the second is +the sum of first and second row. + +The ``frame`` parameter specifies which other rows that should be used in the +computation. See :ref:`window-frames` for details. + +For example, to annotate each movie with the average rating for the movies by +the same studio in the same genre and release year:: + + >>> from django.db.models import Avg, ExtractYear, F, Window + >>> Movie.objects.annotate( + >>> avg_rating=Window( + >>> expression=Avg('rating'), + >>> partition_by=[F('studio'), F('genre')], + >>> order_by=ExtractYear('released').asc(), + >>> ), + >>> ) + +This makes it easy to check if a movie is rated better or worse than its peers. + +You may want to apply multiple expressions over the same window, i.e., the +same partition and frame. For example, you could modify the previous example +to also include the best and worst rating in each movie's group (same studio, +genre, and release year) by using three window functions in the same query. The +partition and ordering from the previous example is extracted into a dictionary +to reduce repetition:: + + >>> from django.db.models import Avg, ExtractYear, F, Max, Min, Window + >>> window = { + >>> 'partition': [F('studio'), F('genre')], + >>> 'order_by': ExtractYear('released').asc(), + >>> } + >>> Movie.objects.annotate( + >>> avg_rating=Window( + >>> expression=Avg('rating'), **window, + >>> ), + >>> best=Window( + >>> expression=Max('rating'), **window, + >>> ), + >>> worst=Window( + >>> expression=Min('rating'), **window, + >>> ), + >>> ) + +Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle +support window expressions. Support for different window expression features +varies among the different databases. For example, the options in +:meth:`~django.db.models.Expression.asc` and +:meth:`~django.db.models.Expression.desc` may not be supported. Consult the +documentation for your database as needed. + +.. _window-frames: + +Frames +~~~~~~ + +For a window frame, you can choose either a range-based sequence of rows or an +ordinary sequence of rows. + +.. class:: ValueRange(start=None, end=None) + + .. attribute:: frame_type + + This attribute is set to ``'RANGE'``. + + PostgreSQL has limited support for ``ValueRange`` and only supports use of + the standard start and end points, such as ``CURRENT ROW`` and ``UNBOUNDED + FOLLOWING``. + +.. class:: RowRange(start=None, end=None) + + .. attribute:: frame_type + + This attribute is set to ``'ROWS'``. + +Both classes return SQL with the template:: + + %(frame_type)s BETWEEN %(start)s AND %(end)s + +Frames narrow the rows that are used for computing the result. They shift from +some start point to some specified end point. Frames can be used with and +without partitions, but it's often a good idea to specify an ordering of the +window to ensure a deterministic result. In a frame, a peer in a frame is a row +with an equivalent value, or all rows if an ordering clause isn't present. + +The default starting point for a frame is ``UNBOUNDED PRECEDING`` which is the +first row of the partition. The end point is always explicitly included in the +SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default +frame includes all rows from the partition to the last row in the set. + +The accepted values for the ``start`` and ``end`` arguments are ``None``, an +integer, or zero. A negative integer for ``start`` results in ``N preceding``, +while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``, +zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``. + +There's a difference in what ``CURRENT ROW`` includes. When specified in +``ROWS`` mode, the frame starts or ends with the current row. When specified in +``RANGE`` mode, the frame starts or ends at the first or last peer according to +the ordering clause. Thus, ``RANGE CURRENT ROW`` evaluates the expression for +rows which have the same value specified by the ordering. Because the template +includes both the ``start`` and ``end`` points, this may be expressed with:: + + ValueRange(start=0, end=0) + +If a movie's "peers" are described as movies released by the same studio in the +same genre in the same year, this ``RowRange`` example annotates each movie +with the average rating of a movie's two prior and two following peers:: + + >>> from django.db.models import Avg, ExtractYear, F, RowRange, Window + >>> Movie.objects.annotate( + >>> avg_rating=Window( + >>> expression=Avg('rating'), + >>> partition_by=[F('studio'), F('genre')], + >>> order_by=ExtractYear('released').asc(), + >>> frame=RowRange(start=-2, end=2), + >>> ), + >>> ) + +If the database supports it, you can specify the start and end points based on +values of an expression in the partition. If the ``released`` field of the +``Movie`` model stores the release month of each movies, this ``ValueRange`` +example annotates each movie with the average rating of a movie's peers +released between twelve months before and twelve months after the each movie. + + >>> from django.db.models import Avg, ExpressionList, F, ValueRange, Window + >>> Movie.objects.annotate( + >>> avg_rating=Window( + >>> expression=Avg('rating'), + >>> partition_by=[F('studio'), F('genre')], + >>> order_by=F('released').asc(), + >>> frame=ValueRange(start=-12, end=12), + >>> ), + >>> ) + .. currentmodule:: django.db.models Technical Information @@ -677,6 +862,30 @@ calling the appropriate methods on the wrapped expression. Tells Django that this expression contains an aggregate and that a ``GROUP BY`` clause needs to be added to the query. + .. attribute:: contains_over_clause + + .. versionadded:: 2.0 + + Tells Django that this expression contains a + :class:`~django.db.models.expressions.Window` expression. It's used, + for example, to disallow window function expressions in queries that + modify data. Defaults to ``True``. + + .. attribute:: filterable + + .. versionadded:: 2.0 + + Tells Django that this expression can be referenced in + :meth:`.QuerySet.filter`. Defaults to ``True``. + + .. attribute:: window_compatible + + .. versionadded:: 2.0 + + Tells Django that this expression can be used as the source expression + in :class:`~django.db.models.expressions.Window`. Defaults to + ``False``. + .. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False, for_save=False) Provides the chance to do any pre-processing or validation of diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index c7ab413f53..f62050d818 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1681,6 +1681,11 @@ raised if ``select_for_update()`` is used in autocommit mode. ``select_for_update()`` you should use :class:`~django.test.TransactionTestCase`. +.. admonition:: Certain expressions may not be supported + + PostgreSQL doesn't support ``select_for_update()`` with + :class:`~django.db.models.expressions.Window` expressions. + .. versionchanged:: 1.11 The ``skip_locked`` argument was added. diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index da898db67d..fd9cf8dfd3 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -52,6 +52,14 @@ Mobile-friendly ``contrib.admin`` The admin is now responsive and supports all major mobile devices. Older browser may experience varying levels of graceful degradation. +Window expressions +------------------ + +The new :class:`~django.db.models.expressions.Window` expression allows +adding an ``OVER`` clause to querysets. You can use :ref:`window functions +` and :ref:`aggregate functions ` in +the expression. + Minor features -------------- @@ -404,6 +412,11 @@ backends. requires that the arguments to ``OF`` be columns rather than tables, set ``DatabaseFeatures.select_for_update_of_column = True``. +* To enable support for :class:`~django.db.models.expressions.Window` + expressions, set ``DatabaseFeatures.supports_over_clause`` to ``True``. You + may need to customize the ``DatabaseOperations.window_start_rows_start_end()`` + and/or ``window_start_range_start_end()`` methods. + * Third-party database backends should add a ``DatabaseOperations.cast_char_field_without_max_length`` attribute with the database data type that will be used in the diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py new file mode 100644 index 0000000000..4477547c93 --- /dev/null +++ b/tests/backends/base/test_operations.py @@ -0,0 +1,10 @@ +from django.db import NotSupportedError, connection +from django.test import SimpleTestCase, skipIfDBFeature + + +class DatabaseOperationTests(SimpleTestCase): + @skipIfDBFeature('supports_over_clause') + def test_window_frame_raise_not_supported_error(self): + msg = 'This backend does not support window expressions.' + with self.assertRaisesMessage(NotSupportedError, msg): + connection.ops.window_frame_rows_start_end() diff --git a/tests/db_functions/test_window.py b/tests/db_functions/test_window.py new file mode 100644 index 0000000000..2efc88fdfa --- /dev/null +++ b/tests/db_functions/test_window.py @@ -0,0 +1,39 @@ +from django.db.models.functions import Lag, Lead, NthValue, Ntile +from django.test import SimpleTestCase + + +class ValidationTests(SimpleTestCase): + def test_nth_negative_nth_value(self): + msg = 'NthValue requires a positive integer as for nth' + with self.assertRaisesMessage(ValueError, msg): + NthValue(expression='salary', nth=-1) + + def test_nth_null_expression(self): + msg = 'NthValue requires a non-null source expression' + with self.assertRaisesMessage(ValueError, msg): + NthValue(expression=None) + + def test_lag_negative_offset(self): + msg = 'Lag requires a positive integer for the offset' + with self.assertRaisesMessage(ValueError, msg): + Lag(expression='salary', offset=-1) + + def test_lead_negative_offset(self): + msg = 'Lead requires a positive integer for the offset' + with self.assertRaisesMessage(ValueError, msg): + Lead(expression='salary', offset=-1) + + def test_null_source_lead(self): + msg = 'Lead requires a non-null source expression' + with self.assertRaisesMessage(ValueError, msg): + Lead(expression=None) + + def test_null_source_lag(self): + msg = 'Lag requires a non-null source expression' + with self.assertRaisesMessage(ValueError, msg): + Lag(expression=None) + + def test_negative_num_buckets_ntile(self): + msg = 'num_buckets must be greater than 0' + with self.assertRaisesMessage(ValueError, msg): + Ntile(num_buckets=-1) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 2de0ec828e..e26b3ef6d8 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -10,8 +10,8 @@ from django.db.models.aggregates import ( Avg, Count, Max, Min, StdDev, Sum, Variance, ) from django.db.models.expressions import ( - Case, Col, Exists, ExpressionWrapper, F, Func, OrderBy, OuterRef, Random, - RawSQL, Ref, Subquery, Value, When, + Case, Col, Exists, ExpressionList, ExpressionWrapper, F, Func, OrderBy, + OuterRef, Random, RawSQL, Ref, Subquery, Value, When, ) from django.db.models.functions import ( Coalesce, Concat, Length, Lower, Substr, Upper, @@ -1330,6 +1330,11 @@ class ValueTests(TestCase): self.assertNotEqual(value, other_value) self.assertNotEqual(value, no_output_field) + def test_raise_empty_expressionlist(self): + msg = 'ExpressionList requires at least one expression' + with self.assertRaisesMessage(ValueError, msg): + ExpressionList() + class ReprTests(TestCase): @@ -1355,6 +1360,14 @@ class ReprTests(TestCase): self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])") self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))") self.assertEqual(repr(Value(1)), "Value(1)") + self.assertEqual( + repr(ExpressionList(F('col'), F('anothercol'))), + 'ExpressionList(F(col), F(anothercol))' + ) + self.assertEqual( + repr(ExpressionList(OrderBy(F('col'), descending=False))), + 'ExpressionList(OrderBy(F(col), descending=False))' + ) def test_functions(self): self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))") diff --git a/tests/expressions_window/__init__.py b/tests/expressions_window/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/expressions_window/models.py b/tests/expressions_window/models.py new file mode 100644 index 0000000000..94cade7ed7 --- /dev/null +++ b/tests/expressions_window/models.py @@ -0,0 +1,11 @@ +from django.db import models + + +class Employee(models.Model): + name = models.CharField(max_length=40, blank=False, null=False) + salary = models.PositiveIntegerField() + department = models.CharField(max_length=40, blank=False, null=False) + hire_date = models.DateField(blank=False, null=False) + + def __str__(self): + return '{}, {}, {}, {}'.format(self.name, self.department, self.salary, self.hire_date) diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py new file mode 100644 index 0000000000..e61516b4b6 --- /dev/null +++ b/tests/expressions_window/tests.py @@ -0,0 +1,783 @@ +import datetime +from unittest import skipIf, skipUnless + +from django.core.exceptions import FieldError +from django.db import NotSupportedError, connection +from django.db.models import ( + F, RowRange, Value, ValueRange, Window, WindowFrame, +) +from django.db.models.aggregates import Avg, Max, Min, Sum +from django.db.models.functions import ( + CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead, + NthValue, Ntile, PercentRank, Rank, RowNumber, Upper, +) +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature + +from .models import Employee + + +@skipUnlessDBFeature('supports_over_clause') +class WindowFunctionTests(TestCase): + @classmethod + def setUpTestData(cls): + Employee.objects.bulk_create([ + Employee(name=e[0], salary=e[1], department=e[2], hire_date=e[3]) + for e in [ + ('Jones', 45000, 'Accounting', datetime.datetime(2005, 11, 1)), + ('Williams', 37000, 'Accounting', datetime.datetime(2009, 6, 1)), + ('Jenson', 45000, 'Accounting', datetime.datetime(2008, 4, 1)), + ('Adams', 50000, 'Accounting', datetime.datetime(2013, 7, 1)), + ('Smith', 55000, 'Sales', datetime.datetime(2007, 6, 1)), + ('Brown', 53000, 'Sales', datetime.datetime(2009, 9, 1)), + ('Johnson', 40000, 'Marketing', datetime.datetime(2012, 3, 1)), + ('Smith', 38000, 'Marketing', datetime.datetime(2009, 10, 1)), + ('Wilkinson', 60000, 'IT', datetime.datetime(2011, 3, 1)), + ('Moore', 34000, 'IT', datetime.datetime(2013, 8, 1)), + ('Miller', 100000, 'Management', datetime.datetime(2005, 6, 1)), + ('Johnson', 80000, 'Management', datetime.datetime(2005, 7, 1)), + ] + ]) + + def test_dense_rank(self): + qs = Employee.objects.annotate(rank=Window( + expression=DenseRank(), + order_by=ExtractYear(F('hire_date')).asc(), + )) + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 2), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 3), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 4), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 4), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 4), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 5), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 6), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 7), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 7), + ], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False) + + def test_department_salary(self): + qs = Employee.objects.annotate(department_sum=Window( + expression=Sum('salary'), + partition_by=F('department'), + order_by=[F('hire_date').asc()], + )).order_by('department', 'department_sum') + self.assertQuerysetEqual(qs, [ + ('Jones', 'Accounting', 45000, 45000), + ('Jenson', 'Accounting', 45000, 90000), + ('Williams', 'Accounting', 37000, 127000), + ('Adams', 'Accounting', 50000, 177000), + ('Wilkinson', 'IT', 60000, 60000), + ('Moore', 'IT', 34000, 94000), + ('Miller', 'Management', 100000, 100000), + ('Johnson', 'Management', 80000, 180000), + ('Smith', 'Marketing', 38000, 38000), + ('Johnson', 'Marketing', 40000, 78000), + ('Smith', 'Sales', 55000, 55000), + ('Brown', 'Sales', 53000, 108000), + ], lambda entry: (entry.name, entry.department, entry.salary, entry.department_sum)) + + def test_rank(self): + """ + Rank the employees based on the year they're were hired. Since there + are multiple employees hired in different years, this will contain + gaps. + """ + qs = Employee.objects.annotate(rank=Window( + expression=Rank(), + order_by=ExtractYear(F('hire_date')).asc(), + )) + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 4), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 5), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 6), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 6), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 6), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 9), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 10), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 11), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 11), + ], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False) + + def test_row_number(self): + """ + The row number window function computes the number based on the order + in which the tuples were inserted. Depending on the backend, + + Oracle requires an ordering-clause in the Window expression. + """ + qs = Employee.objects.annotate(row_number=Window( + expression=RowNumber(), + order_by=F('pk').asc(), + )).order_by('pk') + self.assertQuerysetEqual(qs, [ + ('Jones', 'Accounting', 1), + ('Williams', 'Accounting', 2), + ('Jenson', 'Accounting', 3), + ('Adams', 'Accounting', 4), + ('Smith', 'Sales', 5), + ('Brown', 'Sales', 6), + ('Johnson', 'Marketing', 7), + ('Smith', 'Marketing', 8), + ('Wilkinson', 'IT', 9), + ('Moore', 'IT', 10), + ('Miller', 'Management', 11), + ('Johnson', 'Management', 12), + ], lambda entry: (entry.name, entry.department, entry.row_number)) + + @skipIf(connection.vendor == 'oracle', "Oracle requires ORDER BY in row_number, ANSI:SQL doesn't") + def test_row_number_no_ordering(self): + """ + The row number window function computes the number based on the order + in which the tuples were inserted. + """ + # Add a default ordering for consistent results across databases. + qs = Employee.objects.annotate(row_number=Window( + expression=RowNumber(), + )).order_by('pk') + self.assertQuerysetEqual(qs, [ + ('Jones', 'Accounting', 1), + ('Williams', 'Accounting', 2), + ('Jenson', 'Accounting', 3), + ('Adams', 'Accounting', 4), + ('Smith', 'Sales', 5), + ('Brown', 'Sales', 6), + ('Johnson', 'Marketing', 7), + ('Smith', 'Marketing', 8), + ('Wilkinson', 'IT', 9), + ('Moore', 'IT', 10), + ('Miller', 'Management', 11), + ('Johnson', 'Management', 12), + ], lambda entry: (entry.name, entry.department, entry.row_number)) + + def test_avg_salary_department(self): + qs = Employee.objects.annotate(avg_salary=Window( + expression=Avg('salary'), + order_by=F('department').asc(), + partition_by='department', + )).order_by('department', '-salary', 'name') + self.assertQuerysetEqual(qs, [ + ('Adams', 50000, 'Accounting', 44250.00), + ('Jenson', 45000, 'Accounting', 44250.00), + ('Jones', 45000, 'Accounting', 44250.00), + ('Williams', 37000, 'Accounting', 44250.00), + ('Wilkinson', 60000, 'IT', 47000.00), + ('Moore', 34000, 'IT', 47000.00), + ('Miller', 100000, 'Management', 90000.00), + ('Johnson', 80000, 'Management', 90000.00), + ('Johnson', 40000, 'Marketing', 39000.00), + ('Smith', 38000, 'Marketing', 39000.00), + ('Smith', 55000, 'Sales', 54000.00), + ('Brown', 53000, 'Sales', 54000.00), + ], transform=lambda row: (row.name, row.salary, row.department, row.avg_salary)) + + def test_lag(self): + """ + Compute the difference between an employee's salary and the next + highest salary in the employee's department. Return None if the + employee has the lowest salary. + """ + qs = Employee.objects.annotate(lag=Window( + expression=Lag(expression='salary', offset=1), + partition_by=F('department'), + order_by=[F('salary').asc(), F('name').asc()], + )).order_by('department') + self.assertQuerysetEqual(qs, [ + ('Williams', 37000, 'Accounting', None), + ('Jenson', 45000, 'Accounting', 37000), + ('Jones', 45000, 'Accounting', 45000), + ('Adams', 50000, 'Accounting', 45000), + ('Moore', 34000, 'IT', None), + ('Wilkinson', 60000, 'IT', 34000), + ('Johnson', 80000, 'Management', None), + ('Miller', 100000, 'Management', 80000), + ('Smith', 38000, 'Marketing', None), + ('Johnson', 40000, 'Marketing', 38000), + ('Brown', 53000, 'Sales', None), + ('Smith', 55000, 'Sales', 53000), + ], transform=lambda row: (row.name, row.salary, row.department, row.lag)) + + def test_first_value(self): + qs = Employee.objects.annotate(first_value=Window( + expression=FirstValue('salary'), + partition_by=F('department'), + order_by=F('hire_date').asc(), + )).order_by('department', 'hire_date') + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 45000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 45000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 45000), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 60000), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 100000), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 38000), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 55000), + ], lambda row: (row.name, row.salary, row.department, row.hire_date, row.first_value)) + + def test_last_value(self): + qs = Employee.objects.annotate(last_value=Window( + expression=LastValue('hire_date'), + partition_by=F('department'), + order_by=F('hire_date').asc(), + )) + self.assertQuerysetEqual(qs, [ + ('Adams', 'Accounting', datetime.date(2013, 7, 1), 50000, datetime.date(2013, 7, 1)), + ('Jenson', 'Accounting', datetime.date(2008, 4, 1), 45000, datetime.date(2008, 4, 1)), + ('Jones', 'Accounting', datetime.date(2005, 11, 1), 45000, datetime.date(2005, 11, 1)), + ('Williams', 'Accounting', datetime.date(2009, 6, 1), 37000, datetime.date(2009, 6, 1)), + ('Moore', 'IT', datetime.date(2013, 8, 1), 34000, datetime.date(2013, 8, 1)), + ('Wilkinson', 'IT', datetime.date(2011, 3, 1), 60000, datetime.date(2011, 3, 1)), + ('Miller', 'Management', datetime.date(2005, 6, 1), 100000, datetime.date(2005, 6, 1)), + ('Johnson', 'Management', datetime.date(2005, 7, 1), 80000, datetime.date(2005, 7, 1)), + ('Johnson', 'Marketing', datetime.date(2012, 3, 1), 40000, datetime.date(2012, 3, 1)), + ('Smith', 'Marketing', datetime.date(2009, 10, 1), 38000, datetime.date(2009, 10, 1)), + ('Brown', 'Sales', datetime.date(2009, 9, 1), 53000, datetime.date(2009, 9, 1)), + ('Smith', 'Sales', datetime.date(2007, 6, 1), 55000, datetime.date(2007, 6, 1)), + ], transform=lambda row: (row.name, row.department, row.hire_date, row.salary, row.last_value), ordered=False) + + def test_function_list_of_values(self): + qs = Employee.objects.annotate(lead=Window( + expression=Lead(expression='salary'), + order_by=[F('hire_date').asc(), F('name').desc()], + partition_by='department', + )).values_list('name', 'salary', 'department', 'hire_date', 'lead') + self.assertNotIn('GROUP BY', str(qs.query)) + self.assertSequenceEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 37000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 50000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 34000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 80000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 40000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 53000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None), + ]) + + def test_min_department(self): + """An alternative way to specify a query for FirstValue.""" + qs = Employee.objects.annotate(min_salary=Window( + expression=Min('salary'), + partition_by=F('department'), + order_by=[F('salary').asc(), F('name').asc()] + )).order_by('department', 'salary', 'name') + self.assertQuerysetEqual(qs, [ + ('Williams', 'Accounting', 37000, 37000), + ('Jenson', 'Accounting', 45000, 37000), + ('Jones', 'Accounting', 45000, 37000), + ('Adams', 'Accounting', 50000, 37000), + ('Moore', 'IT', 34000, 34000), + ('Wilkinson', 'IT', 60000, 34000), + ('Johnson', 'Management', 80000, 80000), + ('Miller', 'Management', 100000, 80000), + ('Smith', 'Marketing', 38000, 38000), + ('Johnson', 'Marketing', 40000, 38000), + ('Brown', 'Sales', 53000, 53000), + ('Smith', 'Sales', 55000, 53000), + ], lambda row: (row.name, row.department, row.salary, row.min_salary)) + + def test_max_per_year(self): + """ + Find the maximum salary awarded in the same year as the + employee was hired, regardless of the department. + """ + qs = Employee.objects.annotate(max_salary_year=Window( + expression=Max('salary'), + order_by=ExtractYear('hire_date').asc(), + partition_by=ExtractYear('hire_date') + )).order_by(ExtractYear('hire_date'), 'salary') + self.assertQuerysetEqual(qs, [ + ('Jones', 'Accounting', 45000, 2005, 100000), + ('Johnson', 'Management', 80000, 2005, 100000), + ('Miller', 'Management', 100000, 2005, 100000), + ('Smith', 'Sales', 55000, 2007, 55000), + ('Jenson', 'Accounting', 45000, 2008, 45000), + ('Williams', 'Accounting', 37000, 2009, 53000), + ('Smith', 'Marketing', 38000, 2009, 53000), + ('Brown', 'Sales', 53000, 2009, 53000), + ('Wilkinson', 'IT', 60000, 2011, 60000), + ('Johnson', 'Marketing', 40000, 2012, 40000), + ('Moore', 'IT', 34000, 2013, 50000), + ('Adams', 'Accounting', 50000, 2013, 50000), + ], lambda row: (row.name, row.department, row.salary, row.hire_date.year, row.max_salary_year)) + + def test_cume_dist(self): + """ + Compute the cumulative distribution for the employees based on the + salary in increasing order. Equal to rank/total number of rows (12). + """ + qs = Employee.objects.annotate(cume_dist=Window( + expression=CumeDist(), + order_by=F('salary').asc(), + )).order_by('salary', 'name') + # Round result of cume_dist because Oracle uses greater precision. + self.assertQuerysetEqual(qs, [ + ('Moore', 'IT', 34000, 0.0833333333), + ('Williams', 'Accounting', 37000, 0.1666666667), + ('Smith', 'Marketing', 38000, 0.25), + ('Johnson', 'Marketing', 40000, 0.3333333333), + ('Jenson', 'Accounting', 45000, 0.5), + ('Jones', 'Accounting', 45000, 0.5), + ('Adams', 'Accounting', 50000, 0.5833333333), + ('Brown', 'Sales', 53000, 0.6666666667), + ('Smith', 'Sales', 55000, 0.75), + ('Wilkinson', 'IT', 60000, 0.8333333333), + ('Johnson', 'Management', 80000, 0.9166666667), + ('Miller', 'Management', 100000, 1), + ], lambda row: (row.name, row.department, row.salary, round(row.cume_dist, 10))) + + def test_nthvalue(self): + qs = Employee.objects.annotate( + nth_value=Window(expression=NthValue( + expression='salary', nth=2), + order_by=[F('hire_date').asc(), F('name').desc()], + partition_by=F('department'), + ) + ).order_by('department', 'hire_date', 'name') + self.assertQuerysetEqual(qs, [ + ('Jones', 'Accounting', datetime.date(2005, 11, 1), 45000, None), + ('Jenson', 'Accounting', datetime.date(2008, 4, 1), 45000, 45000), + ('Williams', 'Accounting', datetime.date(2009, 6, 1), 37000, 45000), + ('Adams', 'Accounting', datetime.date(2013, 7, 1), 50000, 45000), + ('Wilkinson', 'IT', datetime.date(2011, 3, 1), 60000, None), + ('Moore', 'IT', datetime.date(2013, 8, 1), 34000, 34000), + ('Miller', 'Management', datetime.date(2005, 6, 1), 100000, None), + ('Johnson', 'Management', datetime.date(2005, 7, 1), 80000, 80000), + ('Smith', 'Marketing', datetime.date(2009, 10, 1), 38000, None), + ('Johnson', 'Marketing', datetime.date(2012, 3, 1), 40000, 40000), + ('Smith', 'Sales', datetime.date(2007, 6, 1), 55000, None), + ('Brown', 'Sales', datetime.date(2009, 9, 1), 53000, 53000), + ], lambda row: (row.name, row.department, row.hire_date, row.salary, row.nth_value)) + + def test_lead(self): + """ + Determine what the next person hired in the same department makes. + Because the dataset is ambiguous, the name is also part of the + ordering clause. No default is provided, so None/NULL should be + returned. + """ + qs = Employee.objects.annotate(lead=Window( + expression=Lead(expression='salary'), + order_by=[F('hire_date').asc(), F('name').desc()], + partition_by='department', + )).order_by('department') + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 37000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 50000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 34000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 80000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 40000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 53000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None), + ], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.lead)) + + def test_lead_offset(self): + """ + Determine what the person hired after someone makes. Due to + ambiguity, the name is also included in the ordering. + """ + qs = Employee.objects.annotate(lead=Window( + expression=Lead('salary', offset=2), + partition_by='department', + order_by=F('hire_date').asc(), + )) + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 37000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 50000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), None), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), None), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), None), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), None), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), None), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None), + ], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.lead), + ordered=False + ) + + def test_lead_default(self): + qs = Employee.objects.annotate(lead_default=Window( + expression=Lead(expression='salary', offset=5, default=60000), + partition_by=F('department'), + order_by=F('department').asc(), + )) + self.assertEqual(list(qs.values_list('lead_default', flat=True).distinct()), [60000]) + + def test_ntile(self): + """ + Compute the group for each of the employees across the entire company, + based on how high the salary is for them. There are twelve employees + so it divides evenly into four groups. + """ + qs = Employee.objects.annotate(ntile=Window( + expression=Ntile(num_buckets=4), + order_by=F('salary').desc(), + )).order_by('ntile', '-salary', 'name') + self.assertQuerysetEqual(qs, [ + ('Miller', 'Management', 100000, 1), + ('Johnson', 'Management', 80000, 1), + ('Wilkinson', 'IT', 60000, 1), + ('Smith', 'Sales', 55000, 2), + ('Brown', 'Sales', 53000, 2), + ('Adams', 'Accounting', 50000, 2), + ('Jenson', 'Accounting', 45000, 3), + ('Jones', 'Accounting', 45000, 3), + ('Johnson', 'Marketing', 40000, 3), + ('Smith', 'Marketing', 38000, 4), + ('Williams', 'Accounting', 37000, 4), + ('Moore', 'IT', 34000, 4), + ], lambda x: (x.name, x.department, x.salary, x.ntile)) + + def test_percent_rank(self): + """ + Calculate the percentage rank of the employees across the entire + company based on salary and name (in case of ambiguity). + """ + qs = Employee.objects.annotate(percent_rank=Window( + expression=PercentRank(), + order_by=[F('salary').asc(), F('name').asc()], + )).order_by('percent_rank') + # Round to account for precision differences among databases. + self.assertQuerysetEqual(qs, [ + ('Moore', 'IT', 34000, 0.0), + ('Williams', 'Accounting', 37000, 0.0909090909), + ('Smith', 'Marketing', 38000, 0.1818181818), + ('Johnson', 'Marketing', 40000, 0.2727272727), + ('Jenson', 'Accounting', 45000, 0.3636363636), + ('Jones', 'Accounting', 45000, 0.4545454545), + ('Adams', 'Accounting', 50000, 0.5454545455), + ('Brown', 'Sales', 53000, 0.6363636364), + ('Smith', 'Sales', 55000, 0.7272727273), + ('Wilkinson', 'IT', 60000, 0.8181818182), + ('Johnson', 'Management', 80000, 0.9090909091), + ('Miller', 'Management', 100000, 1.0), + ], transform=lambda row: (row.name, row.department, row.salary, round(row.percent_rank, 10))) + + def test_nth_returns_null(self): + """ + Find the nth row of the data set. None is returned since there are + fewer than 20 rows in the test data. + """ + qs = Employee.objects.annotate(nth_value=Window( + expression=NthValue('salary', nth=20), + order_by=F('salary').asc() + )) + self.assertEqual(list(qs.values_list('nth_value', flat=True).distinct()), [None]) + + def test_multiple_partitioning(self): + """ + Find the maximum salary for each department for people hired in the + same year. + """ + qs = Employee.objects.annotate(max=Window( + expression=Max('salary'), + partition_by=[F('department'), ExtractYear(F('hire_date'))], + )).order_by('department', 'hire_date', 'name') + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 45000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 37000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 50000), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 34000), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 100000), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 40000), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 53000), + ], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.max)) + + def test_multiple_ordering(self): + """ + Accumulate the salaries over the departments based on hire_date. + If two people were hired on the same date in the same department, the + ordering clause will render a different result for those people. + """ + qs = Employee.objects.annotate(sum=Window( + expression=Sum('salary'), + partition_by='department', + order_by=[F('hire_date').asc(), F('name').asc()], + )).order_by('department', 'sum') + self.assertQuerysetEqual(qs, [ + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 90000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 127000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 177000), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 94000), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 180000), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 78000), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 108000), + ], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum)) + + @skipIf(connection.vendor == 'postgresql', 'n following/preceding not supported by PostgreSQL') + def test_range_n_preceding_and_following(self): + qs = Employee.objects.annotate(sum=Window( + expression=Sum('salary'), + order_by=F('salary').asc(), + partition_by='department', + frame=ValueRange(start=-2, end=2), + )) + self.assertIn('RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING', str(qs.query)) + self.assertQuerysetEqual(qs, [ + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 37000), + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 90000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 90000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 50000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 53000), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 40000), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 34000), + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 80000), + ], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum), ordered=False) + + def test_range_unbound(self): + """A query with RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING.""" + qs = Employee.objects.annotate(sum=Window( + expression=Sum('salary'), + partition_by='department', + order_by=[F('hire_date').asc(), F('name').asc()], + frame=ValueRange(start=None, end=None), + )).order_by('department', 'hire_date', 'name') + self.assertIn('RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING', str(qs.query)) + self.assertQuerysetEqual(qs, [ + ('Jones', 'Accounting', 45000, datetime.date(2005, 11, 1), 177000), + ('Jenson', 'Accounting', 45000, datetime.date(2008, 4, 1), 177000), + ('Williams', 'Accounting', 37000, datetime.date(2009, 6, 1), 177000), + ('Adams', 'Accounting', 50000, datetime.date(2013, 7, 1), 177000), + ('Wilkinson', 'IT', 60000, datetime.date(2011, 3, 1), 94000), + ('Moore', 'IT', 34000, datetime.date(2013, 8, 1), 94000), + ('Miller', 'Management', 100000, datetime.date(2005, 6, 1), 180000), + ('Johnson', 'Management', 80000, datetime.date(2005, 7, 1), 180000), + ('Smith', 'Marketing', 38000, datetime.date(2009, 10, 1), 78000), + ('Johnson', 'Marketing', 40000, datetime.date(2012, 3, 1), 78000), + ('Smith', 'Sales', 55000, datetime.date(2007, 6, 1), 108000), + ('Brown', 'Sales', 53000, datetime.date(2009, 9, 1), 108000), + ], transform=lambda row: (row.name, row.department, row.salary, row.hire_date, row.sum)) + + def test_row_range_rank(self): + """ + A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING. + The resulting sum is the sum of the three next (if they exist) and all + previous rows according to the ordering clause. + """ + qs = Employee.objects.annotate(sum=Window( + expression=Sum('salary'), + order_by=[F('hire_date').asc(), F('name').desc()], + frame=RowRange(start=None, end=3), + )).order_by('sum', 'hire_date') + self.assertIn('ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING', str(qs.query)) + self.assertQuerysetEqual(qs, [ + ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 280000), + ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 325000), + ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 362000), + ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 415000), + ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 453000), + ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 513000), + ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 553000), + ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 603000), + ('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 637000), + ('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 637000), + ('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 637000), + ('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 637000), + ], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum)) + + @skipUnlessDBFeature('can_distinct_on_fields') + def test_distinct_window_function(self): + """ + Window functions are not aggregates, and hence a query to filter out + duplicates may be useful. + """ + qs = Employee.objects.annotate( + sum=Window( + expression=Sum('salary'), + partition_by=ExtractYear('hire_date'), + order_by=ExtractYear('hire_date') + ), + year=ExtractYear('hire_date'), + ).values('year', 'sum').distinct('year').order_by('year') + results = [ + {'year': 2005, 'sum': 225000}, {'year': 2007, 'sum': 55000}, + {'year': 2008, 'sum': 45000}, {'year': 2009, 'sum': 128000}, + {'year': 2011, 'sum': 60000}, {'year': 2012, 'sum': 40000}, + {'year': 2013, 'sum': 84000}, + ] + for idx, val in zip(range(len(results)), results): + with self.subTest(result=val): + self.assertEqual(qs[idx], val) + + def test_fail_update(self): + """Window expressions can't be used in an UPDATE statement.""" + msg = 'Window expressions are not allowed in this query' + with self.assertRaisesMessage(FieldError, msg): + Employee.objects.filter(department='Management').update( + salary=Window(expression=Max('salary'), partition_by='department'), + ) + + def test_fail_insert(self): + """Window expressions can't be used in an INSERT statement.""" + msg = 'Window expressions are not allowed in this query' + with self.assertRaisesMessage(FieldError, msg): + Employee.objects.create( + name='Jameson', department='Management', hire_date=datetime.date(2007, 7, 1), + salary=Window(expression=Sum(Value(10000), order_by=F('pk').asc())), + ) + + def test_invalid_start_value_range(self): + msg = "start argument must be a negative integer, zero, or None, but got '3'." + with self.assertRaisesMessage(ValueError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + order_by=F('hire_date').asc(), + frame=ValueRange(start=3), + ))) + + def test_invalid_end_value_range(self): + msg = "end argument must be a positive integer, zero, or None, but got '-3'." + with self.assertRaisesMessage(ValueError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + order_by=F('hire_date').asc(), + frame=ValueRange(end=-3), + ))) + + def test_invalid_type_end_value_range(self): + msg = "end argument must be a positive integer, zero, or None, but got 'a'." + with self.assertRaisesMessage(ValueError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + order_by=F('hire_date').asc(), + frame=ValueRange(end='a'), + ))) + + def test_invalid_type_start_value_range(self): + msg = "start argument must be a negative integer, zero, or None, but got 'a'." + with self.assertRaisesMessage(ValueError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + frame=ValueRange(start='a'), + ))) + + def test_invalid_type_end_row_range(self): + msg = "end argument must be a positive integer, zero, or None, but got 'a'." + with self.assertRaisesMessage(ValueError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + frame=RowRange(end='a'), + ))) + + @skipUnless(connection.vendor == 'postgresql', 'Frame construction not allowed on PostgreSQL') + def test_postgresql_illegal_range_frame_start(self): + msg = 'PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING.' + with self.assertRaisesMessage(NotSupportedError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + order_by=F('hire_date').asc(), + frame=ValueRange(start=-1), + ))) + + @skipUnless(connection.vendor == 'postgresql', 'Frame construction not allowed on PostgreSQL') + def test_postgresql_illegal_range_frame_end(self): + msg = 'PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING.' + with self.assertRaisesMessage(NotSupportedError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + order_by=F('hire_date').asc(), + frame=ValueRange(end=1), + ))) + + def test_invalid_type_start_row_range(self): + msg = "start argument must be a negative integer, zero, or None, but got 'a'." + with self.assertRaisesMessage(ValueError, msg): + list(Employee.objects.annotate(test=Window( + expression=Sum('salary'), + order_by=F('hire_date').asc(), + frame=RowRange(start='a'), + ))) + + +class NonQueryWindowTests(SimpleTestCase): + def test_window_repr(self): + self.assertEqual( + repr(Window(expression=Sum('salary'), partition_by='department')), + '' + ) + self.assertEqual( + repr(Window(expression=Avg('salary'), order_by=F('department').asc())), + '' + ) + + def test_window_frame_repr(self): + self.assertEqual( + repr(RowRange(start=-1)), + '' + ) + self.assertEqual( + repr(ValueRange(start=None, end=1)), + '' + ) + self.assertEqual( + repr(ValueRange(start=0, end=0)), + '' + ) + self.assertEqual( + repr(RowRange(start=0, end=0)), + '' + ) + + def test_empty_group_by_cols(self): + window = Window(expression=Sum('pk')) + self.assertEqual(window.get_group_by_cols(), []) + self.assertFalse(window.contains_aggregate) + + def test_frame_empty_group_by_cols(self): + frame = WindowFrame() + self.assertEqual(frame.get_group_by_cols(), []) + + def test_frame_window_frame_notimplemented(self): + frame = WindowFrame() + msg = 'Subclasses must implement window_frame_start_end().' + with self.assertRaisesMessage(NotImplementedError, msg): + frame.window_frame_start_end(None, None, None) + + def test_invalid_filter(self): + msg = 'Window is disallowed in the filter clause' + with self.assertRaisesMessage(NotSupportedError, msg): + Employee.objects.annotate(dense_rank=Window(expression=DenseRank())).filter(dense_rank__gte=1) + + def test_invalid_order_by(self): + msg = 'order_by must be either an Expression or a sequence of expressions' + with self.assertRaisesMessage(ValueError, msg): + Window(expression=Sum('power'), order_by='-horse') + + def test_invalid_source_expression(self): + msg = "Expression 'Upper' isn't compatible with OVER clauses." + with self.assertRaisesMessage(ValueError, msg): + Window(expression=Upper('name'))