Fixed #28382 -- Prevented BaseExpression._output_field from being set if _resolve_output_field() fails.
This commit is contained in:
parent
306b961a4d
commit
29769a9942
|
@ -44,8 +44,8 @@ class Avg(Aggregate):
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
source_field = self.get_source_fields()[0]
|
source_field = self.get_source_fields()[0]
|
||||||
if isinstance(source_field, (IntegerField, DecimalField)):
|
if isinstance(source_field, (IntegerField, DecimalField)):
|
||||||
self._output_field = FloatField()
|
return FloatField()
|
||||||
super()._resolve_output_field()
|
return super()._resolve_output_field()
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection):
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
|
|
|
@ -236,7 +236,7 @@ class BaseExpression:
|
||||||
None values.
|
None values.
|
||||||
"""
|
"""
|
||||||
if self._output_field is None:
|
if self._output_field is None:
|
||||||
self._resolve_output_field()
|
self._output_field = self._resolve_output_field()
|
||||||
return self._output_field
|
return self._output_field
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
|
@ -253,18 +253,11 @@ class BaseExpression:
|
||||||
this check. If all sources are `None`, then an error will be thrown
|
this check. If all sources are `None`, then an error will be thrown
|
||||||
higher up the stack in the `output_field` property.
|
higher up the stack in the `output_field` property.
|
||||||
"""
|
"""
|
||||||
if self._output_field is None:
|
sources_iter = (source for source in self.get_source_fields() if source is not None)
|
||||||
sources = self.get_source_fields()
|
for output_field in sources_iter:
|
||||||
num_sources = len(sources)
|
if any(not isinstance(output_field, source.__class__) for source in sources_iter):
|
||||||
if num_sources == 0:
|
raise FieldError('Expression contains mixed types. You must set output_field.')
|
||||||
self._output_field = None
|
return output_field
|
||||||
else:
|
|
||||||
for source in sources:
|
|
||||||
if self._output_field is None:
|
|
||||||
self._output_field = source
|
|
||||||
if source is not None and not isinstance(self._output_field, source.__class__):
|
|
||||||
raise FieldError(
|
|
||||||
"Expression contains mixed types. You must set output_field")
|
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection, context):
|
def convert_value(self, value, expression, connection, context):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -965,8 +965,12 @@ class AggregateTestCase(TestCase):
|
||||||
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
|
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
|
||||||
|
|
||||||
def test_combine_different_types(self):
|
def test_combine_different_types(self):
|
||||||
with self.assertRaisesMessage(FieldError, 'Expression contains mixed types. You must set output_field'):
|
msg = 'Expression contains mixed types. You must set output_field.'
|
||||||
Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')).get(pk=self.b4.pk)
|
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))
|
||||||
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
qs.first()
|
||||||
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
qs.first()
|
||||||
|
|
||||||
b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
|
||||||
output_field=IntegerField())).get(pk=self.b4.pk)
|
output_field=IntegerField())).get(pk=self.b4.pk)
|
||||||
|
|
Loading…
Reference in New Issue