From 29769a9942b08c86397692fa164396a670a4e1ec Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Fri, 7 Jul 2017 09:04:37 +0500 Subject: [PATCH] Fixed #28382 -- Prevented BaseExpression._output_field from being set if _resolve_output_field() fails. --- django/db/models/aggregates.py | 4 ++-- django/db/models/expressions.py | 19 ++++++------------- tests/aggregation/tests.py | 8 ++++++-- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 095455f8c1..507fced8e7 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -44,8 +44,8 @@ class Avg(Aggregate): def _resolve_output_field(self): source_field = self.get_source_fields()[0] if isinstance(source_field, (IntegerField, DecimalField)): - self._output_field = FloatField() - super()._resolve_output_field() + return FloatField() + return super()._resolve_output_field() def as_oracle(self, compiler, connection): if self.output_field.get_internal_type() == 'DurationField': diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 5cd558c12a..7e5c7313ce 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -236,7 +236,7 @@ class BaseExpression: None values. """ if self._output_field is None: - self._resolve_output_field() + self._output_field = self._resolve_output_field() return self._output_field def _resolve_output_field(self): @@ -253,18 +253,11 @@ class BaseExpression: this check. If all sources are `None`, then an error will be thrown higher up the stack in the `output_field` property. """ - if self._output_field is None: - sources = self.get_source_fields() - num_sources = len(sources) - if num_sources == 0: - self._output_field = None - else: - for source in sources: - if self._output_field is None: - self._output_field = source - if source is not None and not isinstance(self._output_field, source.__class__): - raise FieldError( - "Expression contains mixed types. You must set output_field") + sources_iter = (source for source in self.get_source_fields() if source is not None) + for output_field in sources_iter: + if any(not isinstance(output_field, source.__class__) for source in sources_iter): + raise FieldError('Expression contains mixed types. You must set output_field.') + return output_field def convert_value(self, value, expression, connection, context): """ diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index c3f7998f60..3c53877303 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -965,8 +965,12 @@ class AggregateTestCase(TestCase): self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)}) def test_combine_different_types(self): - with self.assertRaisesMessage(FieldError, 'Expression contains mixed types. You must set output_field'): - Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')).get(pk=self.b4.pk) + msg = 'Expression contains mixed types. You must set output_field.' + qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')) + with self.assertRaisesMessage(FieldError, msg): + qs.first() + with self.assertRaisesMessage(FieldError, msg): + qs.first() b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), output_field=IntegerField())).get(pk=self.b4.pk)