mirror of https://github.com/django/django.git
Fixed #26608 -- Added support for window expressions (OVER clause).
Thanks Josh Smeaton, Mariusz Felisiak, Sergey Fedoseev, Simon Charettes, Adam Chainz/Johnson and Tim Graham for comments and reviews and Jamie Cockburn for initial patch.
This commit is contained in:
parent
da1ba03f1d
commit
d549b88050
|
@ -236,6 +236,9 @@ class BaseDatabaseFeatures:
|
|||
# Does the backend support indexing a TextField?
|
||||
supports_index_on_text_field = True
|
||||
|
||||
# Does the backed support window expressions (expression OVER (...))?
|
||||
supports_over_clause = False
|
||||
|
||||
# Does the backend support CAST with precision?
|
||||
supports_cast_with_precision = True
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from importlib import import_module
|
|||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import transaction
|
||||
from django.db import NotSupportedError, transaction
|
||||
from django.db.backends import utils
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_duration
|
||||
|
@ -39,6 +39,13 @@ class BaseDatabaseOperations:
|
|||
# CharField data type if the max_length argument isn't provided.
|
||||
cast_char_field_without_max_length = None
|
||||
|
||||
# Start and end points for window expressions.
|
||||
PRECEDING = 'PRECEDING'
|
||||
FOLLOWING = 'FOLLOWING'
|
||||
UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
|
||||
UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
|
||||
CURRENT_ROW = 'CURRENT ROW'
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._cache = None
|
||||
|
@ -598,3 +605,34 @@ class BaseDatabaseOperations:
|
|||
rhs_sql, rhs_params = rhs
|
||||
return "(%s - %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
|
||||
raise NotImplementedError("This backend does not support %s subtraction." % internal_type)
|
||||
|
||||
def window_frame_start(self, start):
|
||||
if isinstance(start, int):
|
||||
if start < 0:
|
||||
return '%d %s' % (abs(start), self.PRECEDING)
|
||||
elif start == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif start is None:
|
||||
return self.UNBOUNDED_PRECEDING
|
||||
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
|
||||
|
||||
def window_frame_end(self, end):
|
||||
if isinstance(end, int):
|
||||
if end == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif end > 0:
|
||||
return '%d %s' % (end, self.FOLLOWING)
|
||||
elif end is None:
|
||||
return self.UNBOUNDED_FOLLOWING
|
||||
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
|
||||
|
||||
def window_frame_rows_start_end(self, start=None, end=None):
|
||||
"""
|
||||
Return SQL for start and end points in an OVER clause window frame.
|
||||
"""
|
||||
if not self.connection.features.supports_over_clause:
|
||||
raise NotSupportedError('This backend does not support window expressions.')
|
||||
return self.window_frame_start(start), self.window_frame_end(end)
|
||||
|
||||
def window_frame_range_start_end(self, start=None, end=None):
|
||||
return self.window_frame_rows_start_end(start, end)
|
||||
|
|
|
@ -81,6 +81,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
result = cursor.fetchone()
|
||||
return result and result[0] == 1
|
||||
|
||||
@cached_property
|
||||
def supports_over_clause(self):
|
||||
return self.connection.mysql_version >= (8, 0, 2)
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""
|
||||
|
|
|
@ -54,3 +54,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
END;
|
||||
"""
|
||||
supports_callproc_kwargs = True
|
||||
supports_over_clause = True
|
||||
|
|
|
@ -48,6 +48,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
V_I := P_I;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
supports_over_clause = True
|
||||
|
||||
@cached_property
|
||||
def supports_aggregate_filter_clause(self):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from psycopg2.extras import Inet
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
|
||||
|
||||
|
@ -247,3 +248,12 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
rhs_sql, rhs_params = rhs
|
||||
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), lhs_params + rhs_params
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def window_frame_range_start_end(self, start=None, end=None):
|
||||
start_, end_ = super().window_frame_range_start_end(start, end)
|
||||
if (start and start < 0) or (end and end > 0):
|
||||
raise NotSupportedError(
|
||||
'PostgreSQL only supports UNBOUNDED together with PRECEDING '
|
||||
'and FOLLOWING.'
|
||||
)
|
||||
return start_, end_
|
||||
|
|
|
@ -6,8 +6,8 @@ from django.db.models.deletion import (
|
|||
CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
|
||||
)
|
||||
from django.db.models.expressions import (
|
||||
Case, Exists, Expression, ExpressionWrapper, F, Func, OuterRef, Subquery,
|
||||
Value, When,
|
||||
Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
|
||||
OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame,
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
from django.db.models.fields import __all__ as fields_all
|
||||
|
@ -64,8 +64,9 @@ __all__ += [
|
|||
'ObjectDoesNotExist', 'signals',
|
||||
'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',
|
||||
'ProtectedError',
|
||||
'Case', 'Exists', 'Expression', 'ExpressionWrapper', 'F', 'Func',
|
||||
'OuterRef', 'Subquery', 'Value', 'When',
|
||||
'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
|
||||
'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When',
|
||||
'Window', 'WindowFrame',
|
||||
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
|
||||
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
|
||||
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
|
||||
|
|
|
@ -15,6 +15,7 @@ class Aggregate(Func):
|
|||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, *args, filter=None, **kwargs):
|
||||
self.filter = filter
|
||||
|
|
|
@ -3,6 +3,7 @@ import datetime
|
|||
from decimal import Decimal
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FieldError
|
||||
from django.db import connection
|
||||
from django.db.models import fields
|
||||
from django.db.models.query_utils import Q
|
||||
from django.utils.deconstruct import deconstructible
|
||||
|
@ -140,6 +141,10 @@ class BaseExpression:
|
|||
# aggregate specific fields
|
||||
is_summary = False
|
||||
_output_field_resolved_to_none = False
|
||||
# Can the expression be used in a WHERE clause?
|
||||
filterable = True
|
||||
# Can the expression can be used as a source expression in Window?
|
||||
window_compatible = False
|
||||
|
||||
def __init__(self, output_field=None):
|
||||
if output_field is not None:
|
||||
|
@ -206,6 +211,13 @@ class BaseExpression:
|
|||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def contains_over_clause(self):
|
||||
for expr in self.get_source_expressions():
|
||||
if expr and expr.contains_over_clause:
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def contains_column_references(self):
|
||||
for expr in self.get_source_expressions():
|
||||
|
@ -232,6 +244,7 @@ class BaseExpression:
|
|||
c.is_summary = summarize
|
||||
c.set_source_expressions([
|
||||
expr.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if expr else None
|
||||
for expr in c.get_source_expressions()
|
||||
])
|
||||
return c
|
||||
|
@ -482,6 +495,9 @@ class TemporalSubtraction(CombinedExpression):
|
|||
@deconstructible
|
||||
class F(Combinable):
|
||||
"""An object capable of resolving references to existing query objects."""
|
||||
# Can the expression be used in a WHERE clause?
|
||||
filterable = True
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Arguments:
|
||||
|
@ -767,6 +783,23 @@ class Ref(Expression):
|
|||
return [self]
|
||||
|
||||
|
||||
class ExpressionList(Func):
|
||||
"""
|
||||
An expression containing multiple expressions. Can be used to provide a
|
||||
list of expressions as an argument to another expression, like an
|
||||
ordering clause.
|
||||
"""
|
||||
template = '%(expressions)s'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) == 0:
|
||||
raise ValueError('%s requires at least one expression.' % self.__class__.__name__)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def __str__(self):
|
||||
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||
|
||||
|
||||
class ExpressionWrapper(Expression):
|
||||
"""
|
||||
An expression that can wrap another expression so that it can provide
|
||||
|
@ -1118,3 +1151,168 @@ class OrderBy(BaseExpression):
|
|||
|
||||
def desc(self):
|
||||
self.descending = True
|
||||
|
||||
|
||||
class Window(Expression):
|
||||
template = '%(expression)s OVER (%(window)s)'
|
||||
# Although the main expression may either be an aggregate or an
|
||||
# expression with an aggregate function, the GROUP BY that will
|
||||
# be introduced in the query as a result is not desired.
|
||||
contains_aggregate = False
|
||||
contains_over_clause = True
|
||||
filterable = False
|
||||
|
||||
def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None):
|
||||
self.partition_by = partition_by
|
||||
self.order_by = order_by
|
||||
self.frame = frame
|
||||
|
||||
if not getattr(expression, 'window_compatible', False):
|
||||
raise ValueError(
|
||||
"Expression '%s' isn't compatible with OVER clauses." %
|
||||
expression.__class__.__name__
|
||||
)
|
||||
|
||||
if self.partition_by is not None:
|
||||
if not isinstance(self.partition_by, (tuple, list)):
|
||||
self.partition_by = (self.partition_by,)
|
||||
self.partition_by = ExpressionList(*self.partition_by)
|
||||
|
||||
if self.order_by is not None:
|
||||
if isinstance(self.order_by, (list, tuple)):
|
||||
self.order_by = ExpressionList(*self.order_by)
|
||||
elif not isinstance(self.order_by, BaseExpression):
|
||||
raise ValueError(
|
||||
'order_by must be either an Expression or a sequence of '
|
||||
'expressions.'
|
||||
)
|
||||
super().__init__(output_field=output_field)
|
||||
self.source_expression = self._parse_expressions(expression)[0]
|
||||
|
||||
def _resolve_output_field(self):
|
||||
return self.source_expression.output_field
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.source_expression, self.partition_by, self.order_by, self.frame]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.source_expression, self.partition_by, self.order_by, self.frame = exprs
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
connection.ops.check_expression_support(self)
|
||||
expr_sql, params = compiler.compile(self.source_expression)
|
||||
window_sql, window_params = [], []
|
||||
|
||||
if self.partition_by is not None:
|
||||
sql_expr, sql_params = self.partition_by.as_sql(
|
||||
compiler=compiler, connection=connection,
|
||||
template='PARTITION BY %(expressions)s',
|
||||
)
|
||||
window_sql.extend(sql_expr)
|
||||
window_params.extend(sql_params)
|
||||
|
||||
if self.order_by is not None:
|
||||
window_sql.append(' ORDER BY ')
|
||||
order_sql, order_params = compiler.compile(self.order_by)
|
||||
window_sql.extend(''.join(order_sql))
|
||||
window_params.extend(order_params)
|
||||
|
||||
if self.frame:
|
||||
frame_sql, frame_params = compiler.compile(self.frame)
|
||||
window_sql.extend(' ' + frame_sql)
|
||||
window_params.extend(frame_params)
|
||||
|
||||
params.extend(window_params)
|
||||
template = template or self.template
|
||||
|
||||
return template % {
|
||||
'expression': expr_sql,
|
||||
'window': ''.join(window_sql).strip()
|
||||
}, params
|
||||
|
||||
def __str__(self):
|
||||
return '{} OVER ({}{}{})'.format(
|
||||
str(self.source_expression),
|
||||
'PARTITION BY ' + str(self.partition_by) if self.partition_by else '',
|
||||
'ORDER BY ' + str(self.order_by) if self.order_by else '',
|
||||
str(self.frame or ''),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: %s>' % (self.__class__.__name__, self)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
|
||||
class WindowFrame(Expression):
|
||||
"""
|
||||
Model the frame clause in window expressions. There are two types of frame
|
||||
clauses which are subclasses, however, all processing and validation (by no
|
||||
means intended to be complete) is done here. Thus, providing an end for a
|
||||
frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
|
||||
row in the frame).
|
||||
"""
|
||||
template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
|
||||
|
||||
def __init__(self, start=None, end=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.start, self.end = exprs
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [Value(self.start), Value(self.end)]
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
start, end = self.window_frame_start_end(connection, self.start.value, self.end.value)
|
||||
return self.template % {
|
||||
'frame_type': self.frame_type,
|
||||
'start': start,
|
||||
'end': end,
|
||||
}, []
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: %s>' % (self.__class__.__name__, self)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
def __str__(self):
|
||||
if self.start is not None and self.start < 0:
|
||||
start = '%d %s' % (abs(self.start), connection.ops.PRECEDING)
|
||||
elif self.start is not None and self.start == 0:
|
||||
start = connection.ops.CURRENT_ROW
|
||||
else:
|
||||
start = connection.ops.UNBOUNDED_PRECEDING
|
||||
|
||||
if self.end is not None and self.end > 0:
|
||||
end = '%d %s' % (self.end, connection.ops.FOLLOWING)
|
||||
elif self.end is not None and self.end == 0:
|
||||
end = connection.ops.CURRENT_ROW
|
||||
else:
|
||||
end = connection.ops.UNBOUNDED_FOLLOWING
|
||||
return self.template % {
|
||||
'frame_type': self.frame_type,
|
||||
'start': start,
|
||||
'end': end,
|
||||
}
|
||||
|
||||
def window_frame_start_end(self, connection, start, end):
|
||||
raise NotImplementedError('Subclasses must implement window_frame_start_end().')
|
||||
|
||||
|
||||
class RowRange(WindowFrame):
|
||||
frame_type = 'ROWS'
|
||||
|
||||
def window_frame_start_end(self, connection, start, end):
|
||||
return connection.ops.window_frame_rows_start_end(start, end)
|
||||
|
||||
|
||||
class ValueRange(WindowFrame):
|
||||
frame_type = 'RANGE'
|
||||
|
||||
def window_frame_start_end(self, connection, start, end):
|
||||
return connection.ops.window_frame_range_start_end(start, end)
|
||||
|
|
|
@ -8,6 +8,10 @@ from .datetime import (
|
|||
Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth,
|
||||
TruncQuarter, TruncSecond, TruncTime, TruncYear,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
|
||||
PercentRank, Rank, RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# base
|
||||
|
@ -18,4 +22,7 @@ __all__ = [
|
|||
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractWeekDay',
|
||||
'ExtractYear', 'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute',
|
||||
'TruncMonth', 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncYear',
|
||||
# window
|
||||
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
|
||||
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
from django.db.models import FloatField, Func, IntegerField
|
||||
|
||||
__all__ = [
|
||||
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
|
||||
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
|
||||
]
|
||||
|
||||
|
||||
class CumeDist(Func):
|
||||
function = 'CUME_DIST'
|
||||
name = 'CumeDist'
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class DenseRank(Func):
|
||||
function = 'DENSE_RANK'
|
||||
name = 'DenseRank'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class FirstValue(Func):
|
||||
arity = 1
|
||||
function = 'FIRST_VALUE'
|
||||
name = 'FirstValue'
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class LagLeadFunction(Func):
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, offset=1, default=None, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
'%s requires a non-null source expression.' %
|
||||
self.__class__.__name__
|
||||
)
|
||||
if offset is None or offset <= 0:
|
||||
raise ValueError(
|
||||
'%s requires a positive integer for the offset.' %
|
||||
self.__class__.__name__
|
||||
)
|
||||
args = (expression, offset)
|
||||
if default is not None:
|
||||
args += (default,)
|
||||
super().__init__(*args, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Lag(LagLeadFunction):
|
||||
function = 'LAG'
|
||||
name = 'Lag'
|
||||
|
||||
|
||||
class LastValue(Func):
|
||||
arity = 1
|
||||
function = 'LAST_VALUE'
|
||||
name = 'LastValue'
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Lead(LagLeadFunction):
|
||||
function = 'LEAD'
|
||||
name = 'Lead'
|
||||
|
||||
|
||||
class NthValue(Func):
|
||||
function = 'NTH_VALUE'
|
||||
name = 'NthValue'
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, nth=1, **extra):
|
||||
if expression is None:
|
||||
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
|
||||
if nth is None or nth <= 0:
|
||||
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
|
||||
super().__init__(expression, nth, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Ntile(Func):
|
||||
function = 'NTILE'
|
||||
name = 'Ntile'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, num_buckets=1, **extra):
|
||||
if num_buckets <= 0:
|
||||
raise ValueError('num_buckets must be greater than 0.')
|
||||
super().__init__(num_buckets, **extra)
|
||||
|
||||
|
||||
class PercentRank(Func):
|
||||
function = 'PERCENT_RANK'
|
||||
name = 'PercentRank'
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Rank(Func):
|
||||
function = 'RANK'
|
||||
name = 'Rank'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
function = 'ROW_NUMBER'
|
||||
name = 'RowNumber'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
|
@ -115,6 +115,10 @@ class Lookup:
|
|||
def contains_aggregate(self):
|
||||
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
|
||||
|
||||
@cached_property
|
||||
def contains_over_clause(self):
|
||||
return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
|
||||
|
||||
@property
|
||||
def is_summary(self):
|
||||
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
|
||||
|
|
|
@ -1107,6 +1107,8 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
)
|
||||
if value.contains_aggregate:
|
||||
raise FieldError("Aggregate functions are not allowed in this query")
|
||||
if value.contains_over_clause:
|
||||
raise FieldError('Window expressions are not allowed in this query.')
|
||||
else:
|
||||
value = field.get_db_prep_save(value, connection=self.connection)
|
||||
return value
|
||||
|
@ -1262,6 +1264,8 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||
val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
|
||||
if val.contains_aggregate:
|
||||
raise FieldError("Aggregate functions are not allowed in this query")
|
||||
if val.contains_over_clause:
|
||||
raise FieldError('Window expressions are not allowed in this query.')
|
||||
elif hasattr(val, 'prepare_database_save'):
|
||||
if field.remote_field:
|
||||
val = field.get_db_prep_save(
|
||||
|
|
|
@ -13,7 +13,7 @@ from string import ascii_uppercase
|
|||
from django.core.exceptions import (
|
||||
EmptyResultSet, FieldDoesNotExist, FieldError,
|
||||
)
|
||||
from django.db import DEFAULT_DB_ALIAS, connections
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
|
||||
from django.db.models.aggregates import Count
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import Col, Ref
|
||||
|
@ -1125,6 +1125,13 @@ class Query:
|
|||
if not arg:
|
||||
raise FieldError("Cannot parse keyword query %r" % arg)
|
||||
lookups, parts, reffed_expression = self.solve_lookup_type(arg)
|
||||
|
||||
if not getattr(reffed_expression, 'filterable', True):
|
||||
raise NotSupportedError(
|
||||
reffed_expression.__class__.__name__ + ' is disallowed in '
|
||||
'the filter clause.'
|
||||
)
|
||||
|
||||
if not allow_joins and len(parts) > 1:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
|
||||
|
|
|
@ -167,6 +167,16 @@ class WhereNode(tree.Node):
|
|||
def contains_aggregate(self):
|
||||
return self._contains_aggregate(self)
|
||||
|
||||
@classmethod
|
||||
def _contains_over_clause(cls, obj):
|
||||
if isinstance(obj, tree.Node):
|
||||
return any(cls._contains_over_clause(c) for c in obj.children)
|
||||
return obj.contains_over_clause
|
||||
|
||||
@cached_property
|
||||
def contains_over_clause(self):
|
||||
return self._contains_over_clause(self)
|
||||
|
||||
@property
|
||||
def is_summary(self):
|
||||
return any(child.is_summary for child in self.children)
|
||||
|
|
|
@ -819,3 +819,132 @@ Usage example::
|
|||
'minute': 'minute': datetime.datetime(2014, 6, 15, 14, 30, tzinfo=<UTC>),
|
||||
'second': datetime.datetime(2014, 6, 15, 14, 30, 50, tzinfo=<UTC>)
|
||||
}
|
||||
|
||||
.. _window-functions:
|
||||
|
||||
Window functions
|
||||
================
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
There are a number of functions to use in a
|
||||
:class:`~django.db.models.expressions.Window` expression for computing the rank
|
||||
of elements or the :class:`Ntile` of some rows.
|
||||
|
||||
``CumeDist``
|
||||
------------
|
||||
|
||||
.. class:: CumeDist(*expressions, **extra)
|
||||
|
||||
Calculates the cumulative distribution of a value within a window or partition.
|
||||
The cumulative distribution is defined as the number of rows preceding or
|
||||
peered with the current row divided by the total number of rows in the frame.
|
||||
|
||||
``DenseRank``
|
||||
-------------
|
||||
|
||||
.. class:: DenseRank(*expressions, **extra)
|
||||
|
||||
Equivalent to :class:`Rank` but does not have gaps.
|
||||
|
||||
``FirstValue``
|
||||
--------------
|
||||
|
||||
.. class:: FirstValue(expression, **extra)
|
||||
|
||||
Returns the value evaluated at the row that's the first row of the window
|
||||
frame, or ``None`` if no such value exists.
|
||||
|
||||
``Lag``
|
||||
-------
|
||||
|
||||
.. class:: Lag(expression, offset=1, default=None, **extra)
|
||||
|
||||
Calculates the value offset by ``offset``, and if no row exists there, returns
|
||||
``default``.
|
||||
|
||||
``default`` must have the same type as the ``expression``, however, this is
|
||||
only validated by the database and not in Python.
|
||||
|
||||
``LastValue``
|
||||
-------------
|
||||
|
||||
.. class:: LastValue(expression, **extra)
|
||||
|
||||
Comparable to :class:`FirstValue`, it calculates the last value in a given
|
||||
frame clause.
|
||||
|
||||
``Lead``
|
||||
--------
|
||||
|
||||
.. class:: Lead(expression, offset=1, default=None, **extra)
|
||||
|
||||
Calculates the leading value in a given :ref:`frame <window-frames>`. Both
|
||||
``offset`` and ``default`` are evaluated with respect to the current row.
|
||||
|
||||
``default`` must have the same type as the ``expression``, however, this is
|
||||
only validated by the database and not in Python.
|
||||
|
||||
``NthValue``
|
||||
------------
|
||||
|
||||
.. class:: NthValue(expression, nth=1, **extra)
|
||||
|
||||
Computes the row relative to the offset ``nth`` (must be a positive value)
|
||||
within the window. Returns ``None`` if no row exists.
|
||||
|
||||
Some databases may handle a nonexistent nth-value differently. For example,
|
||||
Oracle returns an empty string rather than ``None`` for character-based
|
||||
expressions. Django doesn't do any conversions in these cases.
|
||||
|
||||
``Ntile``
|
||||
---------
|
||||
|
||||
.. class:: Ntile(num_buckets=1, **extra)
|
||||
|
||||
Calculates a partition for each of the rows in the frame clause, distributing
|
||||
numbers as evenly as possible between 1 and ``num_buckets``. If the rows don't
|
||||
divide evenly into a number of buckets, one or more buckets will be represented
|
||||
more frequently.
|
||||
|
||||
``PercentRank``
|
||||
---------------
|
||||
|
||||
.. class:: PercentRank(*expressions, **extra)
|
||||
|
||||
Computes the percentile rank of the rows in the frame clause. This
|
||||
computation is equivalent to evaluating::
|
||||
|
||||
(rank - 1) / (total rows - 1)
|
||||
|
||||
The following table explains the calculation for the percentile rank of a row:
|
||||
|
||||
===== ===== ==== ============ ============
|
||||
Row # Value Rank Calculation Percent Rank
|
||||
===== ===== ==== ============ ============
|
||||
1 15 1 (1-1)/(7-1) 0.0000
|
||||
2 20 2 (2-1)/(7-1) 0.1666
|
||||
3 20 2 (2-1)/(7-1) 0.1666
|
||||
4 20 2 (2-1)/(7-1) 0.1666
|
||||
5 30 5 (5-1)/(7-1) 0.6666
|
||||
6 30 5 (5-1)/(7-1) 0.6666
|
||||
7 40 7 (7-1)/(7-1) 1.0000
|
||||
===== ===== ==== ============ ============
|
||||
|
||||
``Rank``
|
||||
--------
|
||||
|
||||
.. class:: Rank(*expressions, **extra)
|
||||
|
||||
Comparable to ``RowNumber``, this function ranks rows in the window. The
|
||||
computed rank contains gaps. Use :class:`DenseRank` to compute rank without
|
||||
gaps.
|
||||
|
||||
``RowNumber``
|
||||
-------------
|
||||
|
||||
.. class:: RowNumber(*expressions, **extra)
|
||||
|
||||
Computes the row number according to the ordering of either the frame clause
|
||||
or the ordering of the whole query if there is no partitioning of the
|
||||
:ref:`window frame <window-frames>`.
|
||||
|
|
|
@ -353,6 +353,13 @@ The ``Aggregate`` API is as follows:
|
|||
generated. Specifically, the ``function`` will be interpolated as the
|
||||
``function`` placeholder within :attr:`template`. Defaults to ``None``.
|
||||
|
||||
.. attribute:: window_compatible
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Defaults to ``True`` since most aggregate functions can be used as the
|
||||
source expression in :class:`~django.db.models.expressions.Window`.
|
||||
|
||||
The ``expression`` argument can be the name of a field on the model, or another
|
||||
expression. It will be converted to a string and used as the ``expressions``
|
||||
placeholder within the ``template``.
|
||||
|
@ -649,6 +656,184 @@ should avoid them if possible.
|
|||
force you to acknowledge that you're not interpolating your SQL with user
|
||||
provided data.
|
||||
|
||||
Window functions
|
||||
----------------
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Window functions provide a way to apply functions on partitions. Unlike a
|
||||
normal aggregation function which computes a final result for each set defined
|
||||
by the group by, window functions operate on :ref:`frames <window-frames>` and
|
||||
partitions, and compute the result for each row.
|
||||
|
||||
You can specify multiple windows in the same query which in Django ORM would be
|
||||
equivalent to including multiple expressions in a :doc:`QuerySet.annotate()
|
||||
</topics/db/aggregation>` call. The ORM doesn't make use of named windows,
|
||||
instead they are part of the selected columns.
|
||||
|
||||
.. class:: Window(expression, partition_by=None, order_by=None, frame=None, output_field=None)
|
||||
|
||||
.. attribute:: filterable
|
||||
|
||||
Defaults to ``False``. The SQL standard disallows referencing window
|
||||
functions in the ``WHERE`` clause and Django raises an exception when
|
||||
constructing a ``QuerySet`` that would do that.
|
||||
|
||||
.. attribute:: template
|
||||
|
||||
Defaults to ``%(expression)s OVER (%(window)s)'``. If only the
|
||||
``expression`` argument is provided, the window clause will be blank.
|
||||
|
||||
The ``Window`` class is the main expression for an ``OVER`` clause.
|
||||
|
||||
The ``expression`` argument is either a :ref:`window function
|
||||
<window-functions>`, an :ref:`aggregate function <aggregation-functions>`, or
|
||||
an expression that's compatible in a window clause.
|
||||
|
||||
The ``partition_by`` argument is a list of expressions (column names should be
|
||||
wrapped in an ``F``-object) that control the partitioning of the rows.
|
||||
Partitioning narrows which rows are used to compute the result set.
|
||||
|
||||
The ``output_field`` is specified either as an argument or by the expression.
|
||||
|
||||
The ``order_by`` argument accepts a sequence of expressions on which you can
|
||||
call :meth:`~django.db.models.Expression.asc` and
|
||||
:meth:`~django.db.models.Expression.desc`. The ordering controls the order in
|
||||
which the expression is applied. For example, if you sum over the rows in a
|
||||
partition, the first result is just the value of the first row, the second is
|
||||
the sum of first and second row.
|
||||
|
||||
The ``frame`` parameter specifies which other rows that should be used in the
|
||||
computation. See :ref:`window-frames` for details.
|
||||
|
||||
For example, to annotate each movie with the average rating for the movies by
|
||||
the same studio in the same genre and release year::
|
||||
|
||||
>>> from django.db.models import Avg, ExtractYear, F, Window
|
||||
>>> Movie.objects.annotate(
|
||||
>>> avg_rating=Window(
|
||||
>>> expression=Avg('rating'),
|
||||
>>> partition_by=[F('studio'), F('genre')],
|
||||
>>> order_by=ExtractYear('released').asc(),
|
||||
>>> ),
|
||||
>>> )
|
||||
|
||||
This makes it easy to check if a movie is rated better or worse than its peers.
|
||||
|
||||
You may want to apply multiple expressions over the same window, i.e., the
|
||||
same partition and frame. For example, you could modify the previous example
|
||||
to also include the best and worst rating in each movie's group (same studio,
|
||||
genre, and release year) by using three window functions in the same query. The
|
||||
partition and ordering from the previous example is extracted into a dictionary
|
||||
to reduce repetition::
|
||||
|
||||
>>> from django.db.models import Avg, ExtractYear, F, Max, Min, Window
|
||||
>>> window = {
|
||||
>>> 'partition': [F('studio'), F('genre')],
|
||||
>>> 'order_by': ExtractYear('released').asc(),
|
||||
>>> }
|
||||
>>> Movie.objects.annotate(
|
||||
>>> avg_rating=Window(
|
||||
>>> expression=Avg('rating'), **window,
|
||||
>>> ),
|
||||
>>> best=Window(
|
||||
>>> expression=Max('rating'), **window,
|
||||
>>> ),
|
||||
>>> worst=Window(
|
||||
>>> expression=Min('rating'), **window,
|
||||
>>> ),
|
||||
>>> )
|
||||
|
||||
Among Django's built-in database backends, MySQL 8.0.2+, PostgreSQL, and Oracle
|
||||
support window expressions. Support for different window expression features
|
||||
varies among the different databases. For example, the options in
|
||||
:meth:`~django.db.models.Expression.asc` and
|
||||
:meth:`~django.db.models.Expression.desc` may not be supported. Consult the
|
||||
documentation for your database as needed.
|
||||
|
||||
.. _window-frames:
|
||||
|
||||
Frames
|
||||
~~~~~~
|
||||
|
||||
For a window frame, you can choose either a range-based sequence of rows or an
|
||||
ordinary sequence of rows.
|
||||
|
||||
.. class:: ValueRange(start=None, end=None)
|
||||
|
||||
.. attribute:: frame_type
|
||||
|
||||
This attribute is set to ``'RANGE'``.
|
||||
|
||||
PostgreSQL has limited support for ``ValueRange`` and only supports use of
|
||||
the standard start and end points, such as ``CURRENT ROW`` and ``UNBOUNDED
|
||||
FOLLOWING``.
|
||||
|
||||
.. class:: RowRange(start=None, end=None)
|
||||
|
||||
.. attribute:: frame_type
|
||||
|
||||
This attribute is set to ``'ROWS'``.
|
||||
|
||||
Both classes return SQL with the template::
|
||||
|
||||
%(frame_type)s BETWEEN %(start)s AND %(end)s
|
||||
|
||||
Frames narrow the rows that are used for computing the result. They shift from
|
||||
some start point to some specified end point. Frames can be used with and
|
||||
without partitions, but it's often a good idea to specify an ordering of the
|
||||
window to ensure a deterministic result. In a frame, a peer in a frame is a row
|
||||
with an equivalent value, or all rows if an ordering clause isn't present.
|
||||
|
||||
The default starting point for a frame is ``UNBOUNDED PRECEDING`` which is the
|
||||
first row of the partition. The end point is always explicitly included in the
|
||||
SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default
|
||||
frame includes all rows from the partition to the last row in the set.
|
||||
|
||||
The accepted values for the ``start`` and ``end`` arguments are ``None``, an
|
||||
integer, or zero. A negative integer for ``start`` results in ``N preceding``,
|
||||
while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``,
|
||||
zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``.
|
||||
|
||||
There's a difference in what ``CURRENT ROW`` includes. When specified in
|
||||
``ROWS`` mode, the frame starts or ends with the current row. When specified in
|
||||
``RANGE`` mode, the frame starts or ends at the first or last peer according to
|
||||
the ordering clause. Thus, ``RANGE CURRENT ROW`` evaluates the expression for
|
||||
rows which have the same value specified by the ordering. Because the template
|
||||
includes both the ``start`` and ``end`` points, this may be expressed with::
|
||||
|
||||
ValueRange(start=0, end=0)
|
||||
|
||||
If a movie's "peers" are described as movies released by the same studio in the
|
||||
same genre in the same year, this ``RowRange`` example annotates each movie
|
||||
with the average rating of a movie's two prior and two following peers::
|
||||
|
||||
>>> from django.db.models import Avg, ExtractYear, F, RowRange, Window
|
||||
>>> Movie.objects.annotate(
|
||||
>>> avg_rating=Window(
|
||||
>>> expression=Avg('rating'),
|
||||
>>> partition_by=[F('studio'), F('genre')],
|
||||
>>> order_by=ExtractYear('released').asc(),
|
||||
>>> frame=RowRange(start=-2, end=2),
|
||||
>>> ),
|
||||
>>> )
|
||||
|
||||
If the database supports it, you can specify the start and end points based on
|
||||
values of an expression in the partition. If the ``released`` field of the
|
||||
``Movie`` model stores the release month of each movies, this ``ValueRange``
|
||||
example annotates each movie with the average rating of a movie's peers
|
||||
released between twelve months before and twelve months after the each movie.
|
||||
|
||||
>>> from django.db.models import Avg, ExpressionList, F, ValueRange, Window
|
||||
>>> Movie.objects.annotate(
|
||||
>>> avg_rating=Window(
|
||||
>>> expression=Avg('rating'),
|
||||
>>> partition_by=[F('studio'), F('genre')],
|
||||
>>> order_by=F('released').asc(),
|
||||
>>> frame=ValueRange(start=-12, end=12),
|
||||
>>> ),
|
||||
>>> )
|
||||
|
||||
.. currentmodule:: django.db.models
|
||||
|
||||
Technical Information
|
||||
|
@ -677,6 +862,30 @@ calling the appropriate methods on the wrapped expression.
|
|||
Tells Django that this expression contains an aggregate and that a
|
||||
``GROUP BY`` clause needs to be added to the query.
|
||||
|
||||
.. attribute:: contains_over_clause
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Tells Django that this expression contains a
|
||||
:class:`~django.db.models.expressions.Window` expression. It's used,
|
||||
for example, to disallow window function expressions in queries that
|
||||
modify data. Defaults to ``True``.
|
||||
|
||||
.. attribute:: filterable
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Tells Django that this expression can be referenced in
|
||||
:meth:`.QuerySet.filter`. Defaults to ``True``.
|
||||
|
||||
.. attribute:: window_compatible
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Tells Django that this expression can be used as the source expression
|
||||
in :class:`~django.db.models.expressions.Window`. Defaults to
|
||||
``False``.
|
||||
|
||||
.. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False, for_save=False)
|
||||
|
||||
Provides the chance to do any pre-processing or validation of
|
||||
|
|
|
@ -1681,6 +1681,11 @@ raised if ``select_for_update()`` is used in autocommit mode.
|
|||
``select_for_update()`` you should use
|
||||
:class:`~django.test.TransactionTestCase`.
|
||||
|
||||
.. admonition:: Certain expressions may not be supported
|
||||
|
||||
PostgreSQL doesn't support ``select_for_update()`` with
|
||||
:class:`~django.db.models.expressions.Window` expressions.
|
||||
|
||||
.. versionchanged:: 1.11
|
||||
|
||||
The ``skip_locked`` argument was added.
|
||||
|
|
|
@ -52,6 +52,14 @@ Mobile-friendly ``contrib.admin``
|
|||
The admin is now responsive and supports all major mobile devices.
|
||||
Older browser may experience varying levels of graceful degradation.
|
||||
|
||||
Window expressions
|
||||
------------------
|
||||
|
||||
The new :class:`~django.db.models.expressions.Window` expression allows
|
||||
adding an ``OVER`` clause to querysets. You can use :ref:`window functions
|
||||
<window-functions>` and :ref:`aggregate functions <aggregation-functions>` in
|
||||
the expression.
|
||||
|
||||
Minor features
|
||||
--------------
|
||||
|
||||
|
@ -404,6 +412,11 @@ backends.
|
|||
requires that the arguments to ``OF`` be columns rather than tables, set
|
||||
``DatabaseFeatures.select_for_update_of_column = True``.
|
||||
|
||||
* To enable support for :class:`~django.db.models.expressions.Window`
|
||||
expressions, set ``DatabaseFeatures.supports_over_clause`` to ``True``. You
|
||||
may need to customize the ``DatabaseOperations.window_start_rows_start_end()``
|
||||
and/or ``window_start_range_start_end()`` methods.
|
||||
|
||||
* Third-party database backends should add a
|
||||
``DatabaseOperations.cast_char_field_without_max_length`` attribute with the
|
||||
database data type that will be used in the
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
from django.db import NotSupportedError, connection
|
||||
from django.test import SimpleTestCase, skipIfDBFeature
|
||||
|
||||
|
||||
class DatabaseOperationTests(SimpleTestCase):
|
||||
@skipIfDBFeature('supports_over_clause')
|
||||
def test_window_frame_raise_not_supported_error(self):
|
||||
msg = 'This backend does not support window expressions.'
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
connection.ops.window_frame_rows_start_end()
|
|
@ -0,0 +1,39 @@
|
|||
from django.db.models.functions import Lag, Lead, NthValue, Ntile
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
|
||||
class ValidationTests(SimpleTestCase):
|
||||
def test_nth_negative_nth_value(self):
|
||||
msg = 'NthValue requires a positive integer as for nth'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
NthValue(expression='salary', nth=-1)
|
||||
|
||||
def test_nth_null_expression(self):
|
||||
msg = 'NthValue requires a non-null source expression'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
NthValue(expression=None)
|
||||
|
||||
def test_lag_negative_offset(self):
|
||||
msg = 'Lag requires a positive integer for the offset'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Lag(expression='salary', offset=-1)
|
||||
|
||||
def test_lead_negative_offset(self):
|
||||
msg = 'Lead requires a positive integer for the offset'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Lead(expression='salary', offset=-1)
|
||||
|
||||
def test_null_source_lead(self):
|
||||
msg = 'Lead requires a non-null source expression'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Lead(expression=None)
|
||||
|
||||
def test_null_source_lag(self):
|
||||
msg = 'Lag requires a non-null source expression'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Lag(expression=None)
|
||||
|
||||
def test_negative_num_buckets_ntile(self):
|
||||
msg = 'num_buckets must be greater than 0'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Ntile(num_buckets=-1)
|
|
@ -10,8 +10,8 @@ from django.db.models.aggregates import (
|
|||
Avg, Count, Max, Min, StdDev, Sum, Variance,
|
||||
)
|
||||
from django.db.models.expressions import (
|
||||
Case, Col, Exists, ExpressionWrapper, F, Func, OrderBy, OuterRef, Random,
|
||||
RawSQL, Ref, Subquery, Value, When,
|
||||
Case, Col, Exists, ExpressionList, ExpressionWrapper, F, Func, OrderBy,
|
||||
OuterRef, Random, RawSQL, Ref, Subquery, Value, When,
|
||||
)
|
||||
from django.db.models.functions import (
|
||||
Coalesce, Concat, Length, Lower, Substr, Upper,
|
||||
|
@ -1330,6 +1330,11 @@ class ValueTests(TestCase):
|
|||
self.assertNotEqual(value, other_value)
|
||||
self.assertNotEqual(value, no_output_field)
|
||||
|
||||
def test_raise_empty_expressionlist(self):
|
||||
msg = 'ExpressionList requires at least one expression'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
ExpressionList()
|
||||
|
||||
|
||||
class ReprTests(TestCase):
|
||||
|
||||
|
@ -1355,6 +1360,14 @@ class ReprTests(TestCase):
|
|||
self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
|
||||
self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
|
||||
self.assertEqual(repr(Value(1)), "Value(1)")
|
||||
self.assertEqual(
|
||||
repr(ExpressionList(F('col'), F('anothercol'))),
|
||||
'ExpressionList(F(col), F(anothercol))'
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(ExpressionList(OrderBy(F('col'), descending=False))),
|
||||
'ExpressionList(OrderBy(F(col), descending=False))'
|
||||
)
|
||||
|
||||
def test_functions(self):
|
||||
self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from django.db import models
|
||||
|
||||
|
||||
class Employee(models.Model):
|
||||
name = models.CharField(max_length=40, blank=False, null=False)
|
||||
salary = models.PositiveIntegerField()
|
||||
department = models.CharField(max_length=40, blank=False, null=False)
|
||||
hire_date = models.DateField(blank=False, null=False)
|
||||
|
||||
def __str__(self):
|
||||
return '{}, {}, {}, {}'.format(self.name, self.department, self.salary, self.hire_date)
|
|
@ -0,0 +1,783 @@
|
|||
import datetime
|
||||
from unittest import skipIf, skipUnless
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import NotSupportedError, connection
|
||||
from django.db.models import (
|
||||
F, RowRange, Value, ValueRange, Window, WindowFrame,
|
||||
)
|
||||
from django.db.models.aggregates import Avg, Max, Min, Sum
|
||||
from django.db.models.functions import (
|
||||
CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead,
|
||||
NthValue, Ntile, PercentRank, Rank, RowNumber, Upper,
|
||||
)
|
||||
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
||||
|
||||
from .models import Employee
|
||||
|
||||
|
||||
@skipUnlessDBFeature('supports_over_clause')
|
||||
class WindowFunctionTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
Employee.objects.bulk_create([
|
||||
Employee(name=e[0], salary=e[1], department=e[2], hire_date=e[3])
|
||||
for e in [
|
||||
('Jones', 45000, 'Accounting', datetime.datetime(2005, 11, 1)),
|
||||
('Williams', 37000, 'Accounting', datetime.datetime(2009, 6, 1)),
|
||||
('Jenson', 45000, 'Accounting', datetime.datetime(2008, 4, 1)),
|
||||
('Adams', 50000, 'Accounting', datetime.datetime(2013, 7, 1)),
|
||||
('Smith', 55000, 'Sales', datetime.datetime(2007, 6, 1)),
|
||||
('Brown', 53000, 'Sales', datetime.datetime(2009, 9, 1)),
|
||||
('Johnson', 40000, 'Marketing', datetime.datetime(2012, 3, 1)),
|
||||
('Smith', 38000, 'Marketing', datetime.datetime(2009, 10, 1)),
|
||||
('Wilkinson', 60000, 'IT', datetime.datetime(2011, 3, 1)),
|
||||
('Moore', 34000, 'IT', datetime.datetime(2013, 8, 1)),
|
||||
('Miller', 100000, 'Management', datetime.datetime(2005, 6, 1)),
|
||||
('Johnson', 80000, 'Management', datetime.datetime(2005, 7, 1)),
|
||||
]
|
||||
])
|
||||
|
||||
def test_dense_rank(self):
|
||||
qs = Employee.objects.annotate(rank=Window(
|
||||
expression=DenseRank(),
|
||||
order_by=ExtractYear(F('hire_date')).asc(),
|
||||
))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 2),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 3),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 4),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 4),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 4),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 5),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 6),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 7),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 7),
|
||||
], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False)
|
||||
|
||||
def test_department_salary(self):
|
||||
qs = Employee.objects.annotate(department_sum=Window(
|
||||
expression=Sum('salary'),
|
||||
partition_by=F('department'),
|
||||
order_by=[F('hire_date').asc()],
|
||||
)).order_by('department', 'department_sum')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 'Accounting', 45000, 45000),
|
||||
('Jenson', 'Accounting', 45000, 90000),
|
||||
('Williams', 'Accounting', 37000, 127000),
|
||||
('Adams', 'Accounting', 50000, 177000),
|
||||
('Wilkinson', 'IT', 60000, 60000),
|
||||
('Moore', 'IT', 34000, 94000),
|
||||
('Miller', 'Management', 100000, 100000),
|
||||
('Johnson', 'Management', 80000, 180000),
|
||||
('Smith', 'Marketing', 38000, 38000),
|
||||
('Johnson', 'Marketing', 40000, 78000),
|
||||
('Smith', 'Sales', 55000, 55000),
|
||||
('Brown', 'Sales', 53000, 108000),
|
||||
], lambda entry: (entry.name, entry.department, entry.salary, entry.department_sum))
|
||||
|
||||
def test_rank(self):
|
||||
"""
|
||||
Rank the employees based on the year they're were hired. Since there
|
||||
are multiple employees hired in different years, this will contain
|
||||
gaps.
|
||||
"""
|
||||
qs = Employee.objects.annotate(rank=Window(
|
||||
expression=Rank(),
|
||||
order_by=ExtractYear(F('hire_date')).asc(),
|
||||
))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 4),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 5),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 6),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 6),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 6),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 9),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 10),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 11),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 11),
|
||||
], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False)
|
||||
|
||||
def test_row_number(self):
|
||||
"""
|
||||
The row number window function computes the number based on the order
|
||||
in which the tuples were inserted. Depending on the backend,
|
||||
|
||||
Oracle requires an ordering-clause in the Window expression.
|
||||
"""
|
||||
qs = Employee.objects.annotate(row_number=Window(
|
||||
expression=RowNumber(),
|
||||
order_by=F('pk').asc(),
|
||||
)).order_by('pk')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 'Accounting', 1),
|
||||
('Williams', 'Accounting', 2),
|
||||
('Jenson', 'Accounting', 3),
|
||||
('Adams', 'Accounting', 4),
|
||||
('Smith', 'Sales', 5),
|
||||
('Brown', 'Sales', 6),
|
||||
('Johnson', 'Marketing', 7),
|
||||
('Smith', 'Marketing', 8),
|
||||
('Wilkinson', 'IT', 9),
|
||||
('Moore', 'IT', 10),
|
||||
('Miller', 'Management', 11),
|
||||
('Johnson', 'Management', 12),
|
||||
], lambda entry: (entry.name, entry.department, entry.row_number))
|
||||
|
||||
@skipIf(connection.vendor == 'oracle', "Oracle requires ORDER BY in row_number, ANSI:SQL doesn't")
|
||||
def test_row_number_no_ordering(self):
|
||||
"""
|
||||
The row number window function computes the number based on the order
|
||||
in which the tuples were inserted.
|
||||
"""
|
||||
# Add a default ordering for consistent results across databases.
|
||||
qs = Employee.objects.annotate(row_number=Window(
|
||||
expression=RowNumber(),
|
||||
)).order_by('pk')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 'Accounting', 1),
|
||||
('Williams', 'Accounting', 2),
|
||||
('Jenson', 'Accounting', 3),
|
||||
('Adams', 'Accounting', 4),
|
||||
('Smith', 'Sales', 5),
|
||||
('Brown', 'Sales', 6),
|
||||
('Johnson', 'Marketing', 7),
|
||||
('Smith', 'Marketing', 8),
|
||||
('Wilkinson', 'IT', 9),
|
||||
('Moore', 'IT', 10),
|
||||
('Miller', 'Management', 11),
|
||||
('Johnson', 'Management', 12),
|
||||
], lambda entry: (entry.name, entry.department, entry.row_number))
|
||||
|
||||
def test_avg_salary_department(self):
|
||||
qs = Employee.objects.annotate(avg_salary=Window(
|
||||
expression=Avg('salary'),
|
||||
order_by=F('department').asc(),
|
||||
partition_by='department',
|
||||
)).order_by('department', '-salary', 'name')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Adams', 50000, 'Accounting', 44250.00),
|
||||
('Jenson', 45000, 'Accounting', 44250.00),
|
||||
('Jones', 45000, 'Accounting', 44250.00),
|
||||
('Williams', 37000, 'Accounting', 44250.00),
|
||||
('Wilkinson', 60000, 'IT', 47000.00),
|
||||
('Moore', 34000, 'IT', 47000.00),
|
||||
('Miller', 100000, 'Management', 90000.00),
|
||||
('Johnson', 80000, 'Management', 90000.00),
|
||||
('Johnson', 40000, 'Marketing', 39000.00),
|
||||
('Smith', 38000, 'Marketing', 39000.00),
|
||||
('Smith', 55000, 'Sales', 54000.00),
|
||||
('Brown', 53000, 'Sales', 54000.00),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.avg_salary))
|
||||
|
||||
def test_lag(self):
|
||||
"""
|
||||
Compute the difference between an employee's salary and the next
|
||||
highest salary in the employee's department. Return None if the
|
||||
employee has the lowest salary.
|
||||
"""
|
||||
qs = Employee.objects.annotate(lag=Window(
|
||||
expression=Lag(expression='salary', offset=1),
|
||||
partition_by=F('department'),
|
||||
order_by=[F('salary').asc(), F('name').asc()],
|
||||
)).order_by('department')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Williams', 37000, 'Accounting', None),
|
||||
('Jenson', 45000, 'Accounting', 37000),
|
||||
('Jones', 45000, 'Accounting', 45000),
|
||||
('Adams', 50000, 'Accounting', 45000),
|
||||
('Moore', 34000, 'IT', None),
|
||||
('Wilkinson', 60000, 'IT', 34000),
|
||||
('Johnson', 80000, 'Management', None),
|
||||
('Miller', 100000, 'Management', 80000),
|
||||
('Smith', 38000, 'Marketing', None),
|
||||
('Johnson', 40000, 'Marketing', 38000),
|
||||
('Brown', 53000, 'Sales', None),
|
||||
('Smith', 55000, 'Sales', 53000),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.lag))
|
||||
|
||||
def test_first_value(self):
|
||||
qs = Employee.objects.annotate(first_value=Window(
|
||||
expression=FirstValue('salary'),
|
||||
partition_by=F('department'),
|
||||
order_by=F('hire_date').asc(),
|
||||
)).order_by('department', 'hire_date')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 45000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 45000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 45000),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 60000),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 100000),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 38000),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 55000),
|
||||
], lambda row: (row.name, row.salary, row.department, row.hire_date, row.first_value))
|
||||
|
||||
def test_last_value(self):
|
||||
qs = Employee.objects.annotate(last_value=Window(
|
||||
expression=LastValue('hire_date'),
|
||||
partition_by=F('department'),
|
||||
order_by=F('hire_date').asc(),
|
||||
))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Adams', 'Accounting', datetime.date(2013, 7, 1), 50000, datetime.date(2013, 7, 1)),
|
||||
('Jenson', 'Accounting', datetime.date(2008, 4, 1), 45000, datetime.date(2008, 4, 1)),
|
||||
('Jones', 'Accounting', datetime.date(2005, 11, 1), 45000, datetime.date(2005, 11, 1)),
|
||||
('Williams', 'Accounting', datetime.date(2009, 6, 1), 37000, datetime.date(2009, 6, 1)),
|
||||
('Moore', 'IT', datetime.date(2013, 8, 1), 34000, datetime.date(2013, 8, 1)),
|
||||
('Wilkinson', 'IT', datetime.date(2011, 3, 1), 60000, datetime.date(2011, 3, 1)),
|
||||
('Miller', 'Management', datetime.date(2005, 6, 1), 100000, datetime.date(2005, 6, 1)),
|
||||
('Johnson', 'Management', datetime.date(2005, 7, 1), 80000, datetime.date(2005, 7, 1)),
|
||||
('Johnson', 'Marketing', datetime.date(2012, 3, 1), 40000, datetime.date(2012, 3, 1)),
|
||||
('Smith', 'Marketing', datetime.date(2009, 10, 1), 38000, datetime.date(2009, 10, 1)),
|
||||
('Brown', 'Sales', datetime.date(2009, 9, 1), 53000, datetime.date(2009, 9, 1)),
|
||||
('Smith', 'Sales', datetime.date(2007, 6, 1), 55000, datetime.date(2007, 6, 1)),
|
||||
], transform=lambda row: (row.name, row.department, row.hire_date, row.salary, row.last_value), ordered=False)
|
||||
|
||||
def test_function_list_of_values(self):
|
||||
qs = Employee.objects.annotate(lead=Window(
|
||||
expression=Lead(expression='salary'),
|
||||
order_by=[F('hire_date').asc(), F('name').desc()],
|
||||
partition_by='department',
|
||||
)).values_list('name', 'salary', 'department', 'hire_date', 'lead')
|
||||
self.assertNotIn('GROUP BY', str(qs.query))
|
||||
self.assertSequenceEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 37000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 50000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 34000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 80000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 40000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 53000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None),
|
||||
])
|
||||
|
||||
def test_min_department(self):
|
||||
"""An alternative way to specify a query for FirstValue."""
|
||||
qs = Employee.objects.annotate(min_salary=Window(
|
||||
expression=Min('salary'),
|
||||
partition_by=F('department'),
|
||||
order_by=[F('salary').asc(), F('name').asc()]
|
||||
)).order_by('department', 'salary', 'name')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Williams', 'Accounting', 37000, 37000),
|
||||
('Jenson', 'Accounting', 45000, 37000),
|
||||
('Jones', 'Accounting', 45000, 37000),
|
||||
('Adams', 'Accounting', 50000, 37000),
|
||||
('Moore', 'IT', 34000, 34000),
|
||||
('Wilkinson', 'IT', 60000, 34000),
|
||||
('Johnson', 'Management', 80000, 80000),
|
||||
('Miller', 'Management', 100000, 80000),
|
||||
('Smith', 'Marketing', 38000, 38000),
|
||||
('Johnson', 'Marketing', 40000, 38000),
|
||||
('Brown', 'Sales', 53000, 53000),
|
||||
('Smith', 'Sales', 55000, 53000),
|
||||
], lambda row: (row.name, row.department, row.salary, row.min_salary))
|
||||
|
||||
def test_max_per_year(self):
|
||||
"""
|
||||
Find the maximum salary awarded in the same year as the
|
||||
employee was hired, regardless of the department.
|
||||
"""
|
||||
qs = Employee.objects.annotate(max_salary_year=Window(
|
||||
expression=Max('salary'),
|
||||
order_by=ExtractYear('hire_date').asc(),
|
||||
partition_by=ExtractYear('hire_date')
|
||||
)).order_by(ExtractYear('hire_date'), 'salary')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 'Accounting', 45000, 2005, 100000),
|
||||
('Johnson', 'Management', 80000, 2005, 100000),
|
||||
('Miller', 'Management', 100000, 2005, 100000),
|
||||
('Smith', 'Sales', 55000, 2007, 55000),
|
||||
('Jenson', 'Accounting', 45000, 2008, 45000),
|
||||
('Williams', 'Accounting', 37000, 2009, 53000),
|
||||
('Smith', 'Marketing', 38000, 2009, 53000),
|
||||
('Brown', 'Sales', 53000, 2009, 53000),
|
||||
('Wilkinson', 'IT', 60000, 2011, 60000),
|
||||
('Johnson', 'Marketing', 40000, 2012, 40000),
|
||||
('Moore', 'IT', 34000, 2013, 50000),
|
||||
('Adams', 'Accounting', 50000, 2013, 50000),
|
||||
], lambda row: (row.name, row.department, row.salary, row.hire_date.year, row.max_salary_year))
|
||||
|
||||
def test_cume_dist(self):
|
||||
"""
|
||||
Compute the cumulative distribution for the employees based on the
|
||||
salary in increasing order. Equal to rank/total number of rows (12).
|
||||
"""
|
||||
qs = Employee.objects.annotate(cume_dist=Window(
|
||||
expression=CumeDist(),
|
||||
order_by=F('salary').asc(),
|
||||
)).order_by('salary', 'name')
|
||||
# Round result of cume_dist because Oracle uses greater precision.
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Moore', 'IT', 34000, 0.0833333333),
|
||||
('Williams', 'Accounting', 37000, 0.1666666667),
|
||||
('Smith', 'Marketing', 38000, 0.25),
|
||||
('Johnson', 'Marketing', 40000, 0.3333333333),
|
||||
('Jenson', 'Accounting', 45000, 0.5),
|
||||
('Jones', 'Accounting', 45000, 0.5),
|
||||
('Adams', 'Accounting', 50000, 0.5833333333),
|
||||
('Brown', 'Sales', 53000, 0.6666666667),
|
||||
('Smith', 'Sales', 55000, 0.75),
|
||||
('Wilkinson', 'IT', 60000, 0.8333333333),
|
||||
('Johnson', 'Management', 80000, 0.9166666667),
|
||||
('Miller', 'Management', 100000, 1),
|
||||
], lambda row: (row.name, row.department, row.salary, round(row.cume_dist, 10)))
|
||||
|
||||
def test_nthvalue(self):
|
||||
qs = Employee.objects.annotate(
|
||||
nth_value=Window(expression=NthValue(
|
||||
expression='salary', nth=2),
|
||||
order_by=[F('hire_date').asc(), F('name').desc()],
|
||||
partition_by=F('department'),
|
||||
)
|
||||
).order_by('department', 'hire_date', 'name')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 'Accounting', datetime.date(2005, 11, 1), 45000, None),
|
||||
('Jenson', 'Accounting', datetime.date(2008, 4, 1), 45000, 45000),
|
||||
('Williams', 'Accounting', datetime.date(2009, 6, 1), 37000, 45000),
|
||||
('Adams', 'Accounting', datetime.date(2013, 7, 1), 50000, 45000),
|
||||
('Wilkinson', 'IT', datetime.date(2011, 3, 1), 60000, None),
|
||||
('Moore', 'IT', datetime.date(2013, 8, 1), 34000, 34000),
|
||||
('Miller', 'Management', datetime.date(2005, 6, 1), 100000, None),
|
||||
('Johnson', 'Management', datetime.date(2005, 7, 1), 80000, 80000),
|
||||
('Smith', 'Marketing', datetime.date(2009, 10, 1), 38000, None),
|
||||
('Johnson', 'Marketing', datetime.date(2012, 3, 1), 40000, 40000),
|
||||
('Smith', 'Sales', datetime.date(2007, 6, 1), 55000, None),
|
||||
('Brown', 'Sales', datetime.date(2009, 9, 1), 53000, 53000),
|
||||
], lambda row: (row.name, row.department, row.hire_date, row.salary, row.nth_value))
|
||||
|
||||
def test_lead(self):
|
||||
"""
|
||||
Determine what the next person hired in the same department makes.
|
||||
Because the dataset is ambiguous, the name is also part of the
|
||||
ordering clause. No default is provided, so None/NULL should be
|
||||
returned.
|
||||
"""
|
||||
qs = Employee.objects.annotate(lead=Window(
|
||||
expression=Lead(expression='salary'),
|
||||
order_by=[F('hire_date').asc(), F('name').desc()],
|
||||
partition_by='department',
|
||||
)).order_by('department')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 37000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 50000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 34000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 80000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 40000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 53000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.lead))
|
||||
|
||||
def test_lead_offset(self):
|
||||
"""
|
||||
Determine what the person hired after someone makes. Due to
|
||||
ambiguity, the name is also included in the ordering.
|
||||
"""
|
||||
qs = Employee.objects.annotate(lead=Window(
|
||||
expression=Lead('salary', offset=2),
|
||||
partition_by='department',
|
||||
order_by=F('hire_date').asc(),
|
||||
))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 37000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 50000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), None),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), None),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), None),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), None),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), None),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), None),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), None),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), None),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), None),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), None),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.lead),
|
||||
ordered=False
|
||||
)
|
||||
|
||||
def test_lead_default(self):
|
||||
qs = Employee.objects.annotate(lead_default=Window(
|
||||
expression=Lead(expression='salary', offset=5, default=60000),
|
||||
partition_by=F('department'),
|
||||
order_by=F('department').asc(),
|
||||
))
|
||||
self.assertEqual(list(qs.values_list('lead_default', flat=True).distinct()), [60000])
|
||||
|
||||
def test_ntile(self):
|
||||
"""
|
||||
Compute the group for each of the employees across the entire company,
|
||||
based on how high the salary is for them. There are twelve employees
|
||||
so it divides evenly into four groups.
|
||||
"""
|
||||
qs = Employee.objects.annotate(ntile=Window(
|
||||
expression=Ntile(num_buckets=4),
|
||||
order_by=F('salary').desc(),
|
||||
)).order_by('ntile', '-salary', 'name')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Miller', 'Management', 100000, 1),
|
||||
('Johnson', 'Management', 80000, 1),
|
||||
('Wilkinson', 'IT', 60000, 1),
|
||||
('Smith', 'Sales', 55000, 2),
|
||||
('Brown', 'Sales', 53000, 2),
|
||||
('Adams', 'Accounting', 50000, 2),
|
||||
('Jenson', 'Accounting', 45000, 3),
|
||||
('Jones', 'Accounting', 45000, 3),
|
||||
('Johnson', 'Marketing', 40000, 3),
|
||||
('Smith', 'Marketing', 38000, 4),
|
||||
('Williams', 'Accounting', 37000, 4),
|
||||
('Moore', 'IT', 34000, 4),
|
||||
], lambda x: (x.name, x.department, x.salary, x.ntile))
|
||||
|
||||
def test_percent_rank(self):
|
||||
"""
|
||||
Calculate the percentage rank of the employees across the entire
|
||||
company based on salary and name (in case of ambiguity).
|
||||
"""
|
||||
qs = Employee.objects.annotate(percent_rank=Window(
|
||||
expression=PercentRank(),
|
||||
order_by=[F('salary').asc(), F('name').asc()],
|
||||
)).order_by('percent_rank')
|
||||
# Round to account for precision differences among databases.
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Moore', 'IT', 34000, 0.0),
|
||||
('Williams', 'Accounting', 37000, 0.0909090909),
|
||||
('Smith', 'Marketing', 38000, 0.1818181818),
|
||||
('Johnson', 'Marketing', 40000, 0.2727272727),
|
||||
('Jenson', 'Accounting', 45000, 0.3636363636),
|
||||
('Jones', 'Accounting', 45000, 0.4545454545),
|
||||
('Adams', 'Accounting', 50000, 0.5454545455),
|
||||
('Brown', 'Sales', 53000, 0.6363636364),
|
||||
('Smith', 'Sales', 55000, 0.7272727273),
|
||||
('Wilkinson', 'IT', 60000, 0.8181818182),
|
||||
('Johnson', 'Management', 80000, 0.9090909091),
|
||||
('Miller', 'Management', 100000, 1.0),
|
||||
], transform=lambda row: (row.name, row.department, row.salary, round(row.percent_rank, 10)))
|
||||
|
||||
def test_nth_returns_null(self):
|
||||
"""
|
||||
Find the nth row of the data set. None is returned since there are
|
||||
fewer than 20 rows in the test data.
|
||||
"""
|
||||
qs = Employee.objects.annotate(nth_value=Window(
|
||||
expression=NthValue('salary', nth=20),
|
||||
order_by=F('salary').asc()
|
||||
))
|
||||
self.assertEqual(list(qs.values_list('nth_value', flat=True).distinct()), [None])
|
||||
|
||||
def test_multiple_partitioning(self):
|
||||
"""
|
||||
Find the maximum salary for each department for people hired in the
|
||||
same year.
|
||||
"""
|
||||
qs = Employee.objects.annotate(max=Window(
|
||||
expression=Max('salary'),
|
||||
partition_by=[F('department'), ExtractYear(F('hire_date'))],
|
||||
)).order_by('department', 'hire_date', 'name')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 45000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 37000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 50000),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 34000),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 100000),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 40000),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 53000),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.max))
|
||||
|
||||
def test_multiple_ordering(self):
|
||||
"""
|
||||
Accumulate the salaries over the departments based on hire_date.
|
||||
If two people were hired on the same date in the same department, the
|
||||
ordering clause will render a different result for those people.
|
||||
"""
|
||||
qs = Employee.objects.annotate(sum=Window(
|
||||
expression=Sum('salary'),
|
||||
partition_by='department',
|
||||
order_by=[F('hire_date').asc(), F('name').asc()],
|
||||
)).order_by('department', 'sum')
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 90000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 127000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 177000),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 94000),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 180000),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 78000),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 108000),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum))
|
||||
|
||||
@skipIf(connection.vendor == 'postgresql', 'n following/preceding not supported by PostgreSQL')
|
||||
def test_range_n_preceding_and_following(self):
|
||||
qs = Employee.objects.annotate(sum=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('salary').asc(),
|
||||
partition_by='department',
|
||||
frame=ValueRange(start=-2, end=2),
|
||||
))
|
||||
self.assertIn('RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING', str(qs.query))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 37000),
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 90000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 90000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 50000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 53000),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 55000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 40000),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 38000),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 60000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 34000),
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 100000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 80000),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum), ordered=False)
|
||||
|
||||
def test_range_unbound(self):
|
||||
"""A query with RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING."""
|
||||
qs = Employee.objects.annotate(sum=Window(
|
||||
expression=Sum('salary'),
|
||||
partition_by='department',
|
||||
order_by=[F('hire_date').asc(), F('name').asc()],
|
||||
frame=ValueRange(start=None, end=None),
|
||||
)).order_by('department', 'hire_date', 'name')
|
||||
self.assertIn('RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING', str(qs.query))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Jones', 'Accounting', 45000, datetime.date(2005, 11, 1), 177000),
|
||||
('Jenson', 'Accounting', 45000, datetime.date(2008, 4, 1), 177000),
|
||||
('Williams', 'Accounting', 37000, datetime.date(2009, 6, 1), 177000),
|
||||
('Adams', 'Accounting', 50000, datetime.date(2013, 7, 1), 177000),
|
||||
('Wilkinson', 'IT', 60000, datetime.date(2011, 3, 1), 94000),
|
||||
('Moore', 'IT', 34000, datetime.date(2013, 8, 1), 94000),
|
||||
('Miller', 'Management', 100000, datetime.date(2005, 6, 1), 180000),
|
||||
('Johnson', 'Management', 80000, datetime.date(2005, 7, 1), 180000),
|
||||
('Smith', 'Marketing', 38000, datetime.date(2009, 10, 1), 78000),
|
||||
('Johnson', 'Marketing', 40000, datetime.date(2012, 3, 1), 78000),
|
||||
('Smith', 'Sales', 55000, datetime.date(2007, 6, 1), 108000),
|
||||
('Brown', 'Sales', 53000, datetime.date(2009, 9, 1), 108000),
|
||||
], transform=lambda row: (row.name, row.department, row.salary, row.hire_date, row.sum))
|
||||
|
||||
def test_row_range_rank(self):
|
||||
"""
|
||||
A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING.
|
||||
The resulting sum is the sum of the three next (if they exist) and all
|
||||
previous rows according to the ordering clause.
|
||||
"""
|
||||
qs = Employee.objects.annotate(sum=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=[F('hire_date').asc(), F('name').desc()],
|
||||
frame=RowRange(start=None, end=3),
|
||||
)).order_by('sum', 'hire_date')
|
||||
self.assertIn('ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING', str(qs.query))
|
||||
self.assertQuerysetEqual(qs, [
|
||||
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 280000),
|
||||
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 325000),
|
||||
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 362000),
|
||||
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 415000),
|
||||
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 453000),
|
||||
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 513000),
|
||||
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 553000),
|
||||
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 603000),
|
||||
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 637000),
|
||||
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 637000),
|
||||
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 637000),
|
||||
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 637000),
|
||||
], transform=lambda row: (row.name, row.salary, row.department, row.hire_date, row.sum))
|
||||
|
||||
@skipUnlessDBFeature('can_distinct_on_fields')
|
||||
def test_distinct_window_function(self):
|
||||
"""
|
||||
Window functions are not aggregates, and hence a query to filter out
|
||||
duplicates may be useful.
|
||||
"""
|
||||
qs = Employee.objects.annotate(
|
||||
sum=Window(
|
||||
expression=Sum('salary'),
|
||||
partition_by=ExtractYear('hire_date'),
|
||||
order_by=ExtractYear('hire_date')
|
||||
),
|
||||
year=ExtractYear('hire_date'),
|
||||
).values('year', 'sum').distinct('year').order_by('year')
|
||||
results = [
|
||||
{'year': 2005, 'sum': 225000}, {'year': 2007, 'sum': 55000},
|
||||
{'year': 2008, 'sum': 45000}, {'year': 2009, 'sum': 128000},
|
||||
{'year': 2011, 'sum': 60000}, {'year': 2012, 'sum': 40000},
|
||||
{'year': 2013, 'sum': 84000},
|
||||
]
|
||||
for idx, val in zip(range(len(results)), results):
|
||||
with self.subTest(result=val):
|
||||
self.assertEqual(qs[idx], val)
|
||||
|
||||
def test_fail_update(self):
|
||||
"""Window expressions can't be used in an UPDATE statement."""
|
||||
msg = 'Window expressions are not allowed in this query'
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
Employee.objects.filter(department='Management').update(
|
||||
salary=Window(expression=Max('salary'), partition_by='department'),
|
||||
)
|
||||
|
||||
def test_fail_insert(self):
|
||||
"""Window expressions can't be used in an INSERT statement."""
|
||||
msg = 'Window expressions are not allowed in this query'
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
Employee.objects.create(
|
||||
name='Jameson', department='Management', hire_date=datetime.date(2007, 7, 1),
|
||||
salary=Window(expression=Sum(Value(10000), order_by=F('pk').asc())),
|
||||
)
|
||||
|
||||
def test_invalid_start_value_range(self):
|
||||
msg = "start argument must be a negative integer, zero, or None, but got '3'."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('hire_date').asc(),
|
||||
frame=ValueRange(start=3),
|
||||
)))
|
||||
|
||||
def test_invalid_end_value_range(self):
|
||||
msg = "end argument must be a positive integer, zero, or None, but got '-3'."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('hire_date').asc(),
|
||||
frame=ValueRange(end=-3),
|
||||
)))
|
||||
|
||||
def test_invalid_type_end_value_range(self):
|
||||
msg = "end argument must be a positive integer, zero, or None, but got 'a'."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('hire_date').asc(),
|
||||
frame=ValueRange(end='a'),
|
||||
)))
|
||||
|
||||
def test_invalid_type_start_value_range(self):
|
||||
msg = "start argument must be a negative integer, zero, or None, but got 'a'."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
frame=ValueRange(start='a'),
|
||||
)))
|
||||
|
||||
def test_invalid_type_end_row_range(self):
|
||||
msg = "end argument must be a positive integer, zero, or None, but got 'a'."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
frame=RowRange(end='a'),
|
||||
)))
|
||||
|
||||
@skipUnless(connection.vendor == 'postgresql', 'Frame construction not allowed on PostgreSQL')
|
||||
def test_postgresql_illegal_range_frame_start(self):
|
||||
msg = 'PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING.'
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('hire_date').asc(),
|
||||
frame=ValueRange(start=-1),
|
||||
)))
|
||||
|
||||
@skipUnless(connection.vendor == 'postgresql', 'Frame construction not allowed on PostgreSQL')
|
||||
def test_postgresql_illegal_range_frame_end(self):
|
||||
msg = 'PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING.'
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('hire_date').asc(),
|
||||
frame=ValueRange(end=1),
|
||||
)))
|
||||
|
||||
def test_invalid_type_start_row_range(self):
|
||||
msg = "start argument must be a negative integer, zero, or None, but got 'a'."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Employee.objects.annotate(test=Window(
|
||||
expression=Sum('salary'),
|
||||
order_by=F('hire_date').asc(),
|
||||
frame=RowRange(start='a'),
|
||||
)))
|
||||
|
||||
|
||||
class NonQueryWindowTests(SimpleTestCase):
|
||||
def test_window_repr(self):
|
||||
self.assertEqual(
|
||||
repr(Window(expression=Sum('salary'), partition_by='department')),
|
||||
'<Window: Sum(F(salary)) OVER (PARTITION BY F(department))>'
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(Window(expression=Avg('salary'), order_by=F('department').asc())),
|
||||
'<Window: Avg(F(salary)) OVER (ORDER BY OrderBy(F(department), descending=False))>'
|
||||
)
|
||||
|
||||
def test_window_frame_repr(self):
|
||||
self.assertEqual(
|
||||
repr(RowRange(start=-1)),
|
||||
'<RowRange: ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING>'
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(ValueRange(start=None, end=1)),
|
||||
'<ValueRange: RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING>'
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(ValueRange(start=0, end=0)),
|
||||
'<ValueRange: RANGE BETWEEN CURRENT ROW AND CURRENT ROW>'
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(RowRange(start=0, end=0)),
|
||||
'<RowRange: ROWS BETWEEN CURRENT ROW AND CURRENT ROW>'
|
||||
)
|
||||
|
||||
def test_empty_group_by_cols(self):
|
||||
window = Window(expression=Sum('pk'))
|
||||
self.assertEqual(window.get_group_by_cols(), [])
|
||||
self.assertFalse(window.contains_aggregate)
|
||||
|
||||
def test_frame_empty_group_by_cols(self):
|
||||
frame = WindowFrame()
|
||||
self.assertEqual(frame.get_group_by_cols(), [])
|
||||
|
||||
def test_frame_window_frame_notimplemented(self):
|
||||
frame = WindowFrame()
|
||||
msg = 'Subclasses must implement window_frame_start_end().'
|
||||
with self.assertRaisesMessage(NotImplementedError, msg):
|
||||
frame.window_frame_start_end(None, None, None)
|
||||
|
||||
def test_invalid_filter(self):
|
||||
msg = 'Window is disallowed in the filter clause'
|
||||
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||
Employee.objects.annotate(dense_rank=Window(expression=DenseRank())).filter(dense_rank__gte=1)
|
||||
|
||||
def test_invalid_order_by(self):
|
||||
msg = 'order_by must be either an Expression or a sequence of expressions'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Window(expression=Sum('power'), order_by='-horse')
|
||||
|
||||
def test_invalid_source_expression(self):
|
||||
msg = "Expression 'Upper' isn't compatible with OVER clauses."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Window(expression=Upper('name'))
|
Loading…
Reference in New Issue