252 lines
8.2 KiB
Python
252 lines
8.2 KiB
Python
import copy
|
|
from itertools import chain
|
|
|
|
from django import forms
|
|
from django.contrib.postgres.validators import (
|
|
ArrayMaxLengthValidator,
|
|
ArrayMinLengthValidator,
|
|
)
|
|
from django.core.exceptions import ValidationError
|
|
from django.utils.translation import gettext_lazy as _
|
|
|
|
from ..utils import prefix_validation_error
|
|
|
|
|
|
class SimpleArrayField(forms.CharField):
|
|
default_error_messages = {
|
|
"item_invalid": _("Item %(nth)s in the array did not validate:"),
|
|
}
|
|
|
|
def __init__(
|
|
self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs
|
|
):
|
|
self.base_field = base_field
|
|
self.delimiter = delimiter
|
|
super().__init__(**kwargs)
|
|
if min_length is not None:
|
|
self.min_length = min_length
|
|
self.validators.append(ArrayMinLengthValidator(int(min_length)))
|
|
if max_length is not None:
|
|
self.max_length = max_length
|
|
self.validators.append(ArrayMaxLengthValidator(int(max_length)))
|
|
|
|
def clean(self, value):
|
|
value = super().clean(value)
|
|
return [self.base_field.clean(val) for val in value]
|
|
|
|
def prepare_value(self, value):
|
|
if isinstance(value, list):
|
|
return self.delimiter.join(
|
|
str(self.base_field.prepare_value(v)) for v in value
|
|
)
|
|
return value
|
|
|
|
def to_python(self, value):
|
|
if isinstance(value, list):
|
|
items = value
|
|
elif value:
|
|
items = value.split(self.delimiter)
|
|
else:
|
|
items = []
|
|
errors = []
|
|
values = []
|
|
for index, item in enumerate(items):
|
|
try:
|
|
values.append(self.base_field.to_python(item))
|
|
except ValidationError as error:
|
|
errors.append(
|
|
prefix_validation_error(
|
|
error,
|
|
prefix=self.error_messages["item_invalid"],
|
|
code="item_invalid",
|
|
params={"nth": index + 1},
|
|
)
|
|
)
|
|
if errors:
|
|
raise ValidationError(errors)
|
|
return values
|
|
|
|
def validate(self, value):
|
|
super().validate(value)
|
|
errors = []
|
|
for index, item in enumerate(value):
|
|
try:
|
|
self.base_field.validate(item)
|
|
except ValidationError as error:
|
|
errors.append(
|
|
prefix_validation_error(
|
|
error,
|
|
prefix=self.error_messages["item_invalid"],
|
|
code="item_invalid",
|
|
params={"nth": index + 1},
|
|
)
|
|
)
|
|
if errors:
|
|
raise ValidationError(errors)
|
|
|
|
def run_validators(self, value):
|
|
super().run_validators(value)
|
|
errors = []
|
|
for index, item in enumerate(value):
|
|
try:
|
|
self.base_field.run_validators(item)
|
|
except ValidationError as error:
|
|
errors.append(
|
|
prefix_validation_error(
|
|
error,
|
|
prefix=self.error_messages["item_invalid"],
|
|
code="item_invalid",
|
|
params={"nth": index + 1},
|
|
)
|
|
)
|
|
if errors:
|
|
raise ValidationError(errors)
|
|
|
|
def has_changed(self, initial, data):
|
|
try:
|
|
value = self.to_python(data)
|
|
except ValidationError:
|
|
pass
|
|
else:
|
|
if initial in self.empty_values and value in self.empty_values:
|
|
return False
|
|
return super().has_changed(initial, data)
|
|
|
|
|
|
class SplitArrayWidget(forms.Widget):
|
|
template_name = "postgres/widgets/split_array.html"
|
|
|
|
def __init__(self, widget, size, **kwargs):
|
|
self.widget = widget() if isinstance(widget, type) else widget
|
|
self.size = size
|
|
super().__init__(**kwargs)
|
|
|
|
@property
|
|
def is_hidden(self):
|
|
return self.widget.is_hidden
|
|
|
|
def value_from_datadict(self, data, files, name):
|
|
return [
|
|
self.widget.value_from_datadict(data, files, "%s_%s" % (name, index))
|
|
for index in range(self.size)
|
|
]
|
|
|
|
def value_omitted_from_data(self, data, files, name):
|
|
return all(
|
|
self.widget.value_omitted_from_data(data, files, "%s_%s" % (name, index))
|
|
for index in range(self.size)
|
|
)
|
|
|
|
def id_for_label(self, id_):
|
|
# See the comment for RadioSelect.id_for_label()
|
|
if id_:
|
|
id_ += "_0"
|
|
return id_
|
|
|
|
def get_context(self, name, value, attrs=None):
|
|
attrs = {} if attrs is None else attrs
|
|
context = super().get_context(name, value, attrs)
|
|
if self.is_localized:
|
|
self.widget.is_localized = self.is_localized
|
|
value = value or []
|
|
context["widget"]["subwidgets"] = []
|
|
final_attrs = self.build_attrs(attrs)
|
|
id_ = final_attrs.get("id")
|
|
for i in range(max(len(value), self.size)):
|
|
try:
|
|
widget_value = value[i]
|
|
except IndexError:
|
|
widget_value = None
|
|
if id_:
|
|
final_attrs = {**final_attrs, "id": "%s_%s" % (id_, i)}
|
|
context["widget"]["subwidgets"].append(
|
|
self.widget.get_context(name + "_%s" % i, widget_value, final_attrs)[
|
|
"widget"
|
|
]
|
|
)
|
|
return context
|
|
|
|
@property
|
|
def media(self):
|
|
return self.widget.media
|
|
|
|
def __deepcopy__(self, memo):
|
|
obj = super().__deepcopy__(memo)
|
|
obj.widget = copy.deepcopy(self.widget)
|
|
return obj
|
|
|
|
@property
|
|
def needs_multipart_form(self):
|
|
return self.widget.needs_multipart_form
|
|
|
|
|
|
class SplitArrayField(forms.Field):
|
|
default_error_messages = {
|
|
"item_invalid": _("Item %(nth)s in the array did not validate:"),
|
|
}
|
|
|
|
def __init__(self, base_field, size, *, remove_trailing_nulls=False, **kwargs):
|
|
self.base_field = base_field
|
|
self.size = size
|
|
self.remove_trailing_nulls = remove_trailing_nulls
|
|
widget = SplitArrayWidget(widget=base_field.widget, size=size)
|
|
kwargs.setdefault("widget", widget)
|
|
super().__init__(**kwargs)
|
|
|
|
def _remove_trailing_nulls(self, values):
|
|
index = None
|
|
if self.remove_trailing_nulls:
|
|
for i, value in reversed(list(enumerate(values))):
|
|
if value in self.base_field.empty_values:
|
|
index = i
|
|
else:
|
|
break
|
|
if index is not None:
|
|
values = values[:index]
|
|
return values, index
|
|
|
|
def to_python(self, value):
|
|
value = super().to_python(value)
|
|
return [self.base_field.to_python(item) for item in value]
|
|
|
|
def clean(self, value):
|
|
cleaned_data = []
|
|
errors = []
|
|
if not any(value) and self.required:
|
|
raise ValidationError(self.error_messages["required"])
|
|
max_size = max(self.size, len(value))
|
|
for index in range(max_size):
|
|
item = value[index]
|
|
try:
|
|
cleaned_data.append(self.base_field.clean(item))
|
|
except ValidationError as error:
|
|
errors.append(
|
|
prefix_validation_error(
|
|
error,
|
|
self.error_messages["item_invalid"],
|
|
code="item_invalid",
|
|
params={"nth": index + 1},
|
|
)
|
|
)
|
|
cleaned_data.append(None)
|
|
else:
|
|
errors.append(None)
|
|
cleaned_data, null_index = self._remove_trailing_nulls(cleaned_data)
|
|
if null_index is not None:
|
|
errors = errors[:null_index]
|
|
errors = list(filter(None, errors))
|
|
if errors:
|
|
raise ValidationError(list(chain.from_iterable(errors)))
|
|
return cleaned_data
|
|
|
|
def has_changed(self, initial, data):
|
|
try:
|
|
data = self.to_python(data)
|
|
except ValidationError:
|
|
pass
|
|
else:
|
|
data, _ = self._remove_trailing_nulls(data)
|
|
if initial in self.empty_values and data in self.empty_values:
|
|
return False
|
|
return super().has_changed(initial, data)
|