Refs #28643 -- Changed Avg() to use NumericOutputFieldMixin.
Keeps precision instead of forcing DecimalField to FloatField.
This commit is contained in:
parent
3d5e0f8394
commit
c690afb873
|
@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
|
||||||
"""
|
"""
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models.expressions import Case, Func, Star, When
|
from django.db.models.expressions import Case, Func, Star, When
|
||||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
from django.db.models.fields import FloatField, IntegerField
|
||||||
|
from django.db.models.functions.mixins import NumericOutputFieldMixin
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||||
|
@ -93,16 +94,10 @@ class Aggregate(Func):
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
class Avg(Aggregate):
|
class Avg(NumericOutputFieldMixin, Aggregate):
|
||||||
function = 'AVG'
|
function = 'AVG'
|
||||||
name = 'Avg'
|
name = 'Avg'
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
|
||||||
source_field = self.get_source_fields()[0]
|
|
||||||
if isinstance(source_field, (IntegerField, DecimalField)):
|
|
||||||
return FloatField()
|
|
||||||
return super()._resolve_output_field()
|
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection, **extra_context):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from django.db.models.fields import DecimalField, FloatField
|
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||||
from django.db.models.functions import Cast
|
from django.db.models.functions import Cast
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,5 +23,9 @@ class FixDecimalInputMixin:
|
||||||
class NumericOutputFieldMixin:
|
class NumericOutputFieldMixin:
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
|
source_expressions = self.get_source_expressions()
|
||||||
return DecimalField() if has_decimals else FloatField()
|
if any(isinstance(s.output_field, DecimalField) for s in source_expressions):
|
||||||
|
return DecimalField()
|
||||||
|
if any(isinstance(s.output_field, IntegerField) for s in source_expressions):
|
||||||
|
return FloatField()
|
||||||
|
return super()._resolve_output_field() if source_expressions else FloatField()
|
||||||
|
|
|
@ -3336,14 +3336,14 @@ by the aggregate.
|
||||||
``Avg``
|
``Avg``
|
||||||
~~~~~~~
|
~~~~~~~
|
||||||
|
|
||||||
.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
|
.. class:: Avg(expression, output_field=None, filter=None, **extra)
|
||||||
|
|
||||||
Returns the mean value of the given expression, which must be numeric
|
Returns the mean value of the given expression, which must be numeric
|
||||||
unless you specify a different ``output_field``.
|
unless you specify a different ``output_field``.
|
||||||
|
|
||||||
* Default alias: ``<field>__avg``
|
* Default alias: ``<field>__avg``
|
||||||
* Return type: ``float`` (or the type of whatever ``output_field`` is
|
* Return type: ``float`` if input is ``int``, otherwise same as input
|
||||||
specified)
|
field, or ``output_field`` if supplied
|
||||||
|
|
||||||
``Count``
|
``Count``
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
|
@ -493,6 +493,9 @@ Miscellaneous
|
||||||
|
|
||||||
* :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman).
|
* :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman).
|
||||||
|
|
||||||
|
* The :class:`~django.db.models.Avg` aggregate function now returns a
|
||||||
|
``Decimal`` instead of a ``float`` when the input is ``Decimal``.
|
||||||
|
|
||||||
.. _deprecated-features-2.2:
|
.. _deprecated-features-2.2:
|
||||||
|
|
||||||
Features deprecated in 2.2
|
Features deprecated in 2.2
|
||||||
|
|
|
@ -865,15 +865,15 @@ class AggregateTestCase(TestCase):
|
||||||
|
|
||||||
def test_avg_decimal_field(self):
|
def test_avg_decimal_field(self):
|
||||||
v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price']
|
v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price']
|
||||||
self.assertIsInstance(v, float)
|
self.assertIsInstance(v, Decimal)
|
||||||
self.assertEqual(v, Approximate(47.39, places=2))
|
self.assertEqual(v, Approximate(Decimal('47.39'), places=2))
|
||||||
|
|
||||||
def test_order_of_precedence(self):
|
def test_order_of_precedence(self):
|
||||||
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
|
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
|
||||||
self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)})
|
self.assertEqual(p1, {'avg_price': Approximate(Decimal('148.18'), places=2)})
|
||||||
|
|
||||||
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
|
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
|
||||||
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
|
self.assertEqual(p2, {'avg_price': Approximate(Decimal('53.39'), places=2)})
|
||||||
|
|
||||||
def test_combine_different_types(self):
|
def test_combine_different_types(self):
|
||||||
msg = 'Expression contains mixed types. You must set output_field.'
|
msg = 'Expression contains mixed types. You must set output_field.'
|
||||||
|
@ -1087,7 +1087,7 @@ class AggregateTestCase(TestCase):
|
||||||
return super().as_sql(compiler, connection, function='MAX', **extra_context)
|
return super().as_sql(compiler, connection, function='MAX', **extra_context)
|
||||||
|
|
||||||
qs = Publisher.objects.annotate(
|
qs = Publisher.objects.annotate(
|
||||||
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
|
price_or_median=Greatest(Avg('book__rating', output_field=DecimalField()), Avg('book__price'))
|
||||||
).filter(price_or_median__gte=F('num_awards')).order_by('num_awards')
|
).filter(price_or_median__gte=F('num_awards')).order_by('num_awards')
|
||||||
self.assertQuerysetEqual(
|
self.assertQuerysetEqual(
|
||||||
qs, [1, 3, 7, 9], lambda v: v.num_awards)
|
qs, [1, 3, 7, 9], lambda v: v.num_awards)
|
||||||
|
|
|
@ -401,7 +401,7 @@ class AggregationTests(TestCase):
|
||||||
When(pages__lt=400, then='discount_price'),
|
When(pages__lt=400, then='discount_price'),
|
||||||
output_field=DecimalField()
|
output_field=DecimalField()
|
||||||
)))['test'],
|
)))['test'],
|
||||||
22.27, places=2
|
Decimal('22.27'), places=2
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_distinct_conditional_aggregate(self):
|
def test_distinct_conditional_aggregate(self):
|
||||||
|
@ -1041,7 +1041,7 @@ class AggregationTests(TestCase):
|
||||||
books = Book.objects.values_list("publisher__name").annotate(
|
books = Book.objects.values_list("publisher__name").annotate(
|
||||||
Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages")
|
Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages")
|
||||||
).order_by("-publisher__name")
|
).order_by("-publisher__name")
|
||||||
self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0))
|
self.assertEqual(books[0], ('Sams', 1, Decimal('23.09'), 45.0, 528.0))
|
||||||
|
|
||||||
def test_annotation_disjunction(self):
|
def test_annotation_disjunction(self):
|
||||||
qs = Book.objects.annotate(n_authors=Count("authors")).filter(
|
qs = Book.objects.annotate(n_authors=Count("authors")).filter(
|
||||||
|
|
Loading…
Reference in New Issue