Fixed #29851 -- Fixed crash of annotations with window expressions in Subquery.

This commit is contained in:
Mariusz Felisiak 2018-12-27 20:21:57 +01:00 committed by GitHub
parent 6fe9c45b72
commit dd8ed64113
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 10 deletions

View File

@ -1292,14 +1292,14 @@ class WindowFrame(Expression):
template = '%(frame_type)s BETWEEN %(start)s AND %(end)s' template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
def __init__(self, start=None, end=None): def __init__(self, start=None, end=None):
self.start = start self.start = Value(start)
self.end = end self.end = Value(end)
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.start, self.end = exprs self.start, self.end = exprs
def get_source_expressions(self): def get_source_expressions(self):
return [Value(self.start), Value(self.end)] return [self.start, self.end]
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
@ -1317,16 +1317,16 @@ class WindowFrame(Expression):
return [] return []
def __str__(self): def __str__(self):
if self.start is not None and self.start < 0: if self.start.value is not None and self.start.value < 0:
start = '%d %s' % (abs(self.start), connection.ops.PRECEDING) start = '%d %s' % (abs(self.start.value), connection.ops.PRECEDING)
elif self.start is not None and self.start == 0: elif self.start.value is not None and self.start.value == 0:
start = connection.ops.CURRENT_ROW start = connection.ops.CURRENT_ROW
else: else:
start = connection.ops.UNBOUNDED_PRECEDING start = connection.ops.UNBOUNDED_PRECEDING
if self.end is not None and self.end > 0: if self.end.value is not None and self.end.value > 0:
end = '%d %s' % (self.end, connection.ops.FOLLOWING) end = '%d %s' % (self.end.value, connection.ops.FOLLOWING)
elif self.end is not None and self.end == 0: elif self.end.value is not None and self.end.value == 0:
end = connection.ops.CURRENT_ROW end = connection.ops.CURRENT_ROW
else: else:
end = connection.ops.UNBOUNDED_FOLLOWING end = connection.ops.UNBOUNDED_FOLLOWING

View File

@ -4,7 +4,7 @@ from unittest import skipIf, skipUnless
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import NotSupportedError, connection from django.db import NotSupportedError, connection
from django.db.models import ( from django.db.models import (
F, RowRange, Value, ValueRange, Window, WindowFrame, F, OuterRef, RowRange, Subquery, Value, ValueRange, Window, WindowFrame,
) )
from django.db.models.aggregates import Avg, Max, Min, Sum from django.db.models.aggregates import Avg, Max, Min, Sum
from django.db.models.functions import ( from django.db.models.functions import (
@ -584,6 +584,35 @@ class WindowFunctionTests(TestCase):
('Brown', 'Sales', 53000, datetime.date(2009, 9, 1), 148000) ('Brown', 'Sales', 53000, datetime.date(2009, 9, 1), 148000)
], transform=lambda row: (row.name, row.department, row.salary, row.hire_date, row.sum)) ], transform=lambda row: (row.name, row.department, row.salary, row.hire_date, row.sum))
def test_subquery_row_range_rank(self):
qs = Employee.objects.annotate(
highest_avg_salary_date=Subquery(
Employee.objects.filter(
department=OuterRef('department'),
).annotate(
avg_salary=Window(
expression=Avg('salary'),
order_by=[F('hire_date').asc()],
frame=RowRange(start=-1, end=1),
),
).order_by('-avg_salary', '-hire_date').values('hire_date')[:1],
),
).order_by('department', 'name')
self.assertQuerysetEqual(qs, [
('Adams', 'Accounting', datetime.date(2005, 11, 1)),
('Jenson', 'Accounting', datetime.date(2005, 11, 1)),
('Jones', 'Accounting', datetime.date(2005, 11, 1)),
('Williams', 'Accounting', datetime.date(2005, 11, 1)),
('Moore', 'IT', datetime.date(2013, 8, 1)),
('Wilkinson', 'IT', datetime.date(2013, 8, 1)),
('Johnson', 'Management', datetime.date(2005, 7, 1)),
('Miller', 'Management', datetime.date(2005, 7, 1)),
('Johnson', 'Marketing', datetime.date(2012, 3, 1)),
('Smith', 'Marketing', datetime.date(2012, 3, 1)),
('Brown', 'Sales', datetime.date(2009, 9, 1)),
('Smith', 'Sales', datetime.date(2009, 9, 1)),
], transform=lambda row: (row.name, row.department, row.highest_avg_salary_date))
def test_row_range_rank(self): def test_row_range_rank(self):
""" """
A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING. A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING.