diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index a7dc55ee98..f9202543a3 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions. """ from django.core.exceptions import FieldError from django.db.models.expressions import Case, Func, Star, When -from django.db.models.fields import DecimalField, FloatField, IntegerField +from django.db.models.fields import FloatField, IntegerField +from django.db.models.functions.mixins import NumericOutputFieldMixin __all__ = [ 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', @@ -93,16 +94,10 @@ class Aggregate(Func): return options -class Avg(Aggregate): +class Avg(NumericOutputFieldMixin, Aggregate): function = 'AVG' name = 'Avg' - def _resolve_output_field(self): - source_field = self.get_source_fields()[0] - if isinstance(source_field, (IntegerField, DecimalField)): - return FloatField() - return super()._resolve_output_field() - def as_mysql(self, compiler, connection, **extra_context): sql, params = super().as_sql(compiler, connection, **extra_context) if self.output_field.get_internal_type() == 'DurationField': diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py index 1bf3d6cbd0..9b46987788 100644 --- a/django/db/models/functions/mixins.py +++ b/django/db/models/functions/mixins.py @@ -1,6 +1,6 @@ import sys -from django.db.models.fields import DecimalField, FloatField +from django.db.models.fields import DecimalField, FloatField, IntegerField from django.db.models.functions import Cast @@ -23,5 +23,9 @@ class FixDecimalInputMixin: class NumericOutputFieldMixin: def _resolve_output_field(self): - has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions()) - return DecimalField() if has_decimals else FloatField() + source_expressions = self.get_source_expressions() + if any(isinstance(s.output_field, DecimalField) for s in source_expressions): + return DecimalField() + if any(isinstance(s.output_field, IntegerField) for s in source_expressions): + return FloatField() + return super()._resolve_output_field() if source_expressions else FloatField() diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 78eb175329..46fcd50e37 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -3336,14 +3336,14 @@ by the aggregate. ``Avg`` ~~~~~~~ -.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra) +.. class:: Avg(expression, output_field=None, filter=None, **extra) Returns the mean value of the given expression, which must be numeric unless you specify a different ``output_field``. * Default alias: ``__avg`` - * Return type: ``float`` (or the type of whatever ``output_field`` is - specified) + * Return type: ``float`` if input is ``int``, otherwise same as input + field, or ``output_field`` if supplied ``Count`` ~~~~~~~~~ diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 13f7617888..b96b0ed1ef 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -493,6 +493,9 @@ Miscellaneous * :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman). +* The :class:`~django.db.models.Avg` aggregate function now returns a + ``Decimal`` instead of a ``float`` when the input is ``Decimal``. + .. _deprecated-features-2.2: Features deprecated in 2.2 diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 75d2ecb1c5..8cac90f020 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -865,15 +865,15 @@ class AggregateTestCase(TestCase): def test_avg_decimal_field(self): v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price'] - self.assertIsInstance(v, float) - self.assertEqual(v, Approximate(47.39, places=2)) + self.assertIsInstance(v, Decimal) + self.assertEqual(v, Approximate(Decimal('47.39'), places=2)) def test_order_of_precedence(self): p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3) - self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)}) + self.assertEqual(p1, {'avg_price': Approximate(Decimal('148.18'), places=2)}) p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3) - self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)}) + self.assertEqual(p2, {'avg_price': Approximate(Decimal('53.39'), places=2)}) def test_combine_different_types(self): msg = 'Expression contains mixed types. You must set output_field.' @@ -1087,7 +1087,7 @@ class AggregateTestCase(TestCase): return super().as_sql(compiler, connection, function='MAX', **extra_context) qs = Publisher.objects.annotate( - price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) + price_or_median=Greatest(Avg('book__rating', output_field=DecimalField()), Avg('book__price')) ).filter(price_or_median__gte=F('num_awards')).order_by('num_awards') self.assertQuerysetEqual( qs, [1, 3, 7, 9], lambda v: v.num_awards) diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index 2b3948a0b4..64bbc13f80 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -401,7 +401,7 @@ class AggregationTests(TestCase): When(pages__lt=400, then='discount_price'), output_field=DecimalField() )))['test'], - 22.27, places=2 + Decimal('22.27'), places=2 ) def test_distinct_conditional_aggregate(self): @@ -1041,7 +1041,7 @@ class AggregationTests(TestCase): books = Book.objects.values_list("publisher__name").annotate( Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages") ).order_by("-publisher__name") - self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0)) + self.assertEqual(books[0], ('Sams', 1, Decimal('23.09'), 45.0, 528.0)) def test_annotation_disjunction(self): qs = Book.objects.annotate(n_authors=Count("authors")).filter(