Fixed #26608 -- Added support for window expressions (OVER clause).

Thanks Josh Smeaton, Mariusz Felisiak, Sergey Fedoseev, Simon Charettes,
Adam Chainz/Johnson and Tim Graham for comments and reviews and Jamie
Cockburn for initial patch.
This commit is contained in:
Mads Jensen 2017-09-18 15:42:29 +02:00 committed by Tim Graham
parent da1ba03f1d
commit d549b88050
25 changed files with 1627 additions and 8 deletions

View File

@ -236,6 +236,9 @@ class BaseDatabaseFeatures:
# Does the backend support indexing a TextField? # Does the backend support indexing a TextField?
supports_index_on_text_field = True supports_index_on_text_field = True
# Does the backed support window expressions (expression OVER (...))?
supports_over_clause = False
# Does the backend support CAST with precision? # Does the backend support CAST with precision?
supports_cast_with_precision = True supports_cast_with_precision = True

View File

@ -4,7 +4,7 @@ from importlib import import_module
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import transaction from django.db import NotSupportedError, transaction
from django.db.backends import utils from django.db.backends import utils
from django.utils import timezone from django.utils import timezone
from django.utils.dateparse import parse_duration from django.utils.dateparse import parse_duration
@ -39,6 +39,13 @@ class BaseDatabaseOperations:
# CharField data type if the max_length argument isn't provided. # CharField data type if the max_length argument isn't provided.
cast_char_field_without_max_length = None cast_char_field_without_max_length = None
# Start and end points for window expressions.
PRECEDING = 'PRECEDING'
FOLLOWING = 'FOLLOWING'
UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
CURRENT_ROW = 'CURRENT ROW'
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
self._cache = None self._cache = None
@ -598,3 +605,34 @@ class BaseDatabaseOperations:
rhs_sql, rhs_params = rhs rhs_sql, rhs_params = rhs
return "(%s - %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params return "(%s - %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
raise NotImplementedError("This backend does not support %s subtraction." % internal_type) raise NotImplementedError("This backend does not support %s subtraction." % internal_type)
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
return '%d %s' % (abs(start), self.PRECEDING)
elif start == 0:
return self.CURRENT_ROW
elif start is None:
return self.UNBOUNDED_PRECEDING
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return self.CURRENT_ROW
elif end > 0:
return '%d %s' % (end, self.FOLLOWING)
elif end is None:
return self.UNBOUNDED_FOLLOWING
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
if not self.connection.features.supports_over_clause:
raise NotSupportedError('This backend does not support window expressions.')
return self.window_frame_start(start), self.window_frame_end(end)
def window_frame_range_start_end(self, start=None, end=None):
return self.window_frame_rows_start_end(start, end)

View File

@ -81,6 +81,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
result = cursor.fetchone() result = cursor.fetchone()
return result and result[0] == 1 return result and result[0] == 1
@cached_property
def supports_over_clause(self):
return self.connection.mysql_version >= (8, 0, 2)
@cached_property @cached_property
def supports_transactions(self): def supports_transactions(self):
""" """

View File

@ -54,3 +54,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
END; END;
""" """
supports_callproc_kwargs = True supports_callproc_kwargs = True
supports_over_clause = True

View File

@ -48,6 +48,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
V_I := P_I; V_I := P_I;
END; END;
$$ LANGUAGE plpgsql;""" $$ LANGUAGE plpgsql;"""
supports_over_clause = True
@cached_property @cached_property
def supports_aggregate_filter_clause(self): def supports_aggregate_filter_clause(self):

View File

@ -1,6 +1,7 @@
from psycopg2.extras import Inet from psycopg2.extras import Inet
from django.conf import settings from django.conf import settings
from django.db import NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
@ -247,3 +248,12 @@ class DatabaseOperations(BaseDatabaseOperations):
rhs_sql, rhs_params = rhs rhs_sql, rhs_params = rhs
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), lhs_params + rhs_params return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), lhs_params + rhs_params
return super().subtract_temporals(internal_type, lhs, rhs) return super().subtract_temporals(internal_type, lhs, rhs)
def window_frame_range_start_end(self, start=None, end=None):
start_, end_ = super().window_frame_range_start_end(start, end)
if (start and start < 0) or (end and end > 0):
raise NotSupportedError(
'PostgreSQL only supports UNBOUNDED together with PRECEDING '
'and FOLLOWING.'
)
return start_, end_

View File

@ -6,8 +6,8 @@ from django.db.models.deletion import (
CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
) )
from django.db.models.expressions import ( from django.db.models.expressions import (
Case, Exists, Expression, ExpressionWrapper, F, Func, OuterRef, Subquery, Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
Value, When, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame,
) )
from django.db.models.fields import * # NOQA from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all from django.db.models.fields import __all__ as fields_all
@ -64,8 +64,9 @@ __all__ += [
'ObjectDoesNotExist', 'signals', 'ObjectDoesNotExist', 'signals',
'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',
'ProtectedError', 'ProtectedError',
'Case', 'Exists', 'Expression', 'ExpressionWrapper', 'F', 'Func', 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
'OuterRef', 'Subquery', 'Value', 'When', 'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When',
'Window', 'WindowFrame',
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',

View File

@ -15,6 +15,7 @@ class Aggregate(Func):
contains_aggregate = True contains_aggregate = True
name = None name = None
filter_template = '%s FILTER (WHERE %%(filter)s)' filter_template = '%s FILTER (WHERE %%(filter)s)'
window_compatible = True
def __init__(self, *args, filter=None, **kwargs): def __init__(self, *args, filter=None, **kwargs):
self.filter = filter self.filter = filter

View File

@ -3,6 +3,7 @@ import datetime
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
from django.db import connection
from django.db.models import fields from django.db.models import fields
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.deconstruct import deconstructible from django.utils.deconstruct import deconstructible
@ -140,6 +141,10 @@ class BaseExpression:
# aggregate specific fields # aggregate specific fields
is_summary = False is_summary = False
_output_field_resolved_to_none = False _output_field_resolved_to_none = False
# Can the expression be used in a WHERE clause?
filterable = True
# Can the expression can be used as a source expression in Window?
window_compatible = False
def __init__(self, output_field=None): def __init__(self, output_field=None):
if output_field is not None: if output_field is not None:
@ -206,6 +211,13 @@ class BaseExpression:
return True return True
return False return False
@cached_property
def contains_over_clause(self):
for expr in self.get_source_expressions():
if expr and expr.contains_over_clause:
return True
return False
@cached_property @cached_property
def contains_column_references(self): def contains_column_references(self):
for expr in self.get_source_expressions(): for expr in self.get_source_expressions():
@ -232,6 +244,7 @@ class BaseExpression:
c.is_summary = summarize c.is_summary = summarize
c.set_source_expressions([ c.set_source_expressions([
expr.resolve_expression(query, allow_joins, reuse, summarize) expr.resolve_expression(query, allow_joins, reuse, summarize)
if expr else None
for expr in c.get_source_expressions() for expr in c.get_source_expressions()
]) ])
return c return c
@ -482,6 +495,9 @@ class TemporalSubtraction(CombinedExpression):
@deconstructible @deconstructible
class F(Combinable): class F(Combinable):
"""An object capable of resolving references to existing query objects.""" """An object capable of resolving references to existing query objects."""
# Can the expression be used in a WHERE clause?
filterable = True
def __init__(self, name): def __init__(self, name):
""" """
Arguments: Arguments:
@ -767,6 +783,23 @@ class Ref(Expression):
return [self] return [self]
class ExpressionList(Func):
"""
An expression containing multiple expressions. Can be used to provide a
list of expressions as an argument to another expression, like an
ordering clause.
"""
template = '%(expressions)s'
def __init__(self, *expressions, **extra):
if len(expressions) == 0:
raise ValueError('%s requires at least one expression.' % self.__class__.__name__)
super().__init__(*expressions, **extra)
def __str__(self):
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
class ExpressionWrapper(Expression): class ExpressionWrapper(Expression):
""" """
An expression that can wrap another expression so that it can provide An expression that can wrap another expression so that it can provide
@ -1118,3 +1151,168 @@ class OrderBy(BaseExpression):
def desc(self): def desc(self):
self.descending = True self.descending = True
class Window(Expression):
template = '%(expression)s OVER (%(window)s)'
# Although the main expression may either be an aggregate or an
# expression with an aggregate function, the GROUP BY that will
# be introduced in the query as a result is not desired.
contains_aggregate = False
contains_over_clause = True
filterable = False
def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None):
self.partition_by = partition_by
self.order_by = order_by
self.frame = frame
if not getattr(expression, 'window_compatible', False):
raise ValueError(
"Expression '%s' isn't compatible with OVER clauses." %
expression.__class__.__name__
)
if self.partition_by is not None:
if not isinstance(self.partition_by, (tuple, list)):
self.partition_by = (self.partition_by,)
self.partition_by = ExpressionList(*self.partition_by)
if self.order_by is not None:
if isinstance(self.order_by, (list, tuple)):
self.order_by = ExpressionList(*self.order_by)
elif not isinstance(self.order_by, BaseExpression):
raise ValueError(
'order_by must be either an Expression or a sequence of '
'expressions.'
)
super().__init__(output_field=output_field)
self.source_expression = self._parse_expressions(expression)[0]
def _resolve_output_field(self):
return self.source_expression.output_field
def get_source_expressions(self):
return [self.source_expression, self.partition_by, self.order_by, self.frame]
def set_source_expressions(self, exprs):
self.source_expression, self.partition_by, self.order_by, self.frame = exprs
def as_sql(self, compiler, connection, function=None, template=None):
connection.ops.check_expression_support(self)
expr_sql, params = compiler.compile(self.source_expression)
window_sql, window_params = [], []
if self.partition_by is not None:
sql_expr, sql_params = self.partition_by.as_sql(
compiler=compiler, connection=connection,
template='PARTITION BY %(expressions)s',
)
window_sql.extend(sql_expr)
window_params.extend(sql_params)
if self.order_by is not None:
window_sql.append(' ORDER BY ')
order_sql, order_params = compiler.compile(self.order_by)
window_sql.extend(''.join(order_sql))
window_params.extend(order_params)
if self.frame:
frame_sql, frame_params = compiler.compile(self.frame)
window_sql.extend(' ' + frame_sql)
window_params.extend(frame_params)
params.extend(window_params)
template = template or self.template
return template % {
'expression': expr_sql,
'window': ''.join(window_sql).strip()
}, params
def __str__(self):
return '{} OVER ({}{}{})'.format(
str(self.source_expression),
'PARTITION BY ' + str(self.partition_by) if self.partition_by else '',
'ORDER BY ' + str(self.order_by) if self.order_by else '',
str(self.frame or ''),
)
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self):
return []
class WindowFrame(Expression):
"""
Model the frame clause in window expressions. There are two types of frame
clauses which are subclasses, however, all processing and validation (by no
means intended to be complete) is done here. Thus, providing an end for a
frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
row in the frame).
"""
template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
def __init__(self, start=None, end=None):
self.start = start
self.end = end
def set_source_expressions(self, exprs):
self.start, self.end = exprs
def get_source_expressions(self):
return [Value(self.start), Value(self.end)]
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
start, end = self.window_frame_start_end(connection, self.start.value, self.end.value)
return self.template % {
'frame_type': self.frame_type,
'start': start,
'end': end,
}, []
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self):
return []
def __str__(self):
if self.start is not None and self.start < 0:
start = '%d %s' % (abs(self.start), connection.ops.PRECEDING)
elif self.start is not None and self.start == 0:
start = connection.ops.CURRENT_ROW
else:
start = connection.ops.UNBOUNDED_PRECEDING
if self.end is not None and self.end > 0:
end = '%d %s' % (self.end, connection.ops.FOLLOWING)
elif self.end is not None and self.end == 0:
end = connection.ops.CURRENT_ROW
else:
end = connection.ops.UNBOUNDED_FOLLOWING
return self.template % {
'frame_type': self.frame_type,
'start': start,
'end': end,
}
def window_frame_start_end(self, connection, start, end):
raise NotImplementedError('Subclasses must implement window_frame_start_end().')
class RowRange(WindowFrame):
frame_type = 'ROWS'
def window_frame_start_end(self, connection, start, end):
return connection.ops.window_frame_rows_start_end(start, end)
class ValueRange(WindowFrame):
frame_type = 'RANGE'
def window_frame_start_end(self, connection, start, end):
return connection.ops.window_frame_range_start_end(start, end)

View File

@ -8,6 +8,10 @@ from .datetime import (
Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth, Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth,
TruncQuarter, TruncSecond, TruncTime, TruncYear, TruncQuarter, TruncSecond, TruncTime, TruncYear,
) )
from .window import (
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
PercentRank, Rank, RowNumber,
)
__all__ = [ __all__ = [
# base # base
@ -18,4 +22,7 @@ __all__ = [
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractWeekDay', 'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractWeekDay',
'ExtractYear', 'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'ExtractYear', 'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute',
'TruncMonth', 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncYear', 'TruncMonth', 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncYear',
# window
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
] ]

View File

@ -0,0 +1,118 @@
from django.db.models import FloatField, Func, IntegerField
__all__ = [
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]
class CumeDist(Func):
function = 'CUME_DIST'
name = 'CumeDist'
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = 'DENSE_RANK'
name = 'DenseRank'
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = 'FIRST_VALUE'
name = 'FirstValue'
window_compatible = True
class LagLeadFunction(Func):
window_compatible = True
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
'%s requires a non-null source expression.' %
self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
'%s requires a positive integer for the offset.' %
self.__class__.__name__
)
args = (expression, offset)
if default is not None:
args += (default,)
super().__init__(*args, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Lag(LagLeadFunction):
function = 'LAG'
name = 'Lag'
class LastValue(Func):
arity = 1
function = 'LAST_VALUE'
name = 'LastValue'
window_compatible = True
class Lead(LagLeadFunction):
function = 'LEAD'
name = 'Lead'
class NthValue(Func):
function = 'NTH_VALUE'
name = 'NthValue'
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
if nth is None or nth <= 0:
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Ntile(Func):
function = 'NTILE'
name = 'Ntile'
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError('num_buckets must be greater than 0.')
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = 'PERCENT_RANK'
name = 'PercentRank'
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = 'RANK'
name = 'Rank'
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = 'ROW_NUMBER'
name = 'RowNumber'
output_field = IntegerField()
window_compatible = True

View File

@ -115,6 +115,10 @@ class Lookup:
def contains_aggregate(self): def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
@cached_property
def contains_over_clause(self):
return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
@property @property
def is_summary(self): def is_summary(self):
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)

View File

@ -1107,6 +1107,8 @@ class SQLInsertCompiler(SQLCompiler):
) )
if value.contains_aggregate: if value.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query") raise FieldError("Aggregate functions are not allowed in this query")
if value.contains_over_clause:
raise FieldError('Window expressions are not allowed in this query.')
else: else:
value = field.get_db_prep_save(value, connection=self.connection) value = field.get_db_prep_save(value, connection=self.connection)
return value return value
@ -1262,6 +1264,8 @@ class SQLUpdateCompiler(SQLCompiler):
val = val.resolve_expression(self.query, allow_joins=False, for_save=True) val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
if val.contains_aggregate: if val.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query") raise FieldError("Aggregate functions are not allowed in this query")
if val.contains_over_clause:
raise FieldError('Window expressions are not allowed in this query.')
elif hasattr(val, 'prepare_database_save'): elif hasattr(val, 'prepare_database_save'):
if field.remote_field: if field.remote_field:
val = field.get_db_prep_save( val = field.get_db_prep_save(

View File

@ -13,7 +13,7 @@ from string import ascii_uppercase
from django.core.exceptions import ( from django.core.exceptions import (
EmptyResultSet, FieldDoesNotExist, FieldError, EmptyResultSet, FieldDoesNotExist, FieldError,
) )
from django.db import DEFAULT_DB_ALIAS, connections from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref from django.db.models.expressions import Col, Ref
@ -1125,6 +1125,13 @@ class Query:
if not arg: if not arg:
raise FieldError("Cannot parse keyword query %r" % arg) raise FieldError("Cannot parse keyword query %r" % arg)
lookups, parts, reffed_expression = self.solve_lookup_type(arg) lookups, parts, reffed_expression = self.solve_lookup_type(arg)
if not getattr(reffed_expression, 'filterable', True):
raise NotSupportedError(
reffed_expression.__class__.__name__ + ' is disallowed in '
'the filter clause.'
)
if not allow_joins and len(parts) > 1: if not allow_joins and len(parts) > 1:
raise FieldError("Joined field references are not permitted in this query") raise FieldError("Joined field references are not permitted in this query")

View File

@ -167,6 +167,16 @@ class WhereNode(tree.Node):
def contains_aggregate(self): def contains_aggregate(self):
return self._contains_aggregate(self) return self._contains_aggregate(self)
@classmethod
def _contains_over_clause(cls, obj):
if isinstance(obj, tree.Node):
return any(cls._contains_over_clause(c) for c in obj.children)
return obj.contains_over_clause
@cached_property
def contains_over_clause(self):
return self._contains_over_clause(self)
@property @property
def is_summary(self): def is_summary(self):
return any(child.is_summary for child in self.children) return any(child.is_summary for child in self.children)

View File

@ -819,3 +819,132 @@ Usage example::
'minute': 'minute': datetime.datetime(2014, 6, 15, 14, 30, tzinfo=<UTC>), 'minute': 'minute': datetime.datetime(2014, 6, 15, 14, 30, tzinfo=<UTC>),
'second': datetime.datetime(2014, 6, 15, 14, 30, 50, tzinfo=<UTC>) 'second': datetime.datetime(2014, 6, 15, 14, 30, 50, tzinfo=<UTC>)
} }
.. _window-functions:
Window functions
================
.. versionadded:: 2.0
There are a number of functions to use in a
:class:`~django.db.models.expressions.Window` expression for computing the rank
of elements or the :class:`Ntile` of some rows.
``CumeDist``
------------
.. class:: CumeDist(*expressions, **extra)
Calculates the cumulative distribution of a value within a window or partition.
The cumulative distribution is defined as the number of rows preceding or
peered with the current row divided by the total number of rows in the frame.
``DenseRank``
-------------
.. class:: DenseRank(*expressions, **extra)
Equivalent to :class:`Rank` but does not have gaps.
``FirstValue``
--------------
.. class:: FirstValue(expression, **extra)
Returns the value evaluated at the row that's the first row of the window
frame, or ``None`` if no such value exists.
``Lag``
-------
.. class:: Lag(expression, offset=1, default=None, **extra)
Calculates the value offset by ``offset``, and if no row exists there, returns
``default``.
``default`` must have the same type as the ``expression``, however, this is
only validated by the database and not in Python.
``LastValue``
-------------
.. class:: LastValue(expression, **extra)
Comparable to :class:`FirstValue`, it calculates the last value in a given
frame clause.
``Lead``
--------
.. class:: Lead(expression, offset=1, default=None, **extra)
Calculates the leading value in a given :ref:`frame <window-frames>`. Both
``offset`` and ``default`` are evaluated with respect to the current row.
``default`` must have the same type as the ``expression``, however, this is
only validated by the database and not in Python.
``NthValue``
------------
.. class:: NthValue(expression, nth=1, **extra)
Computes the row relative to the offset ``nth`` (must be a positive value)
within the window. Returns ``None`` if no row exists.
Some databases may handle a nonexistent nth-value differently. For example,
Oracle returns an empty string rather than ``None`` for character-based
expressions. Django doesn't do any conversions in these cases.
``Ntile``
---------
.. class:: Ntile(num_buckets=1, **extra)
Calculates a partition for each of the rows in the frame clause, distributing
numbers as evenly as possible between 1 and ``num_buckets``. If the rows don't
divide evenly into a number of buckets, one or more buckets will be represented
more frequently.
``PercentRank``
---------------
.. class:: PercentRank(*expressions, **extra)
Computes the percentile rank of the rows in the frame clause. This
computation is equivalent to evaluating::
(rank - 1) / (total rows - 1)
The following table explains the calculation for the percentile rank of a row:
===== ===== ==== ============ ============
Row # Value Rank Calculation Percent Rank
===== ===== ==== ============ ============
1 15 1 (1-1)/(7-1) 0.0000
2 20 2 (2-1)/(7-1) 0.1666
3 20 2 (2-1)/(7-1) 0.1666
4 20 2 (2-1)/(7-1) 0.1666
5 30 5 (5-1)/(7-1) 0.6666
6 30 5 (5-1)/(7-1) 0.6666
7 40 7 (7-1)/(7-1) 1.0000
===== ===== ==== ============ ============
``Rank``
--------
.. class:: Rank(*expressions, **extra)
Comparable to ``RowNumber``, this function ranks rows in the window. The
computed rank contains gaps. Use :class:`DenseRank` to compute rank without
gaps.
``RowNumber``
-------------
.. class:: RowNumber(*expressions, **extra)
Computes the row number according to the ordering of either the frame clause
or the ordering of the whole query if there is no partitioning of the
:ref:`window frame <window-frames>`.

View File

@ -353,6 +353,13 @@ The ``Aggregate`` API is as follows:
generated. Specifically, the ``function`` will be interpolated as the generated. Specifically, the ``function`` will be interpolated as the
``function`` placeholder within :attr:`template`. Defaults to ``None``. ``function`` placeholder within :attr:`template`. Defaults to ``None``.
.. attribute:: window_compatible
.. versionadded:: 2.0
Defaults to ``True`` since most aggregate functions can be used as the
source expression in :class:`~django.db.models.expressions.Window`.
The ``expression`` argument can be the name of a field on the model, or another The ``expression`` argument can be the name of a field on the model, or another
expression. It will be converted to a string and used as the ``expressions`` expression. It will be converted to a string and used as the ``expressions``
placeholder within the ``template``. placeholder within the ``template``.
@ -649,6 +656,184 @@ should avoid them if possible.
force you to acknowledge that you're not interpolating your SQL with user force you to acknowledge that you're not interpolating your SQL with user
provided data. provided data.
Window functions
----------------
.. versionadded:: 2.0
Window functions provide a way to apply functions on partitions. Unlike a
normal aggregation function which computes a final result for each set defined
by the group by, window functions operate on :ref:`frames <window-frames>` and
partitions, and compute the result for each row.
You can specify multiple windows in the same query which in Django ORM would be
equivalent to including multiple expressions in a :doc:`QuerySet.annotate()
</topics/db/aggregation>` call. The ORM doesn't make use of named windows,
instead they are part of the selected columns.
.. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None)
.. attribute:: filterable
Defaults to ``False``. The SQL standard disallows referencing window
functions in the ``WHERE`` clause and Django raises an exception when
constructing a ``QuerySet`` that would do that.
.. attribute:: template
Defaults to ``%(expression)s OVER (%(window)s)'``. If only the
``expression`` argument is provided, the window clause will be blank.
The ``Window`` class is the main expression for an ``OVER`` clause.
The ``expression`` argument is either a :ref:`window function
<window-functions>`, an :ref:`aggregate function <aggregation-functions>`, or
an expression that's compatible in a window clause.
The ``partition_by`` argument is a list of expressions (column names should be
wrapped in an ``F``-object) that control the partitioning of the rows.
Partitioning narrows which rows are used to compute the result set.
The ``output_field`` is specified either as an argument or by the expression.
The ``order_by`` argument accepts a sequence of expressions on which you can
call :meth:`~django.db.models.Expression.asc` and
:meth:`~django.db.models.Expression.desc`. The ordering controls the order in
which the expression is applied. For example, if you sum over the rows in a
partition, the first result is just the value of the first row, the second is
the sum of first and second row.
The ``frame`` parameter specifies which other rows that should be used in the
computation. See :ref:`window-frames` for details.
For example, to annotate each movie with the average rating for the movies by
the same studio in the same genre and release year::
>>> from django.db.models import Avg, ExtractYear, F, Window
>>> Movie.objects.annotate(
>>> avg_rating=Window(
>>> expression=Avg('rating'),
>>> partition_by=[F('studio'), F('genre')],
>>> order_by=ExtractYear('released').asc(),
>>> ),
>>> )
This makes it easy to check if a movie is rated better or worse than its peers.
You may want to apply multiple expressions over the same window, i.e., the
same partition and frame. For example, you could modify the previous example
to also include the best and worst rating in each movie's group (same studio,
genre, and release year) by using three window functions in the same query. The
partition and ordering from the previous example is extracted into a dictionary
to reduce repetition::
>>> from django.db.models import Avg, ExtractYear, F, Max, Min, Window
>>> window = {
>>> 'partition': [F('studio'), F('genre')],
>>> 'order_by': ExtractYear('released').asc(),
>>> }
>>> Movie.objects.annotate(
>>> avg_rating=Window(
>>> expression=Avg('rating'), **window,
>>> ),
>>> best=Window(
>>> expression=Max('rating'), **window,
>>> ),
>>> worst=Window(
>>> expression=Min('rating'), **window,
>>> ),
>>> )
Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle
support window expressions. Support for different window expression features
varies among the different databases. For example, the options in
:meth:`~django.db.models.Expression.asc` and
:meth:`~django.db.models.Expression.desc` may not be supported. Consult the
documentation for your database as needed.
.. _window-frames:
Frames
~~~~~~
For a window frame, you can choose either a range-based sequence of rows or an
ordinary sequence of rows.
.. class:: ValueRange(start=None, end=None)
.. attribute:: frame_type
This attribute is set to ``'RANGE'``.
PostgreSQL has limited support for ``ValueRange`` and only supports use of
the standard start and end points, such as ``CURRENT ROW`` and ``UNBOUNDED
FOLLOWING``.
.. class:: RowRange(start=None, end=None)
.. attribute:: frame_type
This attribute is set to ``'ROWS'``.
Both classes return SQL with the template::
%(frame_type)s BETWEEN %(start)s AND %(end)s
Frames narrow the rows that are used for computing the result. They shift from
some start point to some specified end point. Frames can be used with and
without partitions, but it's often a good idea to specify an ordering of the
window to ensure a deterministic result. In a frame, a peer in a frame is a row
with an equivalent value, or all rows if an ordering clause isn't present.
The default starting point for a frame is ``UNBOUNDED PRECEDING`` which is the
first row of the partition. The end point is always explicitly included in the
SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default
frame includes all rows from the partition to the last row in the set.
The accepted values for the ``start`` and ``end`` arguments are ``None``, an
integer, or zero. A negative integer for ``start`` results in ``N preceding``,
while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``,
zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``.
There's a difference in what ``CURRENT ROW`` includes. When specified in
``ROWS`` mode, the frame starts or ends with the current row. When specified in
``RANGE`` mode, the frame starts or ends at the first or last peer according to
the ordering clause. Thus, ``RANGE CURRENT ROW`` evaluates the expression for
rows which have the same value specified by the ordering. Because the template
includes both the ``start`` and ``end`` points, this may be expressed with::
ValueRange(start=0, end=0)
If a movie's "peers" are described as movies released by the same studio in the
same genre in the same year, this ``RowRange`` example annotates each movie
with the average rating of a movie's two prior and two following peers::
>>> from django.db.models import Avg, ExtractYear, F, RowRange, Window
>>> Movie.objects.annotate(
>>> avg_rating=Window(
>>> expression=Avg('rating'),
>>> partition_by=[F('studio'), F('genre')],
>>> order_by=ExtractYear('released').asc(),
>>> frame=RowRange(start=-2, end=2),
>>> ),
>>> )
If the database supports it, you can specify the start and end points based on
values of an expression in the partition. If the ``released`` field of the
``Movie`` model stores the release month of each movies, this ``ValueRange``
example annotates each movie with the average rating of a movie's peers
released between twelve months before and twelve months after the each movie.
>>> from django.db.models import Avg, ExpressionList, F, ValueRange, Window
>>> Movie.objects.annotate(
>>> avg_rating=Window(
>>> expression=Avg('rating'),
>>> partition_by=[F('studio'), F('genre')],
>>> order_by=F('released').asc(),
>>> frame=ValueRange(start=-12, end=12),
>>> ),
>>> )
.. currentmodule:: django.db.models .. currentmodule:: django.db.models
Technical Information Technical Information
@ -677,6 +862,30 @@ calling the appropriate methods on the wrapped expression.
Tells Django that this expression contains an aggregate and that a Tells Django that this expression contains an aggregate and that a
``GROUP BY`` clause needs to be added to the query. ``GROUP BY`` clause needs to be added to the query.
.. attribute:: contains_over_clause
.. versionadded:: 2.0
Tells Django that this expression contains a
:class:`~django.db.models.expressions.Window` expression. It's used,
for example, to disallow window function expressions in queries that
modify data. Defaults to ``True``.
.. attribute:: filterable
.. versionadded:: 2.0
Tells Django that this expression can be referenced in
:meth:`.QuerySet.filter`. Defaults to ``True``.
.. attribute:: window_compatible
.. versionadded:: 2.0
Tells Django that this expression can be used as the source expression
in :class:`~django.db.models.expressions.Window`. Defaults to
``False``.
.. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False, for_save=False) .. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False, for_save=False)
Provides the chance to do any pre-processing or validation of Provides the chance to do any pre-processing or validation of

View File

@ -1681,6 +1681,11 @@ raised if ``select_for_update()`` is used in autocommit mode.
``select_for_update()`` you should use ``select_for_update()`` you should use
:class:`~django.test.TransactionTestCase`. :class:`~django.test.TransactionTestCase`.
.. admonition:: Certain expressions may not be supported
PostgreSQL doesn't support ``select_for_update()`` with
:class:`~django.db.models.expressions.Window` expressions.
.. versionchanged:: 1.11 .. versionchanged:: 1.11
The ``skip_locked`` argument was added. The ``skip_locked`` argument was added.

View File

@ -52,6 +52,14 @@ Mobile-friendly ``contrib.admin``
The admin is now responsive and supports all major mobile devices. The admin is now responsive and supports all major mobile devices.
Older browser may experience varying levels of graceful degradation. Older browser may experience varying levels of graceful degradation.
Window expressions
------------------
The new :class:`~django.db.models.expressions.Window` expression allows
adding an ``OVER`` clause to querysets. You can use :ref:`window functions
<window-functions>` and :ref:`aggregate functions <aggregation-functions>` in
the expression.
Minor features Minor features
-------------- --------------
@ -404,6 +412,11 @@ backends.
requires that the arguments to ``OF`` be columns rather than tables, set requires that the arguments to ``OF`` be columns rather than tables, set
``DatabaseFeatures.select_for_update_of_column = True``. ``DatabaseFeatures.select_for_update_of_column = True``.
* To enable support for :class:`~django.db.models.expressions.Window`
expressions, set ``DatabaseFeatures.supports_over_clause`` to ``True``. You
may need to customize the ``DatabaseOperations.window_start_rows_start_end()``
and/or ``window_start_range_start_end()`` methods.
* Third-party database backends should add a * Third-party database backends should add a
``DatabaseOperations.cast_char_field_without_max_length`` attribute with the ``DatabaseOperations.cast_char_field_without_max_length`` attribute with the
database data type that will be used in the database data type that will be used in the

View File

@ -0,0 +1,10 @@
from django.db import NotSupportedError, connection
from django.test import SimpleTestCase, skipIfDBFeature
class DatabaseOperationTests(SimpleTestCase):
@skipIfDBFeature('supports_over_clause')
def test_window_frame_raise_not_supported_error(self):
msg = 'This backend does not support window expressions.'
with self.assertRaisesMessage(NotSupportedError, msg):
connection.ops.window_frame_rows_start_end()

View File

@ -0,0 +1,39 @@
from django.db.models.functions import Lag, Lead, NthValue, Ntile
from django.test import SimpleTestCase
class ValidationTests(SimpleTestCase):
def test_nth_negative_nth_value(self):
msg = 'NthValue requires a positive integer as for nth'
with self.assertRaisesMessage(ValueError, msg):
NthValue(expression='salary', nth=-1)
def test_nth_null_expression(self):
msg = 'NthValue requires a non-null source expression'
with self.assertRaisesMessage(ValueError, msg):
NthValue(expression=None)
def test_lag_negative_offset(self):
msg = 'Lag requires a positive integer for the offset'
with self.assertRaisesMessage(ValueError, msg):
Lag(expression='salary', offset=-1)
def test_lead_negative_offset(self):
msg = 'Lead requires a positive integer for the offset'
with self.assertRaisesMessage(ValueError, msg):
Lead(expression='salary', offset=-1)
def test_null_source_lead(self):
msg = 'Lead requires a non-null source expression'
with self.assertRaisesMessage(ValueError, msg):
Lead(expression=None)
def test_null_source_lag(self):
msg = 'Lag requires a non-null source expression'
with self.assertRaisesMessage(ValueError, msg):
Lag(expression=None)
def test_negative_num_buckets_ntile(self):
msg = 'num_buckets must be greater than 0'
with self.assertRaisesMessage(ValueError, msg):
Ntile(num_buckets=-1)

View File

@ -10,8 +10,8 @@ from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance, Avg, Count, Max, Min, StdDev, Sum, Variance,
) )
from django.db.models.expressions import ( from django.db.models.expressions import (
Case, Col, Exists, ExpressionWrapper, F, Func, OrderBy, OuterRef, Random, Case, Col, Exists, ExpressionList, ExpressionWrapper, F, Func, OrderBy,
RawSQL, Ref, Subquery, Value, When, OuterRef, Random, RawSQL, Ref, Subquery, Value, When,
) )
from django.db.models.functions import ( from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper, Coalesce, Concat, Length, Lower, Substr, Upper,
@ -1330,6 +1330,11 @@ class ValueTests(TestCase):
self.assertNotEqual(value, other_value) self.assertNotEqual(value, other_value)
self.assertNotEqual(value, no_output_field) self.assertNotEqual(value, no_output_field)
def test_raise_empty_expressionlist(self):
msg = 'ExpressionList requires at least one expression'
with self.assertRaisesMessage(ValueError, msg):
ExpressionList()
class ReprTests(TestCase): class ReprTests(TestCase):
@ -1355,6 +1360,14 @@ class ReprTests(TestCase):
self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])") self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))") self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
self.assertEqual(repr(Value(1)), "Value(1)") self.assertEqual(repr(Value(1)), "Value(1)")
self.assertEqual(
repr(ExpressionList(F('col'), F('anothercol'))),
'ExpressionList(F(col), F(anothercol))'
)
self.assertEqual(
repr(ExpressionList(OrderBy(F('col'), descending=False))),
'ExpressionList(OrderBy(F(col), descending=False))'
)
def test_functions(self): def test_functions(self):
self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))") self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")

View File

View File

@ -0,0 +1,11 @@
from django.db import models
class Employee(models.Model):
name = models.CharField(max_length=40, blank=False, null=False)
salary = models.PositiveIntegerField()
department = models.CharField(max_length=40, blank=False, null=False)
hire_date = models.DateField(blank=False, null=False)
def __str__(self):
return '{}, {}, {}, {}'.format(self.name, self.department, self.salary, self.hire_date)

View File

@ -0,0 +1,783 @@
import datetime
from unittest import skipIf, skipUnless
from django.core.exceptions import FieldError
from django.db import NotSupportedError, connection
from django.db.models import (
F, RowRange, Value, ValueRange, Window, WindowFrame,
)
from django.db.models.aggregates import Avg, Max, Min, Sum
from django.db.models.functions import (
CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead,
NthValue, Ntile, PercentRank, Rank, RowNumber, Upper,
)
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import Employee
@skipUnlessDBFeature('supports_over_clause')
class WindowFunctionTests(TestCase):
@classmethod
def setUpTestData(cls):
Employee.objects.bulk_create([
Employee(name=e[0], salary=e[1], department=e[2], hire_date=e[3])
for e in [
('Jones', 45000, 'Accounting', datetime.datetime(2005, 11, 1)),
('Williams', 37000, 'Accounting', datetime.datetime(2009, 6, 1)),
('Jenson', 45000, 'Accounting', datetime.datetime(2008, 4, 1)),
('Adams', 50000, 'Accounting', datetime.datetime(2013, 7, 1)),
('Smith', 55000, 'Sales', datetime.datetime(2007, 6, 1)),
('Brown', 53000, 'Sales', datetime.datetime(2009, 9, 1)),
('Johnson', 40000, 'Marketing', datetime.datetime(2012, 3, 1)),
('Smith', 38000, 'Marketing', datetime.datetime(2009, 10, 1)),
('Wilkinson', 60000, 'IT', datetime.datetime(2011, 3, 1)),
('Moore', 34000, 'IT', datetime.datetime(2013, 8, 1)),
('Miller', 100000, 'Management', datetime.datetime(2005, 6, 1)),
('Johnson', 80000, 'Management', datetime.datetime(2005, 7, 1)),
]
])
def test_dense_rank(self):
qs = Employee.objects.annotate(rank=Window(
expression=DenseRank(),
order_by=ExtractYear(F('hire_date')).asc(),
))
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 2),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 3),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 4),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 4),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 4),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 5),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 6),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 7),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 7),
], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False)
def test_department_salary(self):
qs = Employee.objects.annotate(department_sum=Window(
expression=Sum('salary'),
partition_by=F('department'),
order_by=[F('hire_date').asc()],
)).order_by('department', 'department_sum')
self.assertQuerysetEqual(qs, [
('Jones', 'Accounting', 45000, 45000),
('Jenson', 'Accounting', 45000, 90000),
('Williams', 'Accounting', 37000, 127000),
('Adams', 'Accounting', 50000, 177000),
('Wilkinson', 'IT', 60000, 60000),
('Moore', 'IT', 34000, 94000),
('Miller', 'Management', 100000, 100000),
('Johnson', 'Management', 80000, 180000),
('Smith', 'Marketing', 38000, 38000),
('Johnson', 'Marketing', 40000, 78000),
('Smith', 'Sales', 55000, 55000),
('Brown', 'Sales', 53000, 108000),
], lambda entry: (entry.name, entry.department, entry.salary, entry.department_sum))
def test_rank(self):
"""
Rank the employees based on the year they're were hired. Since there
are multiple employees hired in different years, this will contain
gaps.
"""
qs = Employee.objects.annotate(rank=Window(
expression=Rank(),
order_by=ExtractYear(F('hire_date')).asc(),
))
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 4),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 5),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 6),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 6),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 6),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 9),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 10),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 11),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 11),
], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False)
def test_row_number(self):
"""
The row number window function computes the number based on the order
in which the tuples were inserted. Depending on the backend,
Oracle requires an ordering-clause in the Window expression.
"""
qs = Employee.objects.annotate(row_number=Window(
expression=RowNumber(),
order_by=F('pk').asc(),
)).order_by('pk')
self.assertQuerysetEqual(qs, [
('Jones', 'Accounting', 1),
('Williams', 'Accounting', 2),
('Jenson', 'Accounting', 3),
('Adams', 'Accounting', 4),
('Smith', 'Sales', 5),
('Brown', 'Sales', 6),
('Johnson', 'Marketing', 7),
('Smith', 'Marketing', 8),
('Wilkinson', 'IT', 9),
('Moore', 'IT', 10),
('Miller', 'Management', 11),
('Johnson', 'Management', 12),
], lambda entry: (entry.name, entry.department, entry.row_number))
@skipIf(connection.vendor == 'oracle', "Oracle requires ORDER BY in row_number, ANSI:SQL doesn't")
def test_row_number_no_ordering(self):
"""
The row number window function computes the number based on the order
in which the tuples were inserted.
"""
# Add a default ordering for consistent results across databases.
qs = Employee.objects.annotate(row_number=Window(
expression=RowNumber(),
)).order_by('pk')
self.assertQuerysetEqual(qs, [
('Jones', 'Accounting', 1),
('Williams', 'Accounting', 2),
('Jenson', 'Accounting', 3),
('Adams', 'Accounting', 4),
('Smith', 'Sales', 5),
('Brown', 'Sales', 6),
('Johnson', 'Marketing', 7),
('Smith', 'Marketing', 8),
('Wilkinson', 'IT', 9),
('Moore', 'IT', 10),
('Miller', 'Management', 11),
('Johnson', 'Management', 12),
], lambda entry: (entry.name, entry.department, entry.row_number))
def test_avg_salary_department(self):
qs = Employee.objects.annotate(avg_salary=Window(
expression=Avg('salary'),
order_by=F('department').asc(),
partition_by='department',
)).order_by('department', '-salary', 'name')
self.assertQuerysetEqual(qs, [
('Adams', 50000, 'Accounting', 44250.00),
('Jenson', 45000, 'Accounting', 44250.00),
('Jones', 45000, 'Accounting', 44250.00),
('Williams', 37000, 'Accounting', 44250.00),
('Wilkinson', 60000, 'IT', 47000.00),
('Moore', 34000, 'IT', 47000.00),
('Miller', 100000, 'Management', 90000.00),
('Johnson', 80000, 'Management', 90000.00),
('Johnson', 40000, 'Marketing', 39000.00),
('Smith', 38000, 'Marketing', 39000.00),
('Smith', 55000, 'Sales', 54000.00),
('Brown', 53000, 'Sales', 54000.00),
], transform=lambda row: (row.name, row.salary, row.department, row.avg_salary))
def test_lag(self):
"""
Compute the difference between an employee's salary and the next
highest salary in the employee's department. Return None if the
employee has the lowest salary.
"""
qs = Employee.objects.annotate(lag=Window(
expression=Lag(expression='salary', offset=1),
partition_by=F('department'),
order_by=[F('salary').asc(), F('name').asc()],
)).order_by('department')
self.assertQuerysetEqual(qs, [
('Williams', 37000, 'Accounting', None),
('Jenson', 45000, 'Accounting', 37000),
('Jones', 45000, 'Accounting', 45000),
('Adams', 50000, 'Accounting', 45000),
('Moore', 34000, 'IT', None),
('Wilkinson', 60000, 'IT', 34000),
('Johnson', 80000, 'Management', None),
('Miller', 100000, 'Management', 80000),
('Smith', 38000, 'Marketing', None),
('Johnson', 40000, 'Marketing', 38000),
('Brown', 53000, 'Sales', None),
('Smith', 55000, 'Sales', 53000),
], transform=lambda row: (row.name, row.salary, row.department, row.lag))
def test_first_value(self):
qs = Employee.objects.annotate(first_value=Window(
expression=FirstValue('salary'),
partition_by=F('department'),
order_by=F('hire_date').asc(),
)).order_by('department', 'hire_date')
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 45000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 45000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 45000),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 60000),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 100000),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 38000),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 55000),
], lambda row: (row.name, row.salary, row.department, row.hire_date, row.first_value))
def test_last_value(self):
qs = Employee.objects.annotate(last_value=Window(
expression=LastValue('hire_date'),
partition_by=F('department'),
order_by=F('hire_date').asc(),
))
self.assertQuerysetEqual(qs, [
('Adams', 'Accounting', datetime.date(2013, 7, 1), 50000, datetime.date(2013, 7, 1)),
('Jenson', 'Accounting', datetime.date(2008, 4, 1), 45000, datetime.date(2008, 4, 1)),
('Jones', 'Accounting', datetime.date(2005, 11, 1), 45000, datetime.date(2005, 11, 1)),
('Williams', 'Accounting', datetime.date(2009, 6, 1), 37000, datetime.date(2009, 6, 1)),
('Moore', 'IT', datetime.date(2013, 8, 1), 34000, datetime.date(2013, 8, 1)),
('Wilkinson', 'IT', datetime.date(2011, 3, 1), 60000, datetime.date(2011, 3, 1)),
('Miller', 'Management', datetime.date(2005, 6, 1), 100000, datetime.date(2005, 6, 1)),
('Johnson', 'Management', datetime.date(2005, 7, 1), 80000, datetime.date(2005, 7, 1)),
('Johnson', 'Marketing', datetime.date(2012, 3, 1), 40000, datetime.date(2012, 3, 1)),
('Smith', 'Marketing', datetime.date(2009, 10, 1), 38000, datetime.date(2009, 10, 1)),
('Brown', 'Sales', datetime.date(2009, 9, 1), 53000, datetime.date(2009, 9, 1)),
('Smith', 'Sales', datetime.date(2007, 6, 1), 55000, datetime.date(2007, 6, 1)),
], transform=lambda row: (row.name, row.department, row.hire_date, row.salary, row.last_value), ordered=False)
def test_function_list_of_values(self):
qs = Employee.objects.annotate(lead=Window(
expression=Lead(expression='salary'),
order_by=[F('hire_date').asc(), F('name').desc()],
partition_by='department',
)).values_list('name', 'salary', 'department', 'hire_date', 'lead')
self.assertNotIn('GROUP BY', str(qs.query))
self.assertSequenceEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 37000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 50000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 34000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 80000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 40000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 53000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None),
])
def test_min_department(self):
"""An alternative way to specify a query for FirstValue."""
qs = Employee.objects.annotate(min_salary=Window(
expression=Min('salary'),
partition_by=F('department'),
order_by=[F('salary').asc(), F('name').asc()]
)).order_by('department', 'salary', 'name')
self.assertQuerysetEqual(qs, [
('Williams', 'Accounting', 37000, 37000),
('Jenson', 'Accounting', 45000, 37000),
('Jones', 'Accounting', 45000, 37000),
('Adams', 'Accounting', 50000, 37000),
('Moore', 'IT', 34000, 34000),
('Wilkinson', 'IT', 60000, 34000),
('Johnson', 'Management', 80000, 80000),
('Miller', 'Management', 100000, 80000),
('Smith', 'Marketing', 38000, 38000),
('Johnson', 'Marketing', 40000, 38000),
('Brown', 'Sales', 53000, 53000),
('Smith', 'Sales', 55000, 53000),
], lambda row: (row.name, row.department, row.salary, row.min_salary))
def test_max_per_year(self):
"""
Find the maximum salary awarded in the same year as the
employee was hired, regardless of the department.
"""
qs = Employee.objects.annotate(max_salary_year=Window(
expression=Max('salary'),
order_by=ExtractYear('hire_date').asc(),
partition_by=ExtractYear('hire_date')
)).order_by(ExtractYear('hire_date'), 'salary')
self.assertQuerysetEqual(qs, [
('Jones', 'Accounting', 45000, 2005, 100000),
('Johnson', 'Management', 80000, 2005, 100000),
('Miller', 'Management', 100000, 2005, 100000),
('Smith', 'Sales', 55000, 2007, 55000),
('Jenson', 'Accounting', 45000, 2008, 45000),
('Williams', 'Accounting', 37000, 2009, 53000),
('Smith', 'Marketing', 38000, 2009, 53000),
('Brown', 'Sales', 53000, 2009, 53000),
('Wilkinson', 'IT', 60000, 2011, 60000),
('Johnson', 'Marketing', 40000, 2012, 40000),
('Moore', 'IT', 34000, 2013, 50000),
('Adams', 'Accounting', 50000, 2013, 50000),
], lambda row: (row.name, row.department, row.salary, row.hire_date.year, row.max_salary_year))
def test_cume_dist(self):
"""
Compute the cumulative distribution for the employees based on the
salary in increasing order. Equal to rank/total number of rows (12).
"""
qs = Employee.objects.annotate(cume_dist=Window(
expression=CumeDist(),
order_by=F('salary').asc(),
)).order_by('salary', 'name')
# Round result of cume_dist because Oracle uses greater precision.
self.assertQuerysetEqual(qs, [
('Moore', 'IT', 34000, 0.0833333333),
('Williams', 'Accounting', 37000, 0.1666666667),
('Smith', 'Marketing', 38000, 0.25),
('Johnson', 'Marketing', 40000, 0.3333333333),
('Jenson', 'Accounting', 45000, 0.5),
('Jones', 'Accounting', 45000, 0.5),
('Adams', 'Accounting', 50000, 0.5833333333),
('Brown', 'Sales', 53000, 0.6666666667),
('Smith', 'Sales', 55000, 0.75),
('Wilkinson', 'IT', 60000, 0.8333333333),
('Johnson', 'Management', 80000, 0.9166666667),
('Miller', 'Management', 100000, 1),
], lambda row: (row.name, row.department, row.salary, round(row.cume_dist, 10)))
def test_nthvalue(self):
qs = Employee.objects.annotate(
nth_value=Window(expression=NthValue(
expression='salary', nth=2),
order_by=[F('hire_date').asc(), F('name').desc()],
partition_by=F('department'),
)
).order_by('department', 'hire_date', 'name')
self.assertQuerysetEqual(qs, [
('Jones', 'Accounting', datetime.date(2005, 11, 1), 45000, None),
('Jenson', 'Accounting', datetime.date(2008, 4, 1), 45000, 45000),
('Williams', 'Accounting', datetime.date(2009, 6, 1), 37000, 45000),
('Adams', 'Accounting', datetime.date(2013, 7, 1), 50000, 45000),
('Wilkinson', 'IT', datetime.date(2011, 3, 1), 60000, None),
('Moore', 'IT', datetime.date(2013, 8, 1), 34000, 34000),
('Miller', 'Management', datetime.date(2005, 6, 1), 100000, None),
('Johnson', 'Management', datetime.date(2005, 7, 1), 80000, 80000),
('Smith', 'Marketing', datetime.date(2009, 10, 1), 38000, None),
('Johnson', 'Marketing', datetime.date(2012, 3, 1), 40000, 40000),
('Smith', 'Sales', datetime.date(2007, 6, 1), 55000, None),
('Brown', 'Sales', datetime.date(2009, 9, 1), 53000, 53000),
], lambda row: (row.name, row.department, row.hire_date, row.salary, row.nth_value))
def test_lead(self):
"""
Determine what the next person hired in the same department makes.
Because the dataset is ambiguous, the name is also part of the
ordering clause. No default is provided, so None/NULL should be
returned.
"""
qs = Employee.objects.annotate(lead=Window(
expression=Lead(expression='salary'),
order_by=[F('hire_date').asc(), F('name').desc()],
partition_by='department',
)).order_by('department')
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 37000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 50000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 34000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 80000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 40000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 53000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None),
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.lead))
def test_lead_offset(self):
"""
Determine what the person hired after someone makes. Due to
ambiguity, the name is also included in the ordering.
"""
qs = Employee.objects.annotate(lead=Window(
expression=Lead('salary', offset=2),
partition_by='department',
order_by=F('hire_date').asc(),
))
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 37000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 50000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), None),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), None),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), None),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), None),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), None),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None),
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.lead),
ordered=False
)
def test_lead_default(self):
qs = Employee.objects.annotate(lead_default=Window(
expression=Lead(expression='salary', offset=5, default=60000),
partition_by=F('department'),
order_by=F('department').asc(),
))
self.assertEqual(list(qs.values_list('lead_default', flat=True).distinct()), [60000])
def test_ntile(self):
"""
Compute the group for each of the employees across the entire company,
based on how high the salary is for them. There are twelve employees
so it divides evenly into four groups.
"""
qs = Employee.objects.annotate(ntile=Window(
expression=Ntile(num_buckets=4),
order_by=F('salary').desc(),
)).order_by('ntile', '-salary', 'name')
self.assertQuerysetEqual(qs, [
('Miller', 'Management', 100000, 1),
('Johnson', 'Management', 80000, 1),
('Wilkinson', 'IT', 60000, 1),
('Smith', 'Sales', 55000, 2),
('Brown', 'Sales', 53000, 2),
('Adams', 'Accounting', 50000, 2),
('Jenson', 'Accounting', 45000, 3),
('Jones', 'Accounting', 45000, 3),
('Johnson', 'Marketing', 40000, 3),
('Smith', 'Marketing', 38000, 4),
('Williams', 'Accounting', 37000, 4),
('Moore', 'IT', 34000, 4),
], lambda x: (x.name, x.department, x.salary, x.ntile))
def test_percent_rank(self):
"""
Calculate the percentage rank of the employees across the entire
company based on salary and name (in case of ambiguity).
"""
qs = Employee.objects.annotate(percent_rank=Window(
expression=PercentRank(),
order_by=[F('salary').asc(), F('name').asc()],
)).order_by('percent_rank')
# Round to account for precision differences among databases.
self.assertQuerysetEqual(qs, [
('Moore', 'IT', 34000, 0.0),
('Williams', 'Accounting', 37000, 0.0909090909),
('Smith', 'Marketing', 38000, 0.1818181818),
('Johnson', 'Marketing', 40000, 0.2727272727),
('Jenson', 'Accounting', 45000, 0.3636363636),
('Jones', 'Accounting', 45000, 0.4545454545),
('Adams', 'Accounting', 50000, 0.5454545455),
('Brown', 'Sales', 53000, 0.6363636364),
('Smith', 'Sales', 55000, 0.7272727273),
('Wilkinson', 'IT', 60000, 0.8181818182),
('Johnson', 'Management', 80000, 0.9090909091),
('Miller', 'Management', 100000, 1.0),
], transform=lambda row: (row.name, row.department, row.salary, round(row.percent_rank, 10)))
def test_nth_returns_null(self):
"""
Find the nth row of the data set. None is returned since there are
fewer than 20 rows in the test data.
"""
qs = Employee.objects.annotate(nth_value=Window(
expression=NthValue('salary', nth=20),
order_by=F('salary').asc()
))
self.assertEqual(list(qs.values_list('nth_value', flat=True).distinct()), [None])
def test_multiple_partitioning(self):
"""
Find the maximum salary for each department for people hired in the
same year.
"""
qs = Employee.objects.annotate(max=Window(
expression=Max('salary'),
partition_by=[F('department'), ExtractYear(F('hire_date'))],
)).order_by('department', 'hire_date', 'name')
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 45000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 37000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 50000),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 34000),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 100000),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 40000),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 53000),
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.max))
def test_multiple_ordering(self):
"""
Accumulate the salaries over the departments based on hire_date.
If two people were hired on the same date in the same department, the
ordering clause will render a different result for those people.
"""
qs = Employee.objects.annotate(sum=Window(
expression=Sum('salary'),
partition_by='department',
order_by=[F('hire_date').asc(), F('name').asc()],
)).order_by('department', 'sum')
self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 90000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 127000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 177000),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 94000),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 180000),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 78000),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 108000),
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum))
@skipIf(connection.vendor == 'postgresql', 'n following/preceding not supported by PostgreSQL')
def test_range_n_preceding_and_following(self):
qs = Employee.objects.annotate(sum=Window(
expression=Sum('salary'),
order_by=F('salary').asc(),
partition_by='department',
frame=ValueRange(start=-2, end=2),
))
self.assertIn('RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING', str(qs.query))
self.assertQuerysetEqual(qs, [
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 37000),
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 90000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 90000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 50000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 53000),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 40000),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 34000),
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 80000),
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum), ordered=False)
def test_range_unbound(self):
"""A query with RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING."""
qs = Employee.objects.annotate(sum=Window(
expression=Sum('salary'),
partition_by='department',
order_by=[F('hire_date').asc(), F('name').asc()],
frame=ValueRange(start=None, end=None),
)).order_by('department', 'hire_date', 'name')
self.assertIn('RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING', str(qs.query))
self.assertQuerysetEqual(qs, [
('Jones', 'Accounting', 45000, datetime.date(2005, 11, 1), 177000),
('Jenson', 'Accounting', 45000, datetime.date(2008, 4, 1), 177000),
('Williams', 'Accounting', 37000, datetime.date(2009, 6, 1), 177000),
('Adams', 'Accounting', 50000, datetime.date(2013, 7, 1), 177000),
('Wilkinson', 'IT', 60000, datetime.date(2011, 3, 1), 94000),
('Moore', 'IT', 34000, datetime.date(2013, 8, 1), 94000),
('Miller', 'Management', 100000, datetime.date(2005, 6, 1), 180000),
('Johnson', 'Management', 80000, datetime.date(2005, 7, 1), 180000),
('Smith', 'Marketing', 38000, datetime.date(2009, 10, 1), 78000),
('Johnson', 'Marketing', 40000, datetime.date(2012, 3, 1), 78000),
('Smith', 'Sales', 55000, datetime.date(2007, 6, 1), 108000),
('Brown', 'Sales', 53000, datetime.date(2009, 9, 1), 108000),
], transform=lambda row: (row.name, row.department, row.salary, row.hire_date, row.sum))
def test_row_range_rank(self):
"""
A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING.
The resulting sum is the sum of the three next (if they exist) and all
previous rows according to the ordering clause.
"""
qs = Employee.objects.annotate(sum=Window(
expression=Sum('salary'),
order_by=[F('hire_date').asc(), F('name').desc()],
frame=RowRange(start=None, end=3),
)).order_by('sum', 'hire_date')
self.assertIn('ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING', str(qs.query))
self.assertQuerysetEqual(qs, [
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 280000),
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 325000),
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 362000),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 415000),
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 453000),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 513000),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 553000),
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 603000),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 637000),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 637000),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 637000),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 637000),
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum))
@skipUnlessDBFeature('can_distinct_on_fields')
def test_distinct_window_function(self):
"""
Window functions are not aggregates, and hence a query to filter out
duplicates may be useful.
"""
qs = Employee.objects.annotate(
sum=Window(
expression=Sum('salary'),
partition_by=ExtractYear('hire_date'),
order_by=ExtractYear('hire_date')
),
year=ExtractYear('hire_date'),
).values('year', 'sum').distinct('year').order_by('year')
results = [
{'year': 2005, 'sum': 225000}, {'year': 2007, 'sum': 55000},
{'year': 2008, 'sum': 45000}, {'year': 2009, 'sum': 128000},
{'year': 2011, 'sum': 60000}, {'year': 2012, 'sum': 40000},
{'year': 2013, 'sum': 84000},
]
for idx, val in zip(range(len(results)), results):
with self.subTest(result=val):
self.assertEqual(qs[idx], val)
def test_fail_update(self):
"""Window expressions can't be used in an UPDATE statement."""
msg = 'Window expressions are not allowed in this query'
with self.assertRaisesMessage(FieldError, msg):
Employee.objects.filter(department='Management').update(
salary=Window(expression=Max('salary'), partition_by='department'),
)
def test_fail_insert(self):
"""Window expressions can't be used in an INSERT statement."""
msg = 'Window expressions are not allowed in this query'
with self.assertRaisesMessage(FieldError, msg):
Employee.objects.create(
name='Jameson', department='Management', hire_date=datetime.date(2007, 7, 1),
salary=Window(expression=Sum(Value(10000), order_by=F('pk').asc())),
)
def test_invalid_start_value_range(self):
msg = "start argument must be a negative integer, zero, or None, but got '3'."
with self.assertRaisesMessage(ValueError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
order_by=F('hire_date').asc(),
frame=ValueRange(start=3),
)))
def test_invalid_end_value_range(self):
msg = "end argument must be a positive integer, zero, or None, but got '-3'."
with self.assertRaisesMessage(ValueError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
order_by=F('hire_date').asc(),
frame=ValueRange(end=-3),
)))
def test_invalid_type_end_value_range(self):
msg = "end argument must be a positive integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
order_by=F('hire_date').asc(),
frame=ValueRange(end='a'),
)))
def test_invalid_type_start_value_range(self):
msg = "start argument must be a negative integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
frame=ValueRange(start='a'),
)))
def test_invalid_type_end_row_range(self):
msg = "end argument must be a positive integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
frame=RowRange(end='a'),
)))
@skipUnless(connection.vendor == 'postgresql', 'Frame construction not allowed on PostgreSQL')
def test_postgresql_illegal_range_frame_start(self):
msg = 'PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING.'
with self.assertRaisesMessage(NotSupportedError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
order_by=F('hire_date').asc(),
frame=ValueRange(start=-1),
)))
@skipUnless(connection.vendor == 'postgresql', 'Frame construction not allowed on PostgreSQL')
def test_postgresql_illegal_range_frame_end(self):
msg = 'PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING.'
with self.assertRaisesMessage(NotSupportedError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
order_by=F('hire_date').asc(),
frame=ValueRange(end=1),
)))
def test_invalid_type_start_row_range(self):
msg = "start argument must be a negative integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
list(Employee.objects.annotate(test=Window(
expression=Sum('salary'),
order_by=F('hire_date').asc(),
frame=RowRange(start='a'),
)))
class NonQueryWindowTests(SimpleTestCase):
def test_window_repr(self):
self.assertEqual(
repr(Window(expression=Sum('salary'), partition_by='department')),
'<Window: Sum(F(salary)) OVER (PARTITION BY F(department))>'
)
self.assertEqual(
repr(Window(expression=Avg('salary'), order_by=F('department').asc())),
'<Window: Avg(F(salary)) OVER (ORDER BY OrderBy(F(department), descending=False))>'
)
def test_window_frame_repr(self):
self.assertEqual(
repr(RowRange(start=-1)),
'<RowRange: ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING>'
)
self.assertEqual(
repr(ValueRange(start=None, end=1)),
'<ValueRange: RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING>'
)
self.assertEqual(
repr(ValueRange(start=0, end=0)),
'<ValueRange: RANGE BETWEEN CURRENT ROW AND CURRENT ROW>'
)
self.assertEqual(
repr(RowRange(start=0, end=0)),
'<RowRange: ROWS BETWEEN CURRENT ROW AND CURRENT ROW>'
)
def test_empty_group_by_cols(self):
window = Window(expression=Sum('pk'))
self.assertEqual(window.get_group_by_cols(), [])
self.assertFalse(window.contains_aggregate)
def test_frame_empty_group_by_cols(self):
frame = WindowFrame()
self.assertEqual(frame.get_group_by_cols(), [])
def test_frame_window_frame_notimplemented(self):
frame = WindowFrame()
msg = 'Subclasses must implement window_frame_start_end().'
with self.assertRaisesMessage(NotImplementedError, msg):
frame.window_frame_start_end(None, None, None)
def test_invalid_filter(self):
msg = 'Window is disallowed in the filter clause'
with self.assertRaisesMessage(NotSupportedError, msg):
Employee.objects.annotate(dense_rank=Window(expression=DenseRank())).filter(dense_rank__gte=1)
def test_invalid_order_by(self):
msg = 'order_by must be either an Expression or a sequence of expressions'
with self.assertRaisesMessage(ValueError, msg):
Window(expression=Sum('power'), order_by='-horse')
def test_invalid_source_expression(self):
msg = "Expression 'Upper' isn't compatible with OVER clauses."
with self.assertRaisesMessage(ValueError, msg):
Window(expression=Upper('name'))