From a6074e8908e36286de7e31f165be801df8ee1eab Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Tue, 5 Apr 2016 23:48:08 -0400 Subject: [PATCH] Fixed #26458 -- Based Avg's default output_field resolution on its source field type. Thanks Tim for the review and Josh for the input. --- django/db/models/aggregates.py | 10 ++++++---- tests/aggregation/tests.py | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 0ddb281cd9..6c320ed828 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions. """ from django.core.exceptions import FieldError from django.db.models.expressions import Func, Star -from django.db.models.fields import FloatField, IntegerField +from django.db.models.fields import DecimalField, FloatField, IntegerField __all__ = [ 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', @@ -41,9 +41,11 @@ class Avg(Aggregate): function = 'AVG' name = 'Avg' - def __init__(self, expression, **extra): - output_field = extra.pop('output_field', FloatField()) - super(Avg, self).__init__(expression, output_field=output_field, **extra) + def _resolve_output_field(self): + source_field = self.get_source_fields()[0] + if isinstance(source_field, (IntegerField, DecimalField)): + self._output_field = FloatField() + super(Avg, self)._resolve_output_field() def as_oracle(self, compiler, connection): if self.output_field.get_internal_type() == 'DurationField': diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 2e68d0a7f3..758b101157 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -496,10 +496,16 @@ class AggregateTestCase(TestCase): self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)}) def test_avg_duration_field(self): + # Explicit `output_field`. self.assertEqual( Publisher.objects.aggregate(Avg('duration', output_field=DurationField())), {'duration__avg': datetime.timedelta(days=1, hours=12)} ) + # Implicit `output_field`. + self.assertEqual( + Publisher.objects.aggregate(Avg('duration')), + {'duration__avg': datetime.timedelta(days=1, hours=12)} + ) def test_sum_duration_field(self): self.assertEqual(