Fixed #24636 -- Added model field validation for decimal places and max digits.

This commit is contained in:
Iulia Chiriac 2015-04-14 18:11:12 -04:00 committed by Simon Charette
parent 6f1b09bb5c
commit 75ed590032
6 changed files with 134 additions and 49 deletions

View File

@ -346,3 +346,72 @@ class MaxLengthValidator(BaseValidator):
'Ensure this value has at most %(limit_value)d characters (it has %(show_value)d).',
'limit_value')
code = 'max_length'
@deconstructible
class DecimalValidator(object):
"""
Validate that the input does not exceed the maximum number of digits
expected, otherwise raise ValidationError.
"""
messages = {
'max_digits': ungettext_lazy(
'Ensure that there are no more than %(max)s digit in total.',
'Ensure that there are no more than %(max)s digits in total.',
'max'
),
'max_decimal_places': ungettext_lazy(
'Ensure that there are no more than %(max)s decimal place.',
'Ensure that there are no more than %(max)s decimal places.',
'max'
),
'max_whole_digits': ungettext_lazy(
'Ensure that there are no more than %(max)s digit before the decimal point.',
'Ensure that there are no more than %(max)s digits before the decimal point.',
'max'
),
}
def __init__(self, max_digits, decimal_places):
self.max_digits = max_digits
self.decimal_places = decimal_places
def __call__(self, value):
digit_tuple, exponent = value.as_tuple()[1:]
decimals = abs(exponent)
# digit_tuple doesn't include any leading zeros.
digits = len(digit_tuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
raise ValidationError(
self.messages['max_digits'],
code='max_digits',
params={'max': self.max_digits},
)
if self.decimal_places is not None and decimals > self.decimal_places:
raise ValidationError(
self.messages['max_decimal_places'],
code='max_decimal_places',
params={'max': self.decimal_places},
)
if (self.max_digits is not None and self.decimal_places is not None
and whole_digits > (self.max_digits - self.decimal_places)):
raise ValidationError(
self.messages['max_whole_digits'],
code='max_whole_digits',
params={'max': (self.max_digits - self.decimal_places)},
)
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.max_digits == other.max_digits and
self.decimal_places == other.decimal_places
)

View File

@ -1578,6 +1578,12 @@ class DecimalField(Field):
]
return []
@cached_property
def validators(self):
return super(DecimalField, self).validators + [
validators.DecimalValidator(self.max_digits, self.decimal_places)
]
def deconstruct(self):
name, path, args, kwargs = super(DecimalField, self).deconstruct()
if self.max_digits is not None:

View File

@ -334,23 +334,12 @@ class FloatField(IntegerField):
class DecimalField(IntegerField):
default_error_messages = {
'invalid': _('Enter a number.'),
'max_digits': ungettext_lazy(
'Ensure that there are no more than %(max)s digit in total.',
'Ensure that there are no more than %(max)s digits in total.',
'max'),
'max_decimal_places': ungettext_lazy(
'Ensure that there are no more than %(max)s decimal place.',
'Ensure that there are no more than %(max)s decimal places.',
'max'),
'max_whole_digits': ungettext_lazy(
'Ensure that there are no more than %(max)s digit before the decimal point.',
'Ensure that there are no more than %(max)s digits before the decimal point.',
'max'),
}
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
self.max_digits, self.decimal_places = max_digits, decimal_places
super(DecimalField, self).__init__(max_value, min_value, *args, **kwargs)
self.validators.append(validators.DecimalValidator(max_digits, decimal_places))
def to_python(self, value):
"""
@ -379,38 +368,6 @@ class DecimalField(IntegerField):
# isn't equal to itself, so we can use this to identify NaN
if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
raise ValidationError(self.error_messages['invalid'], code='invalid')
sign, digittuple, exponent = value.as_tuple()
decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
raise ValidationError(
self.error_messages['max_digits'],
code='max_digits',
params={'max': self.max_digits},
)
if self.decimal_places is not None and decimals > self.decimal_places:
raise ValidationError(
self.error_messages['max_decimal_places'],
code='max_decimal_places',
params={'max': self.decimal_places},
)
if (self.max_digits is not None and self.decimal_places is not None
and whole_digits > (self.max_digits - self.decimal_places)):
raise ValidationError(
self.error_messages['max_whole_digits'],
code='max_whole_digits',
params={'max': (self.max_digits - self.decimal_places)},
)
return value
def widget_attrs(self, widget):
attrs = super(DecimalField, self).widget_attrs(widget)

View File

@ -281,3 +281,19 @@ to, or in lieu of custom ``field.clean()`` methods.
.. versionchanged:: 1.8
The ``message`` parameter was added.
``DecimalValidator``
--------------------
.. class:: DecimalValidator(max_digits, decimal_places)
.. versionadded:: 1.9
Raises :exc:`~django.core.exceptions.ValidationError` with the following
codes:
- ``'max_digits'`` if the number of digits is larger than ``max_digits``.
- ``'max_decimal_places'`` if the number of decimals is larger than
``decimal_places``.
- ``'max_whole_digits'`` if the number of whole digits is larger than
the difference between ``max_digits`` and ``decimal_places``.

View File

@ -165,6 +165,24 @@ class DecimalFieldTests(test.TestCase):
# This should not crash. That counts as a win for our purposes.
Foo.objects.filter(d__gte=100000000000)
def test_max_digits_validation(self):
field = models.DecimalField(max_digits=2)
expected_message = validators.DecimalValidator.messages['max_digits'] % {'max': 2}
with self.assertRaisesMessage(ValidationError, expected_message):
field.clean(100, None)
def test_max_decimal_places_validation(self):
field = models.DecimalField(decimal_places=1)
expected_message = validators.DecimalValidator.messages['max_decimal_places'] % {'max': 1}
with self.assertRaisesMessage(ValidationError, expected_message):
field.clean(Decimal('0.99'), None)
def test_max_whole_digits_validation(self):
field = models.DecimalField(max_digits=3, decimal_places=1)
expected_message = validators.DecimalValidator.messages['max_whole_digits'] % {'max': 2}
with self.assertRaisesMessage(ValidationError, expected_message):
field.clean(Decimal('999'), None)
class ForeignKeyTests(test.TestCase):
def test_callable_default(self):

View File

@ -10,11 +10,12 @@ from unittest import TestCase
from django.core.exceptions import ValidationError
from django.core.validators import (
BaseValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator,
int_list_validator, validate_comma_separated_integer_list, validate_email,
validate_integer, validate_ipv4_address, validate_ipv6_address,
validate_ipv46_address, validate_slug, validate_unicode_slug,
BaseValidator, DecimalValidator, EmailValidator, MaxLengthValidator,
MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator,
URLValidator, int_list_validator, validate_comma_separated_integer_list,
validate_email, validate_integer, validate_ipv4_address,
validate_ipv6_address, validate_ipv46_address, validate_slug,
validate_unicode_slug,
)
from django.test import SimpleTestCase
from django.test.utils import str_prefix
@ -401,3 +402,21 @@ class TestValidatorEquality(TestCase):
MinValueValidator(45),
MinValueValidator(11),
)
def test_decimal_equality(self):
self.assertEqual(
DecimalValidator(1, 2),
DecimalValidator(1, 2),
)
self.assertNotEqual(
DecimalValidator(1, 2),
DecimalValidator(1, 1),
)
self.assertNotEqual(
DecimalValidator(1, 2),
DecimalValidator(2, 2),
)
self.assertNotEqual(
DecimalValidator(1, 2),
MinValueValidator(11),
)