Fixed #26458 -- Based Avg's default output_field resolution on its source field type.
Thanks Tim for the review and Josh for the input.
This commit is contained in:
parent
361cb7a857
commit
a6074e8908
|
@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions.
|
||||||
"""
|
"""
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models.expressions import Func, Star
|
from django.db.models.expressions import Func, Star
|
||||||
from django.db.models.fields import FloatField, IntegerField
|
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||||
|
@ -41,9 +41,11 @@ class Avg(Aggregate):
|
||||||
function = 'AVG'
|
function = 'AVG'
|
||||||
name = 'Avg'
|
name = 'Avg'
|
||||||
|
|
||||||
def __init__(self, expression, **extra):
|
def _resolve_output_field(self):
|
||||||
output_field = extra.pop('output_field', FloatField())
|
source_field = self.get_source_fields()[0]
|
||||||
super(Avg, self).__init__(expression, output_field=output_field, **extra)
|
if isinstance(source_field, (IntegerField, DecimalField)):
|
||||||
|
self._output_field = FloatField()
|
||||||
|
super(Avg, self)._resolve_output_field()
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection):
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
|
|
|
@ -496,10 +496,16 @@ class AggregateTestCase(TestCase):
|
||||||
self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)})
|
self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)})
|
||||||
|
|
||||||
def test_avg_duration_field(self):
|
def test_avg_duration_field(self):
|
||||||
|
# Explicit `output_field`.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
Publisher.objects.aggregate(Avg('duration', output_field=DurationField())),
|
Publisher.objects.aggregate(Avg('duration', output_field=DurationField())),
|
||||||
{'duration__avg': datetime.timedelta(days=1, hours=12)}
|
{'duration__avg': datetime.timedelta(days=1, hours=12)}
|
||||||
)
|
)
|
||||||
|
# Implicit `output_field`.
|
||||||
|
self.assertEqual(
|
||||||
|
Publisher.objects.aggregate(Avg('duration')),
|
||||||
|
{'duration__avg': datetime.timedelta(days=1, hours=12)}
|
||||||
|
)
|
||||||
|
|
||||||
def test_sum_duration_field(self):
|
def test_sum_duration_field(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
Loading…
Reference in New Issue