Refs #28643 -- Changed Avg() to use NumericOutputFieldMixin.

Keeps precision instead of forcing DecimalField to FloatField.
This commit is contained in:
Nick Pope 2018-12-01 23:46:28 +00:00 committed by Tim Graham
parent 3d5e0f8394
commit c690afb873
6 changed files with 23 additions and 21 deletions

View File

@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions.mixins import NumericOutputFieldMixin
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@ -93,16 +94,10 @@ class Aggregate(Func):
return options
class Avg(Aggregate):
class Avg(NumericOutputFieldMixin, Aggregate):
function = 'AVG'
name = 'Avg'
def _resolve_output_field(self):
source_field = self.get_source_fields()[0]
if isinstance(source_field, (IntegerField, DecimalField)):
return FloatField()
return super()._resolve_output_field()
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':

View File

@ -1,6 +1,6 @@
import sys
from django.db.models.fields import DecimalField, FloatField
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast
@ -23,5 +23,9 @@ class FixDecimalInputMixin:
class NumericOutputFieldMixin:
def _resolve_output_field(self):
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
return DecimalField() if has_decimals else FloatField()
source_expressions = self.get_source_expressions()
if any(isinstance(s.output_field, DecimalField) for s in source_expressions):
return DecimalField()
if any(isinstance(s.output_field, IntegerField) for s in source_expressions):
return FloatField()
return super()._resolve_output_field() if source_expressions else FloatField()

View File

@ -3336,14 +3336,14 @@ by the aggregate.
``Avg``
~~~~~~~
.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
.. class:: Avg(expression, output_field=None, filter=None, **extra)
Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``.
* Default alias: ``<field>__avg``
* Return type: ``float`` (or the type of whatever ``output_field`` is
specified)
* Return type: ``float`` if input is ``int``, otherwise same as input
field, or ``output_field`` if supplied
``Count``
~~~~~~~~~

View File

@ -493,6 +493,9 @@ Miscellaneous
* :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman).
* The :class:`~django.db.models.Avg` aggregate function now returns a
``Decimal`` instead of a ``float`` when the input is ``Decimal``.
.. _deprecated-features-2.2:
Features deprecated in 2.2

View File

@ -865,15 +865,15 @@ class AggregateTestCase(TestCase):
def test_avg_decimal_field(self):
v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price']
self.assertIsInstance(v, float)
self.assertEqual(v, Approximate(47.39, places=2))
self.assertIsInstance(v, Decimal)
self.assertEqual(v, Approximate(Decimal('47.39'), places=2))
def test_order_of_precedence(self):
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)})
self.assertEqual(p1, {'avg_price': Approximate(Decimal('148.18'), places=2)})
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
self.assertEqual(p2, {'avg_price': Approximate(Decimal('53.39'), places=2)})
def test_combine_different_types(self):
msg = 'Expression contains mixed types. You must set output_field.'
@ -1087,7 +1087,7 @@ class AggregateTestCase(TestCase):
return super().as_sql(compiler, connection, function='MAX', **extra_context)
qs = Publisher.objects.annotate(
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
price_or_median=Greatest(Avg('book__rating', output_field=DecimalField()), Avg('book__price'))
).filter(price_or_median__gte=F('num_awards')).order_by('num_awards')
self.assertQuerysetEqual(
qs, [1, 3, 7, 9], lambda v: v.num_awards)

View File

@ -401,7 +401,7 @@ class AggregationTests(TestCase):
When(pages__lt=400, then='discount_price'),
output_field=DecimalField()
)))['test'],
22.27, places=2
Decimal('22.27'), places=2
)
def test_distinct_conditional_aggregate(self):
@ -1041,7 +1041,7 @@ class AggregationTests(TestCase):
books = Book.objects.values_list("publisher__name").annotate(
Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages")
).order_by("-publisher__name")
self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0))
self.assertEqual(books[0], ('Sams', 1, Decimal('23.09'), 45.0, 528.0))
def test_annotation_disjunction(self):
qs = Book.objects.annotate(n_authors=Count("authors")).filter(