Fixed #31487 -- Added precision argument to Round().

This commit is contained in:
Nick Pope 2021-03-24 22:29:33 +00:00 committed by Mariusz Felisiak
parent 61d5e57353
commit 2f13c476ab
7 changed files with 115 additions and 10 deletions

View File

@ -65,6 +65,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"SQLite doesn't have a constraint.": {
'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
},
"SQLite doesn't support negative precision for ROUND().": {
'db_functions.math.test_round.RoundTests.test_null_with_negative_precision',
'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision',
'db_functions.math.test_round.RoundTests.test_float_with_negative_precision',
'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision',
},
}
if Database.sqlite_version_info < (3, 27):
skips.update({

View File

@ -1,6 +1,6 @@
import math
from django.db.models.expressions import Func
from django.db.models.expressions import Func, Value
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
@ -158,9 +158,23 @@ class Random(NumericOutputFieldMixin, Func):
return []
class Round(Transform):
class Round(FixDecimalInputMixin, Transform):
function = 'ROUND'
lookup_name = 'round'
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
super().__init__(expression, precision, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
raise ValueError('SQLite does not support negative precision.')
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
source = self.get_source_expressions()[0]
return source.output_field
class Sign(Transform):

View File

@ -1147,18 +1147,19 @@ Returns a random value in the range ``0.0 ≤ x < 1.0``.
``Round``
---------
.. class:: Round(expression, **extra)
.. class:: Round(expression, precision=0, **extra)
Rounds a numeric field or expression to the nearest integer. Whether half
Rounds a numeric field or expression to ``precision`` (must be an integer)
decimal places. By default, it rounds to the nearest integer. Whether half
values are rounded up or down depends on the database.
Usage example::
>>> from django.db.models.functions import Round
>>> Vector.objects.create(x=5.4, y=-2.3)
>>> vector = Vector.objects.annotate(x_r=Round('x'), y_r=Round('y')).get()
>>> Vector.objects.create(x=5.4, y=-2.37)
>>> vector = Vector.objects.annotate(x_r=Round('x'), y_r=Round('y', precision=1)).get()
>>> vector.x_r, vector.y_r
(5.0, -2.0)
(5.0, -2.4)
It can also be registered as a transform. For example::
@ -1168,6 +1169,10 @@ It can also be registered as a transform. For example::
>>> # Get vectors whose round() is less than 20
>>> vectors = Vector.objects.filter(x__round__lt=20, y__round__lt=20)
.. versionchanged:: 4.0
The ``precision`` argument was added.
``Sign``
--------

View File

@ -222,6 +222,10 @@ Models
whether the queryset contains the given object. This tries to perform the
query in the simplest and fastest way possible.
* The new ``precision`` argument of the
:class:`Round() <django.db.models.functions.Round>` database function allows
specifying the number of decimal places after rounding.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1,7 +1,9 @@
import unittest
from decimal import Decimal
from django.db import connection
from django.db.models import DecimalField
from django.db.models.functions import Round
from django.db.models.functions import Pi, Round
from django.test import TestCase
from django.test.utils import register_lookup
@ -15,6 +17,16 @@ class RoundTests(TestCase):
obj = IntegerModel.objects.annotate(null_round=Round('normal')).first()
self.assertIsNone(obj.null_round)
def test_null_with_precision(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_round=Round('normal', 5)).first()
self.assertIsNone(obj.null_round)
def test_null_with_negative_precision(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_round=Round('normal', -1)).first()
self.assertIsNone(obj.null_round)
def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_round=Round('n1'), n2_round=Round('n2')).first()
@ -23,6 +35,23 @@ class RoundTests(TestCase):
self.assertAlmostEqual(obj.n1_round, obj.n1, places=0)
self.assertAlmostEqual(obj.n2_round, obj.n2, places=0)
def test_decimal_with_precision(self):
DecimalModel.objects.create(n1=Decimal('-5.75'), n2=Pi())
obj = DecimalModel.objects.annotate(
n1_round=Round('n1', 1),
n2_round=Round('n2', 5),
).first()
self.assertIsInstance(obj.n1_round, Decimal)
self.assertIsInstance(obj.n2_round, Decimal)
self.assertAlmostEqual(obj.n1_round, obj.n1, places=1)
self.assertAlmostEqual(obj.n2_round, obj.n2, places=5)
def test_decimal_with_negative_precision(self):
DecimalModel.objects.create(n1=Decimal('365.25'))
obj = DecimalModel.objects.annotate(n1_round=Round('n1', -1)).first()
self.assertIsInstance(obj.n1_round, Decimal)
self.assertEqual(obj.n1_round, 370)
def test_float(self):
FloatModel.objects.create(f1=-27.55, f2=0.55)
obj = FloatModel.objects.annotate(f1_round=Round('f1'), f2_round=Round('f2')).first()
@ -31,6 +60,23 @@ class RoundTests(TestCase):
self.assertAlmostEqual(obj.f1_round, obj.f1, places=0)
self.assertAlmostEqual(obj.f2_round, obj.f2, places=0)
def test_float_with_precision(self):
FloatModel.objects.create(f1=-5.75, f2=Pi())
obj = FloatModel.objects.annotate(
f1_round=Round('f1', 1),
f2_round=Round('f2', 5),
).first()
self.assertIsInstance(obj.f1_round, float)
self.assertIsInstance(obj.f2_round, float)
self.assertAlmostEqual(obj.f1_round, obj.f1, places=1)
self.assertAlmostEqual(obj.f2_round, obj.f2, places=5)
def test_float_with_negative_precision(self):
FloatModel.objects.create(f1=365.25)
obj = FloatModel.objects.annotate(f1_round=Round('f1', -1)).first()
self.assertIsInstance(obj.f1_round, float)
self.assertEqual(obj.f1_round, 370)
def test_integer(self):
IntegerModel.objects.create(small=-20, normal=15, big=-1)
obj = IntegerModel.objects.annotate(
@ -45,9 +91,39 @@ class RoundTests(TestCase):
self.assertAlmostEqual(obj.normal_round, obj.normal, places=0)
self.assertAlmostEqual(obj.big_round, obj.big, places=0)
def test_integer_with_precision(self):
IntegerModel.objects.create(small=-5, normal=3, big=-100)
obj = IntegerModel.objects.annotate(
small_round=Round('small', 1),
normal_round=Round('normal', 5),
big_round=Round('big', 2),
).first()
self.assertIsInstance(obj.small_round, int)
self.assertIsInstance(obj.normal_round, int)
self.assertIsInstance(obj.big_round, int)
self.assertAlmostEqual(obj.small_round, obj.small, places=1)
self.assertAlmostEqual(obj.normal_round, obj.normal, places=5)
self.assertAlmostEqual(obj.big_round, obj.big, places=2)
def test_integer_with_negative_precision(self):
IntegerModel.objects.create(normal=365)
obj = IntegerModel.objects.annotate(normal_round=Round('normal', -1)).first()
self.assertIsInstance(obj.normal_round, int)
self.assertEqual(obj.normal_round, 370)
def test_transform(self):
with register_lookup(DecimalField, Round):
DecimalModel.objects.create(n1=Decimal('2.0'), n2=Decimal('0'))
DecimalModel.objects.create(n1=Decimal('-1.0'), n2=Decimal('0'))
obj = DecimalModel.objects.filter(n1__round__gt=0).get()
self.assertEqual(obj.n1, Decimal('2.0'))
@unittest.skipUnless(
connection.vendor == 'sqlite',
"SQLite doesn't support negative precision.",
)
def test_unsupported_negative_precision(self):
FloatModel.objects.create(f1=123.45)
msg = 'SQLite does not support negative precision.'
with self.assertRaisesMessage(ValueError, msg):
FloatModel.objects.annotate(value=Round('f1', -1)).first()

View File

@ -56,7 +56,7 @@ class Migration(migrations.Migration):
name='DecimalModel',
fields=[
('n1', models.DecimalField(decimal_places=2, max_digits=6)),
('n2', models.DecimalField(decimal_places=2, max_digits=6)),
('n2', models.DecimalField(decimal_places=7, max_digits=9, null=True, blank=True)),
],
),
migrations.CreateModel(

View File

@ -42,7 +42,7 @@ class DTModel(models.Model):
class DecimalModel(models.Model):
n1 = models.DecimalField(decimal_places=2, max_digits=6)
n2 = models.DecimalField(decimal_places=2, max_digits=6)
n2 = models.DecimalField(decimal_places=7, max_digits=9, null=True, blank=True)
class IntegerModel(models.Model):