Refs #28643 -- Changed Avg() to use NumericOutputFieldMixin.
Keeps precision instead of forcing DecimalField to FloatField.
This commit is contained in:
parent
3d5e0f8394
commit
c690afb873
|
@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
|
|||
"""
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.expressions import Case, Func, Star, When
|
||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.functions.mixins import NumericOutputFieldMixin
|
||||
|
||||
__all__ = [
|
||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||
|
@ -93,16 +94,10 @@ class Aggregate(Func):
|
|||
return options
|
||||
|
||||
|
||||
class Avg(Aggregate):
|
||||
class Avg(NumericOutputFieldMixin, Aggregate):
|
||||
function = 'AVG'
|
||||
name = 'Avg'
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source_field = self.get_source_fields()[0]
|
||||
if isinstance(source_field, (IntegerField, DecimalField)):
|
||||
return FloatField()
|
||||
return super()._resolve_output_field()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
if self.output_field.get_internal_type() == 'DurationField':
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import sys
|
||||
|
||||
from django.db.models.fields import DecimalField, FloatField
|
||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
|
||||
|
@ -23,5 +23,9 @@ class FixDecimalInputMixin:
|
|||
class NumericOutputFieldMixin:
|
||||
|
||||
def _resolve_output_field(self):
|
||||
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
|
||||
return DecimalField() if has_decimals else FloatField()
|
||||
source_expressions = self.get_source_expressions()
|
||||
if any(isinstance(s.output_field, DecimalField) for s in source_expressions):
|
||||
return DecimalField()
|
||||
if any(isinstance(s.output_field, IntegerField) for s in source_expressions):
|
||||
return FloatField()
|
||||
return super()._resolve_output_field() if source_expressions else FloatField()
|
||||
|
|
|
@ -3336,14 +3336,14 @@ by the aggregate.
|
|||
``Avg``
|
||||
~~~~~~~
|
||||
|
||||
.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
|
||||
.. class:: Avg(expression, output_field=None, filter=None, **extra)
|
||||
|
||||
Returns the mean value of the given expression, which must be numeric
|
||||
unless you specify a different ``output_field``.
|
||||
|
||||
* Default alias: ``<field>__avg``
|
||||
* Return type: ``float`` (or the type of whatever ``output_field`` is
|
||||
specified)
|
||||
* Return type: ``float`` if input is ``int``, otherwise same as input
|
||||
field, or ``output_field`` if supplied
|
||||
|
||||
``Count``
|
||||
~~~~~~~~~
|
||||
|
|
|
@ -493,6 +493,9 @@ Miscellaneous
|
|||
|
||||
* :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman).
|
||||
|
||||
* The :class:`~django.db.models.Avg` aggregate function now returns a
|
||||
``Decimal`` instead of a ``float`` when the input is ``Decimal``.
|
||||
|
||||
.. _deprecated-features-2.2:
|
||||
|
||||
Features deprecated in 2.2
|
||||
|
|
|
@ -865,15 +865,15 @@ class AggregateTestCase(TestCase):
|
|||
|
||||
def test_avg_decimal_field(self):
|
||||
v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price']
|
||||
self.assertIsInstance(v, float)
|
||||
self.assertEqual(v, Approximate(47.39, places=2))
|
||||
self.assertIsInstance(v, Decimal)
|
||||
self.assertEqual(v, Approximate(Decimal('47.39'), places=2))
|
||||
|
||||
def test_order_of_precedence(self):
|
||||
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
|
||||
self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)})
|
||||
self.assertEqual(p1, {'avg_price': Approximate(Decimal('148.18'), places=2)})
|
||||
|
||||
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
|
||||
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
|
||||
self.assertEqual(p2, {'avg_price': Approximate(Decimal('53.39'), places=2)})
|
||||
|
||||
def test_combine_different_types(self):
|
||||
msg = 'Expression contains mixed types. You must set output_field.'
|
||||
|
@ -1087,7 +1087,7 @@ class AggregateTestCase(TestCase):
|
|||
return super().as_sql(compiler, connection, function='MAX', **extra_context)
|
||||
|
||||
qs = Publisher.objects.annotate(
|
||||
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
|
||||
price_or_median=Greatest(Avg('book__rating', output_field=DecimalField()), Avg('book__price'))
|
||||
).filter(price_or_median__gte=F('num_awards')).order_by('num_awards')
|
||||
self.assertQuerysetEqual(
|
||||
qs, [1, 3, 7, 9], lambda v: v.num_awards)
|
||||
|
|
|
@ -401,7 +401,7 @@ class AggregationTests(TestCase):
|
|||
When(pages__lt=400, then='discount_price'),
|
||||
output_field=DecimalField()
|
||||
)))['test'],
|
||||
22.27, places=2
|
||||
Decimal('22.27'), places=2
|
||||
)
|
||||
|
||||
def test_distinct_conditional_aggregate(self):
|
||||
|
@ -1041,7 +1041,7 @@ class AggregationTests(TestCase):
|
|||
books = Book.objects.values_list("publisher__name").annotate(
|
||||
Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages")
|
||||
).order_by("-publisher__name")
|
||||
self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0))
|
||||
self.assertEqual(books[0], ('Sams', 1, Decimal('23.09'), 45.0, 528.0))
|
||||
|
||||
def test_annotation_disjunction(self):
|
||||
qs = Book.objects.annotate(n_authors=Count("authors")).filter(
|
||||
|
|
Loading…
Reference in New Issue