django1/django/contrib/postgres/forms/ranges.py

93 lines
3.0 KiB
Python

from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
from django import forms
from django.core import exceptions
from django.forms.widgets import MultiWidget
from django.utils.translation import gettext_lazy as _
__all__ = ['IntegerRangeField', 'FloatRangeField', 'DateTimeRangeField', 'DateRangeField']
class BaseRangeField(forms.MultiValueField):
default_error_messages = {
'invalid': _('Enter two valid values.'),
'bound_ordering': _('The start of the range must not exceed the end of the range.'),
}
def __init__(self, **kwargs):
if 'widget' not in kwargs:
kwargs['widget'] = RangeWidget(self.base_field.widget)
if 'fields' not in kwargs:
kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)]
kwargs.setdefault('required', False)
kwargs.setdefault('require_all_fields', False)
super().__init__(**kwargs)
def prepare_value(self, value):
lower_base, upper_base = self.fields
if isinstance(value, self.range_type):
return [
lower_base.prepare_value(value.lower),
upper_base.prepare_value(value.upper),
]
if value is None:
return [
lower_base.prepare_value(None),
upper_base.prepare_value(None),
]
return value
def compress(self, values):
if not values:
return None
lower, upper = values
if lower is not None and upper is not None and lower > upper:
raise exceptions.ValidationError(
self.error_messages['bound_ordering'],
code='bound_ordering',
)
try:
range_value = self.range_type(lower, upper)
except TypeError:
raise exceptions.ValidationError(
self.error_messages['invalid'],
code='invalid',
)
else:
return range_value
class IntegerRangeField(BaseRangeField):
default_error_messages = {'invalid': _('Enter two whole numbers.')}
base_field = forms.IntegerField
range_type = NumericRange
class FloatRangeField(BaseRangeField):
default_error_messages = {'invalid': _('Enter two numbers.')}
base_field = forms.FloatField
range_type = NumericRange
class DateTimeRangeField(BaseRangeField):
default_error_messages = {'invalid': _('Enter two valid date/times.')}
base_field = forms.DateTimeField
range_type = DateTimeTZRange
class DateRangeField(BaseRangeField):
default_error_messages = {'invalid': _('Enter two valid dates.')}
base_field = forms.DateField
range_type = DateRange
class RangeWidget(MultiWidget):
def __init__(self, base_widget, attrs=None):
widgets = (base_widget, base_widget)
super().__init__(widgets, attrs)
def decompress(self, value):
if value:
return (value.lower, value.upper)
return (None, None)