Fixed #28382 -- Prevented BaseExpression._output_field from being set if _resolve_output_field() fails.

This commit is contained in:
Sergey Fedoseev 2017-07-07 09:04:37 +05:00 committed by Tim Graham
parent 306b961a4d
commit 29769a9942
3 changed files with 14 additions and 17 deletions

View File

@ -44,8 +44,8 @@ class Avg(Aggregate):
def _resolve_output_field(self): def _resolve_output_field(self):
source_field = self.get_source_fields()[0] source_field = self.get_source_fields()[0]
if isinstance(source_field, (IntegerField, DecimalField)): if isinstance(source_field, (IntegerField, DecimalField)):
self._output_field = FloatField() return FloatField()
super()._resolve_output_field() return super()._resolve_output_field()
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':

View File

@ -236,7 +236,7 @@ class BaseExpression:
None values. None values.
""" """
if self._output_field is None: if self._output_field is None:
self._resolve_output_field() self._output_field = self._resolve_output_field()
return self._output_field return self._output_field
def _resolve_output_field(self): def _resolve_output_field(self):
@ -253,18 +253,11 @@ class BaseExpression:
this check. If all sources are `None`, then an error will be thrown this check. If all sources are `None`, then an error will be thrown
higher up the stack in the `output_field` property. higher up the stack in the `output_field` property.
""" """
if self._output_field is None: sources_iter = (source for source in self.get_source_fields() if source is not None)
sources = self.get_source_fields() for output_field in sources_iter:
num_sources = len(sources) if any(not isinstance(output_field, source.__class__) for source in sources_iter):
if num_sources == 0: raise FieldError('Expression contains mixed types. You must set output_field.')
self._output_field = None return output_field
else:
for source in sources:
if self._output_field is None:
self._output_field = source
if source is not None and not isinstance(self._output_field, source.__class__):
raise FieldError(
"Expression contains mixed types. You must set output_field")
def convert_value(self, value, expression, connection, context): def convert_value(self, value, expression, connection, context):
""" """

View File

@ -965,8 +965,12 @@ class AggregateTestCase(TestCase):
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)}) self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
def test_combine_different_types(self): def test_combine_different_types(self):
with self.assertRaisesMessage(FieldError, 'Expression contains mixed types. You must set output_field'): msg = 'Expression contains mixed types. You must set output_field.'
Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')).get(pk=self.b4.pk) qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))
with self.assertRaisesMessage(FieldError, msg):
qs.first()
with self.assertRaisesMessage(FieldError, msg):
qs.first()
b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=IntegerField())).get(pk=self.b4.pk) output_field=IntegerField())).get(pk=self.b4.pk)