Fixed #26458 -- Based Avg's default output_field resolution on its source field type.

Thanks Tim for the review and Josh for the input.
This commit is contained in:
Simon Charette 2016-04-05 23:48:08 -04:00
parent 361cb7a857
commit a6074e8908
2 changed files with 12 additions and 4 deletions

View File

@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.expressions import Func, Star from django.db.models.expressions import Func, Star
from django.db.models.fields import FloatField, IntegerField from django.db.models.fields import DecimalField, FloatField, IntegerField
__all__ = [ __all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@ -41,9 +41,11 @@ class Avg(Aggregate):
function = 'AVG' function = 'AVG'
name = 'Avg' name = 'Avg'
def __init__(self, expression, **extra): def _resolve_output_field(self):
output_field = extra.pop('output_field', FloatField()) source_field = self.get_source_fields()[0]
super(Avg, self).__init__(expression, output_field=output_field, **extra) if isinstance(source_field, (IntegerField, DecimalField)):
self._output_field = FloatField()
super(Avg, self)._resolve_output_field()
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':

View File

@ -496,10 +496,16 @@ class AggregateTestCase(TestCase):
self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)}) self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)})
def test_avg_duration_field(self): def test_avg_duration_field(self):
# Explicit `output_field`.
self.assertEqual( self.assertEqual(
Publisher.objects.aggregate(Avg('duration', output_field=DurationField())), Publisher.objects.aggregate(Avg('duration', output_field=DurationField())),
{'duration__avg': datetime.timedelta(days=1, hours=12)} {'duration__avg': datetime.timedelta(days=1, hours=12)}
) )
# Implicit `output_field`.
self.assertEqual(
Publisher.objects.aggregate(Avg('duration')),
{'duration__avg': datetime.timedelta(days=1, hours=12)}
)
def test_sum_duration_field(self): def test_sum_duration_field(self):
self.assertEqual( self.assertEqual(