Refs #29850 -- Added RowRange support for positive integer start and negative integer end.

This commit is contained in:
Sarah Boyce 2023-10-22 16:19:28 +02:00 committed by Mariusz Felisiak
parent a6c7db1d1d
commit 6375cee490
5 changed files with 155 additions and 33 deletions

View File

@ -714,42 +714,50 @@ class BaseDatabaseOperations:
"This backend does not support %s subtraction." % internal_type
)
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
return "%d %s" % (abs(start), self.PRECEDING)
elif start == 0:
def window_frame_value(self, value):
if isinstance(value, int):
if value == 0:
return self.CURRENT_ROW
elif start is None:
return self.UNBOUNDED_PRECEDING
raise ValueError(
"start argument must be a negative integer, zero, or None, but got '%s'."
% start
)
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return self.CURRENT_ROW
elif end > 0:
return "%d %s" % (end, self.FOLLOWING)
elif end is None:
return self.UNBOUNDED_FOLLOWING
raise ValueError(
"end argument must be a positive integer, zero, or None, but got '%s'."
% end
)
elif value < 0:
return "%d %s" % (abs(value), self.PRECEDING)
else:
return "%d %s" % (value, self.FOLLOWING)
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
if not self.connection.features.supports_over_clause:
raise NotSupportedError("This backend does not support window expressions.")
return self.window_frame_start(start), self.window_frame_end(end)
if isinstance(start, int) and isinstance(end, int) and start > end:
raise ValueError("start cannot be greater than end.")
if start is not None and not isinstance(start, int):
raise ValueError(
f"start argument must be an integer, zero, or None, but got '{start}'."
)
if end is not None and not isinstance(end, int):
raise ValueError(
f"end argument must be an integer, zero, or None, but got '{end}'."
)
start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
return start_, end_
def window_frame_range_start_end(self, start=None, end=None):
start_, end_ = self.window_frame_rows_start_end(start, end)
if (start is not None and not isinstance(start, int)) or (
isinstance(start, int) and start > 0
):
raise ValueError(
"start argument must be a negative integer, zero, or None, "
"but got '%s'." % start
)
if (end is not None and not isinstance(end, int)) or (
isinstance(end, int) and end < 0
):
raise ValueError(
"end argument must be a positive integer, zero, or None, but got '%s'."
% end
)
start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
features = self.connection.features
if features.only_supports_unbounded_with_preceding_and_following and (
(start and start < 0) or (end and end > 0)

View File

@ -1895,6 +1895,8 @@ class WindowFrame(Expression):
start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
elif self.start.value is not None and self.start.value == 0:
start = connection.ops.CURRENT_ROW
elif self.start.value is not None and self.start.value > 0:
start = "%d %s" % (self.start.value, connection.ops.FOLLOWING)
else:
start = connection.ops.UNBOUNDED_PRECEDING
@ -1902,6 +1904,8 @@ class WindowFrame(Expression):
end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
elif self.end.value is not None and self.end.value == 0:
end = connection.ops.CURRENT_ROW
elif self.end.value is not None and self.end.value < 0:
end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING)
else:
end = connection.ops.UNBOUNDED_FOLLOWING
return self.template % {

View File

@ -923,9 +923,12 @@ SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default
frame includes all rows from the partition to the last row in the set.
The accepted values for the ``start`` and ``end`` arguments are ``None``, an
integer, or zero. A negative integer for ``start`` results in ``N preceding``,
while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``,
zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``.
integer, or zero. A negative integer for ``start`` results in ``N PRECEDING``,
while ``None`` yields ``UNBOUNDED PRECEDING``. In ``ROWS`` mode, a positive
integer can be used for ```start`` resulting in ``N FOLLOWING``. Positive
integers are accepted for ``end`` and results in ``N FOLLOWING``. In ``ROWS``
mode, a negative integer can be used for ```end`` resulting in ``N PRECEDING``.
For both ``start`` and ``end``, zero will return ``CURRENT ROW``.
There's a difference in what ``CURRENT ROW`` includes. When specified in
``ROWS`` mode, the frame starts or ends with the current row. When specified in
@ -970,6 +973,11 @@ released between twelve months before and twelve months after each movie:
... ),
... )
.. versionchanged:: 5.1
Support for positive integer ``start`` and negative integer ``end`` was
added for ``RowRange``.
.. currentmodule:: django.db.models
Technical Information

View File

@ -171,6 +171,9 @@ Models
* :meth:`.QuerySet.explain` now supports the ``generic_plan`` option on
PostgreSQL 16+.
* :class:`~django.db.models.expressions.RowRange` now accepts positive integers
for the ``start`` argument and negative integers for the ``end`` argument.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1328,6 +1328,84 @@ class WindowFunctionTests(TestCase):
),
)
def test_row_range_both_preceding(self):
"""
A query with ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING.
The resulting sum is the sum of the previous two (if they exist) rows
according to the ordering clause.
"""
qs = Employee.objects.annotate(
sum=Window(
expression=Sum("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(start=-2, end=-1),
)
).order_by("hire_date")
self.assertIn("ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING", str(qs.query))
self.assertQuerySetEqual(
qs,
[
("Miller", 100000, "Management", datetime.date(2005, 6, 1), None),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 125000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 100000),
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 100000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 82000),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 90000),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 91000),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 98000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 100000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), 90000),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.sum,
),
)
def test_row_range_both_following(self):
"""
A query with ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING.
The resulting sum is the sum of the following two (if they exist) rows
according to the ordering clause.
"""
qs = Employee.objects.annotate(
sum=Window(
expression=Sum("salary"),
order_by=[F("hire_date").asc(), F("name").desc()],
frame=RowRange(start=1, end=2),
)
).order_by("hire_date")
self.assertIn("ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING", str(qs.query))
self.assertQuerySetEqual(
qs,
[
("Miller", 100000, "Management", datetime.date(2005, 6, 1), 125000),
("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000),
("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 100000),
("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 82000),
("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 90000),
("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 91000),
("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 98000),
("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 100000),
("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 90000),
("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 84000),
("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 34000),
("Moore", 34000, "IT", datetime.date(2013, 8, 1), None),
],
transform=lambda row: (
row.name,
row.salary,
row.department,
row.hire_date,
row.sum,
),
)
@skipUnlessDBFeature("can_distinct_on_fields")
def test_distinct_window_function(self):
"""
@ -1479,6 +1557,19 @@ class WindowFunctionTests(TestCase):
)
)
def test_invalid_start_end_value_for_row_range(self):
msg = "start cannot be greater than end."
with self.assertRaisesMessage(ValueError, msg):
list(
Employee.objects.annotate(
test=Window(
expression=Sum("salary"),
order_by=F("hire_date").asc(),
frame=RowRange(start=4, end=-3),
)
)
)
def test_invalid_type_end_value_range(self):
msg = "end argument must be a positive integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
@ -1505,7 +1596,7 @@ class WindowFunctionTests(TestCase):
)
def test_invalid_type_end_row_range(self):
msg = "end argument must be a positive integer, zero, or None, but got 'a'."
msg = "end argument must be an integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
list(
Employee.objects.annotate(
@ -1551,7 +1642,7 @@ class WindowFunctionTests(TestCase):
)
def test_invalid_type_start_row_range(self):
msg = "start argument must be a negative integer, zero, or None, but got 'a'."
msg = "start argument must be an integer, zero, or None, but got 'a'."
with self.assertRaisesMessage(ValueError, msg):
list(
Employee.objects.annotate(
@ -1636,6 +1727,14 @@ class NonQueryWindowTests(SimpleTestCase):
repr(RowRange(start=0, end=0)),
"<RowRange: ROWS BETWEEN CURRENT ROW AND CURRENT ROW>",
)
self.assertEqual(
repr(RowRange(start=-2, end=-1)),
"<RowRange: ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING>",
)
self.assertEqual(
repr(RowRange(start=1, end=2)),
"<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING>",
)
def test_empty_group_by_cols(self):
window = Window(expression=Sum("pk"))