Refs #28643 -- Changed Avg() to use NumericOutputFieldMixin.

Keeps precision instead of forcing DecimalField to FloatField.
This commit is contained in:
Nick Pope 2018-12-01 23:46:28 +00:00 committed by Tim Graham
parent 3d5e0f8394
commit c690afb873
6 changed files with 23 additions and 21 deletions

View File

@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import DecimalField, FloatField, IntegerField from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions.mixins import NumericOutputFieldMixin
__all__ = [ __all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@ -93,16 +94,10 @@ class Aggregate(Func):
return options return options
class Avg(Aggregate): class Avg(NumericOutputFieldMixin, Aggregate):
function = 'AVG' function = 'AVG'
name = 'Avg' name = 'Avg'
def _resolve_output_field(self):
source_field = self.get_source_fields()[0]
if isinstance(source_field, (IntegerField, DecimalField)):
return FloatField()
return super()._resolve_output_field()
def as_mysql(self, compiler, connection, **extra_context): def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context) sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':

View File

@ -1,6 +1,6 @@
import sys import sys
from django.db.models.fields import DecimalField, FloatField from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast from django.db.models.functions import Cast
@ -23,5 +23,9 @@ class FixDecimalInputMixin:
class NumericOutputFieldMixin: class NumericOutputFieldMixin:
def _resolve_output_field(self): def _resolve_output_field(self):
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions()) source_expressions = self.get_source_expressions()
return DecimalField() if has_decimals else FloatField() if any(isinstance(s.output_field, DecimalField) for s in source_expressions):
return DecimalField()
if any(isinstance(s.output_field, IntegerField) for s in source_expressions):
return FloatField()
return super()._resolve_output_field() if source_expressions else FloatField()

View File

@ -3336,14 +3336,14 @@ by the aggregate.
``Avg`` ``Avg``
~~~~~~~ ~~~~~~~
.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra) .. class:: Avg(expression, output_field=None, filter=None, **extra)
Returns the mean value of the given expression, which must be numeric Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``. unless you specify a different ``output_field``.
* Default alias: ``<field>__avg`` * Default alias: ``<field>__avg``
* Return type: ``float`` (or the type of whatever ``output_field`` is * Return type: ``float`` if input is ``int``, otherwise same as input
specified) field, or ``output_field`` if supplied
``Count`` ``Count``
~~~~~~~~~ ~~~~~~~~~

View File

@ -493,6 +493,9 @@ Miscellaneous
* :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman). * :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman).
* The :class:`~django.db.models.Avg` aggregate function now returns a
``Decimal`` instead of a ``float`` when the input is ``Decimal``.
.. _deprecated-features-2.2: .. _deprecated-features-2.2:
Features deprecated in 2.2 Features deprecated in 2.2

View File

@ -865,15 +865,15 @@ class AggregateTestCase(TestCase):
def test_avg_decimal_field(self): def test_avg_decimal_field(self):
v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price'] v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price']
self.assertIsInstance(v, float) self.assertIsInstance(v, Decimal)
self.assertEqual(v, Approximate(47.39, places=2)) self.assertEqual(v, Approximate(Decimal('47.39'), places=2))
def test_order_of_precedence(self): def test_order_of_precedence(self):
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3) p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)}) self.assertEqual(p1, {'avg_price': Approximate(Decimal('148.18'), places=2)})
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3) p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)}) self.assertEqual(p2, {'avg_price': Approximate(Decimal('53.39'), places=2)})
def test_combine_different_types(self): def test_combine_different_types(self):
msg = 'Expression contains mixed types. You must set output_field.' msg = 'Expression contains mixed types. You must set output_field.'
@ -1087,7 +1087,7 @@ class AggregateTestCase(TestCase):
return super().as_sql(compiler, connection, function='MAX', **extra_context) return super().as_sql(compiler, connection, function='MAX', **extra_context)
qs = Publisher.objects.annotate( qs = Publisher.objects.annotate(
price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) price_or_median=Greatest(Avg('book__rating', output_field=DecimalField()), Avg('book__price'))
).filter(price_or_median__gte=F('num_awards')).order_by('num_awards') ).filter(price_or_median__gte=F('num_awards')).order_by('num_awards')
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs, [1, 3, 7, 9], lambda v: v.num_awards) qs, [1, 3, 7, 9], lambda v: v.num_awards)

View File

@ -401,7 +401,7 @@ class AggregationTests(TestCase):
When(pages__lt=400, then='discount_price'), When(pages__lt=400, then='discount_price'),
output_field=DecimalField() output_field=DecimalField()
)))['test'], )))['test'],
22.27, places=2 Decimal('22.27'), places=2
) )
def test_distinct_conditional_aggregate(self): def test_distinct_conditional_aggregate(self):
@ -1041,7 +1041,7 @@ class AggregationTests(TestCase):
books = Book.objects.values_list("publisher__name").annotate( books = Book.objects.values_list("publisher__name").annotate(
Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages") Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages")
).order_by("-publisher__name") ).order_by("-publisher__name")
self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0)) self.assertEqual(books[0], ('Sams', 1, Decimal('23.09'), 45.0, 528.0))
def test_annotation_disjunction(self): def test_annotation_disjunction(self):
qs = Book.objects.annotate(n_authors=Count("authors")).filter( qs = Book.objects.annotate(n_authors=Count("authors")).filter(