Fixed #31755 -- Made temporal subtraction resolve output field.

This commit is contained in:
Sergey Fedoseev 2020-06-30 15:37:35 +05:00 committed by Mariusz Felisiak
parent 2d67222472
commit 9d519d3dc4
2 changed files with 40 additions and 53 deletions

View File

@ -443,23 +443,6 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
self.lhs, self.rhs = exprs self.lhs, self.rhs = exprs
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
try:
lhs_type = self.lhs.output_field.get_internal_type()
except FieldError:
lhs_type = None
try:
rhs_type = self.rhs.output_field.get_internal_type()
except FieldError:
rhs_type = None
if (
not connection.features.has_native_duration_field and
'DurationField' in {lhs_type, rhs_type} and
lhs_type != rhs_type
):
return DurationExpression(self.lhs, self.connector, self.rhs).as_sql(compiler, connection)
datetime_fields = {'DateField', 'DateTimeField', 'TimeField'}
if self.connector == self.SUB and lhs_type in datetime_fields and lhs_type == rhs_type:
return TemporalSubtraction(self.lhs, self.rhs).as_sql(compiler, connection)
expressions = [] expressions = []
expression_params = [] expression_params = []
sql, params = compiler.compile(self.lhs) sql, params = compiler.compile(self.lhs)
@ -474,10 +457,30 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
return expression_wrapper % sql, expression_params return expression_wrapper % sql, expression_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
if not isinstance(self, (DurationExpression, TemporalSubtraction)):
try:
lhs_type = lhs.output_field.get_internal_type()
except (AttributeError, FieldError):
lhs_type = None
try:
rhs_type = rhs.output_field.get_internal_type()
except (AttributeError, FieldError):
rhs_type = None
if 'DurationField' in {lhs_type, rhs_type} and lhs_type != rhs_type:
return DurationExpression(self.lhs, self.connector, self.rhs).resolve_expression(
query, allow_joins, reuse, summarize, for_save,
)
datetime_fields = {'DateField', 'DateTimeField', 'TimeField'}
if self.connector == self.SUB and lhs_type in datetime_fields and lhs_type == rhs_type:
return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
query, allow_joins, reuse, summarize, for_save,
)
c = self.copy() c = self.copy()
c.is_summary = summarize c.is_summary = summarize
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) c.lhs = lhs
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) c.rhs = rhs
return c return c
@ -494,6 +497,8 @@ class DurationExpression(CombinedExpression):
return compiler.compile(side) return compiler.compile(side)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if connection.features.has_native_duration_field:
return super().as_sql(compiler, connection)
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
expressions = [] expressions = []
expression_params = [] expression_params = []

View File

@ -1493,9 +1493,7 @@ class FTimeDeltaTests(TestCase):
@skipUnlessDBFeature('supports_temporal_subtraction') @skipUnlessDBFeature('supports_temporal_subtraction')
def test_date_subtraction(self): def test_date_subtraction(self):
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
completion_duration=ExpressionWrapper( completion_duration=F('completed') - F('assigned'),
F('completed') - F('assigned'), output_field=DurationField()
)
) )
at_least_5_days = {e.name for e in queryset.filter(completion_duration__gte=datetime.timedelta(days=5))} at_least_5_days = {e.name for e in queryset.filter(completion_duration__gte=datetime.timedelta(days=5))}
@ -1507,10 +1505,9 @@ class FTimeDeltaTests(TestCase):
less_than_5_days = {e.name for e in queryset.filter(completion_duration__lt=datetime.timedelta(days=5))} less_than_5_days = {e.name for e in queryset.filter(completion_duration__lt=datetime.timedelta(days=5))}
self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'}) self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'})
queryset = Experiment.objects.annotate(difference=ExpressionWrapper( queryset = Experiment.objects.annotate(
F('completed') - Value(None, output_field=DateField()), difference=F('completed') - Value(None, output_field=DateField()),
output_field=DurationField(), )
))
self.assertIsNone(queryset.first().difference) self.assertIsNone(queryset.first().difference)
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper( queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
@ -1523,9 +1520,7 @@ class FTimeDeltaTests(TestCase):
def test_date_subquery_subtraction(self): def test_date_subquery_subtraction(self):
subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('completed') subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('completed')
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
difference=ExpressionWrapper( difference=subquery - F('completed'),
subquery - F('completed'), output_field=DurationField(),
),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertTrue(queryset.exists()) self.assertTrue(queryset.exists())
@ -1540,9 +1535,7 @@ class FTimeDeltaTests(TestCase):
self.e0.completed, self.e0.completed,
output_field=DateField(), output_field=DateField(),
), ),
difference=ExpressionWrapper( difference=F('date_case') - F('completed_value'),
F('date_case') - F('completed_value'), output_field=DurationField(),
),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertEqual(queryset.get(), self.e0) self.assertEqual(queryset.get(), self.e0)
@ -1550,20 +1543,16 @@ class FTimeDeltaTests(TestCase):
def test_time_subtraction(self): def test_time_subtraction(self):
Time.objects.create(time=datetime.time(12, 30, 15, 2345)) Time.objects.create(time=datetime.time(12, 30, 15, 2345))
queryset = Time.objects.annotate( queryset = Time.objects.annotate(
difference=ExpressionWrapper( difference=F('time') - Value(datetime.time(11, 15, 0), output_field=TimeField()),
F('time') - Value(datetime.time(11, 15, 0), output_field=TimeField()),
output_field=DurationField(),
)
) )
self.assertEqual( self.assertEqual(
queryset.get().difference, queryset.get().difference,
datetime.timedelta(hours=1, minutes=15, seconds=15, microseconds=2345) datetime.timedelta(hours=1, minutes=15, seconds=15, microseconds=2345)
) )
queryset = Time.objects.annotate(difference=ExpressionWrapper( queryset = Time.objects.annotate(
F('time') - Value(None, output_field=TimeField()), difference=F('time') - Value(None, output_field=TimeField()),
output_field=DurationField(), )
))
self.assertIsNone(queryset.first().difference) self.assertIsNone(queryset.first().difference)
queryset = Time.objects.annotate(shifted=ExpressionWrapper( queryset = Time.objects.annotate(shifted=ExpressionWrapper(
@ -1577,9 +1566,7 @@ class FTimeDeltaTests(TestCase):
Time.objects.create(time=datetime.time(12, 30, 15, 2345)) Time.objects.create(time=datetime.time(12, 30, 15, 2345))
subquery = Time.objects.filter(pk=OuterRef('pk')).values('time') subquery = Time.objects.filter(pk=OuterRef('pk')).values('time')
queryset = Time.objects.annotate( queryset = Time.objects.annotate(
difference=ExpressionWrapper( difference=subquery - F('time'),
subquery - F('time'), output_field=DurationField(),
),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertTrue(queryset.exists()) self.assertTrue(queryset.exists())
@ -1595,10 +1582,9 @@ class FTimeDeltaTests(TestCase):
] ]
self.assertEqual(over_estimate, ['e4']) self.assertEqual(over_estimate, ['e4'])
queryset = Experiment.objects.annotate(difference=ExpressionWrapper( queryset = Experiment.objects.annotate(
F('start') - Value(None, output_field=DateTimeField()), difference=F('start') - Value(None, output_field=DateTimeField()),
output_field=DurationField(), )
))
self.assertIsNone(queryset.first().difference) self.assertIsNone(queryset.first().difference)
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper( queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
@ -1611,9 +1597,7 @@ class FTimeDeltaTests(TestCase):
def test_datetime_subquery_subtraction(self): def test_datetime_subquery_subtraction(self):
subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('start') subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('start')
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
difference=ExpressionWrapper( difference=subquery - F('start'),
subquery - F('start'), output_field=DurationField(),
),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertTrue(queryset.exists()) self.assertTrue(queryset.exists())
@ -1621,9 +1605,7 @@ class FTimeDeltaTests(TestCase):
def test_datetime_subtraction_microseconds(self): def test_datetime_subtraction_microseconds(self):
delta = datetime.timedelta(microseconds=8999999999999999) delta = datetime.timedelta(microseconds=8999999999999999)
Experiment.objects.update(end=F('start') + delta) Experiment.objects.update(end=F('start') + delta)
qs = Experiment.objects.annotate( qs = Experiment.objects.annotate(delta=F('end') - F('start'))
delta=ExpressionWrapper(F('end') - F('start'), output_field=DurationField())
)
for e in qs: for e in qs:
self.assertEqual(e.delta, delta) self.assertEqual(e.delta, delta)