2014-03-27 00:44:21 +08:00
|
|
|
import json
|
|
|
|
|
2015-01-11 00:11:15 +08:00
|
|
|
from django.contrib.postgres import lookups
|
2014-03-27 00:44:21 +08:00
|
|
|
from django.contrib.postgres.forms import SimpleArrayField
|
|
|
|
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
|
|
|
from django.core import checks, exceptions
|
2020-11-03 03:38:34 +08:00
|
|
|
from django.db.models import Field, Func, IntegerField, Transform, Value
|
2019-10-17 17:36:39 +08:00
|
|
|
from django.db.models.fields.mixins import CheckFieldDefaultMixin
|
2015-11-07 23:06:06 +08:00
|
|
|
from django.db.models.lookups import Exact, In
|
2017-01-27 03:58:33 +08:00
|
|
|
from django.utils.translation import gettext_lazy as _
|
2014-03-27 00:44:21 +08:00
|
|
|
|
2019-03-04 02:33:48 +08:00
|
|
|
from ..utils import prefix_validation_error
|
2015-06-06 19:55:04 +08:00
|
|
|
from .utils import AttributeSetter
|
2014-03-27 00:44:21 +08:00
|
|
|
|
2015-06-06 19:55:04 +08:00
|
|
|
__all__ = ['ArrayField']
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
|
2017-08-18 07:21:35 +08:00
|
|
|
class ArrayField(CheckFieldDefaultMixin, Field):
|
2014-03-27 00:44:21 +08:00
|
|
|
empty_strings_allowed = False
|
|
|
|
default_error_messages = {
|
2018-05-14 05:57:28 +08:00
|
|
|
'item_invalid': _('Item %(nth)s in the array did not validate:'),
|
2014-03-27 00:44:21 +08:00
|
|
|
'nested_array_mismatch': _('Nested arrays must have the same length.'),
|
|
|
|
}
|
2017-08-18 07:21:35 +08:00
|
|
|
_default_hint = ('list', '[]')
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
def __init__(self, base_field, size=None, **kwargs):
|
|
|
|
self.base_field = base_field
|
|
|
|
self.size = size
|
|
|
|
if self.size:
|
2018-09-28 21:57:12 +08:00
|
|
|
self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
|
2015-07-23 01:51:05 +08:00
|
|
|
# For performance, only add a from_db_value() method if the base field
|
|
|
|
# implements it.
|
|
|
|
if hasattr(self.base_field, 'from_db_value'):
|
|
|
|
self.from_db_value = self._from_db_value
|
2017-01-21 21:13:44 +08:00
|
|
|
super().__init__(**kwargs)
|
2014-03-27 00:44:21 +08:00
|
|
|
|
2015-12-05 02:14:12 +08:00
|
|
|
@property
|
|
|
|
def model(self):
|
|
|
|
try:
|
|
|
|
return self.__dict__['model']
|
|
|
|
except KeyError:
|
|
|
|
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
|
|
|
|
|
|
|
|
@model.setter
|
|
|
|
def model(self, model):
|
|
|
|
self.__dict__['model'] = model
|
|
|
|
self.base_field.model = model
|
2015-08-05 07:23:05 +08:00
|
|
|
|
2019-11-01 03:33:16 +08:00
|
|
|
@classmethod
|
|
|
|
def _choices_is_value(cls, value):
|
|
|
|
return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
|
|
|
|
|
2014-03-27 00:44:21 +08:00
|
|
|
def check(self, **kwargs):
|
2017-01-21 21:13:44 +08:00
|
|
|
errors = super().check(**kwargs)
|
2015-02-26 22:19:17 +08:00
|
|
|
if self.base_field.remote_field:
|
2014-03-27 00:44:21 +08:00
|
|
|
errors.append(
|
|
|
|
checks.Error(
|
|
|
|
'Base field for array cannot be a related field.',
|
|
|
|
obj=self,
|
|
|
|
id='postgres.E002'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Remove the field name checks as they are not needed here.
|
|
|
|
base_errors = self.base_field.check()
|
|
|
|
if base_errors:
|
|
|
|
messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
|
|
|
|
errors.append(
|
|
|
|
checks.Error(
|
|
|
|
'Base field for array has errors:\n %s' % messages,
|
|
|
|
obj=self,
|
|
|
|
id='postgres.E001'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return errors
|
|
|
|
|
|
|
|
def set_attributes_from_name(self, name):
|
2017-01-21 21:13:44 +08:00
|
|
|
super().set_attributes_from_name(name)
|
2014-03-27 00:44:21 +08:00
|
|
|
self.base_field.set_attributes_from_name(name)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def description(self):
|
|
|
|
return 'Array of %s' % self.base_field.description
|
|
|
|
|
|
|
|
def db_type(self, connection):
|
|
|
|
size = self.size or ''
|
|
|
|
return '%s[%s]' % (self.base_field.db_type(connection), size)
|
|
|
|
|
2019-08-22 19:56:18 +08:00
|
|
|
def cast_db_type(self, connection):
|
|
|
|
size = self.size or ''
|
|
|
|
return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
|
|
|
|
|
2018-07-27 23:35:54 +08:00
|
|
|
def get_placeholder(self, value, compiler, connection):
|
|
|
|
return '%s::{}'.format(self.db_type(connection))
|
|
|
|
|
2015-01-11 02:13:28 +08:00
|
|
|
def get_db_prep_value(self, value, connection, prepared=False):
|
2017-09-14 03:00:45 +08:00
|
|
|
if isinstance(value, (list, tuple)):
|
2016-03-15 16:23:44 +08:00
|
|
|
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
|
2014-03-27 00:44:21 +08:00
|
|
|
return value
|
|
|
|
|
|
|
|
def deconstruct(self):
|
2017-01-21 21:13:44 +08:00
|
|
|
name, path, args, kwargs = super().deconstruct()
|
2014-12-21 18:21:37 +08:00
|
|
|
if path == 'django.contrib.postgres.fields.array.ArrayField':
|
|
|
|
path = 'django.contrib.postgres.fields.ArrayField'
|
2014-11-23 01:57:13 +08:00
|
|
|
kwargs.update({
|
2017-01-14 19:31:34 +08:00
|
|
|
'base_field': self.base_field.clone(),
|
2014-11-23 01:57:13 +08:00
|
|
|
'size': self.size,
|
|
|
|
})
|
2014-03-27 00:44:21 +08:00
|
|
|
return name, path, args, kwargs
|
|
|
|
|
|
|
|
def to_python(self, value):
|
2016-12-29 23:27:49 +08:00
|
|
|
if isinstance(value, str):
|
2014-03-27 00:44:21 +08:00
|
|
|
# Assume we're deserializing
|
|
|
|
vals = json.loads(value)
|
|
|
|
value = [self.base_field.to_python(val) for val in vals]
|
|
|
|
return value
|
|
|
|
|
2017-07-07 01:18:05 +08:00
|
|
|
def _from_db_value(self, value, expression, connection):
|
2015-07-23 01:51:05 +08:00
|
|
|
if value is None:
|
|
|
|
return value
|
|
|
|
return [
|
2018-12-28 08:58:22 +08:00
|
|
|
self.base_field.from_db_value(item, expression, connection)
|
2015-07-23 01:51:05 +08:00
|
|
|
for item in value
|
|
|
|
]
|
|
|
|
|
2014-03-27 00:44:21 +08:00
|
|
|
def value_to_string(self, obj):
|
|
|
|
values = []
|
2015-04-26 14:30:46 +08:00
|
|
|
vals = self.value_from_object(obj)
|
2014-03-27 00:44:21 +08:00
|
|
|
base_field = self.base_field
|
|
|
|
|
|
|
|
for val in vals:
|
2016-02-16 02:28:49 +08:00
|
|
|
if val is None:
|
|
|
|
values.append(None)
|
|
|
|
else:
|
|
|
|
obj = AttributeSetter(base_field.attname, val)
|
|
|
|
values.append(base_field.value_to_string(obj))
|
2014-03-27 00:44:21 +08:00
|
|
|
return json.dumps(values)
|
|
|
|
|
|
|
|
def get_transform(self, name):
|
2017-01-21 21:13:44 +08:00
|
|
|
transform = super().get_transform(name)
|
2014-03-27 00:44:21 +08:00
|
|
|
if transform:
|
|
|
|
return transform
|
2016-09-20 03:56:53 +08:00
|
|
|
if '_' not in name:
|
|
|
|
try:
|
|
|
|
index = int(name)
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
index += 1 # postgres uses 1-indexing
|
|
|
|
return IndexTransformFactory(index, self.base_field)
|
2014-03-27 00:44:21 +08:00
|
|
|
try:
|
|
|
|
start, end = name.split('_')
|
|
|
|
start = int(start) + 1
|
|
|
|
end = int(end) # don't add one here because postgres slices are weird
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
return SliceTransformFactory(start, end)
|
|
|
|
|
|
|
|
def validate(self, value, model_instance):
|
2017-01-21 21:13:44 +08:00
|
|
|
super().validate(value, model_instance)
|
2015-12-02 06:59:58 +08:00
|
|
|
for index, part in enumerate(value):
|
2014-03-27 00:44:21 +08:00
|
|
|
try:
|
|
|
|
self.base_field.validate(part, model_instance)
|
2015-12-02 06:59:58 +08:00
|
|
|
except exceptions.ValidationError as error:
|
|
|
|
raise prefix_validation_error(
|
|
|
|
error,
|
|
|
|
prefix=self.error_messages['item_invalid'],
|
2014-03-27 00:44:21 +08:00
|
|
|
code='item_invalid',
|
2018-04-23 10:18:46 +08:00
|
|
|
params={'nth': index + 1},
|
2014-03-27 00:44:21 +08:00
|
|
|
)
|
|
|
|
if isinstance(self.base_field, ArrayField):
|
|
|
|
if len({len(i) for i in value}) > 1:
|
|
|
|
raise exceptions.ValidationError(
|
|
|
|
self.error_messages['nested_array_mismatch'],
|
|
|
|
code='nested_array_mismatch',
|
|
|
|
)
|
|
|
|
|
2015-02-20 19:33:02 +08:00
|
|
|
def run_validators(self, value):
|
2017-01-21 21:13:44 +08:00
|
|
|
super().run_validators(value)
|
2015-12-02 06:59:58 +08:00
|
|
|
for index, part in enumerate(value):
|
2015-02-20 19:33:02 +08:00
|
|
|
try:
|
|
|
|
self.base_field.run_validators(part)
|
2015-12-02 06:59:58 +08:00
|
|
|
except exceptions.ValidationError as error:
|
|
|
|
raise prefix_validation_error(
|
|
|
|
error,
|
|
|
|
prefix=self.error_messages['item_invalid'],
|
2015-02-20 19:33:02 +08:00
|
|
|
code='item_invalid',
|
2018-04-23 10:18:46 +08:00
|
|
|
params={'nth': index + 1},
|
2015-02-20 19:33:02 +08:00
|
|
|
)
|
|
|
|
|
2014-03-27 00:44:21 +08:00
|
|
|
def formfield(self, **kwargs):
|
2017-12-11 20:08:45 +08:00
|
|
|
return super().formfield(**{
|
2014-03-27 00:44:21 +08:00
|
|
|
'form_class': SimpleArrayField,
|
|
|
|
'base_field': self.base_field.formfield(),
|
|
|
|
'max_length': self.size,
|
2017-12-11 20:08:45 +08:00
|
|
|
**kwargs,
|
|
|
|
})
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
|
2020-11-03 03:38:34 +08:00
|
|
|
class ArrayRHSMixin:
|
|
|
|
def __init__(self, lhs, rhs):
|
|
|
|
if isinstance(rhs, (tuple, list)):
|
|
|
|
expressions = []
|
|
|
|
for value in rhs:
|
|
|
|
if not hasattr(value, 'resolve_expression'):
|
|
|
|
field = lhs.output_field
|
|
|
|
value = Value(field.base_field.get_prep_value(value))
|
|
|
|
expressions.append(value)
|
|
|
|
rhs = Func(
|
|
|
|
*expressions,
|
|
|
|
function='ARRAY',
|
|
|
|
template='%(function)s[%(expressions)s]',
|
|
|
|
)
|
|
|
|
super().__init__(lhs, rhs)
|
|
|
|
|
2019-08-22 17:29:56 +08:00
|
|
|
def process_rhs(self, compiler, connection):
|
|
|
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
2019-08-22 19:56:18 +08:00
|
|
|
cast_type = self.lhs.output_field.cast_db_type(connection)
|
2019-08-22 17:29:56 +08:00
|
|
|
return '%s::%s' % (rhs, cast_type), rhs_params
|
|
|
|
|
|
|
|
|
2014-10-10 00:04:50 +08:00
|
|
|
@ArrayField.register_lookup
|
2020-11-03 03:38:34 +08:00
|
|
|
class ArrayContains(ArrayRHSMixin, lookups.DataContains):
|
2019-08-22 17:29:56 +08:00
|
|
|
pass
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
|
2015-01-20 17:52:23 +08:00
|
|
|
@ArrayField.register_lookup
|
2020-11-03 03:38:34 +08:00
|
|
|
class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
|
2019-08-22 17:29:56 +08:00
|
|
|
pass
|
2015-11-07 20:52:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
@ArrayField.register_lookup
|
2020-11-03 03:38:34 +08:00
|
|
|
class ArrayExact(ArrayRHSMixin, Exact):
|
2019-08-22 17:29:56 +08:00
|
|
|
pass
|
2015-01-20 17:52:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
@ArrayField.register_lookup
|
2020-11-03 03:38:34 +08:00
|
|
|
class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
|
2019-08-22 17:29:56 +08:00
|
|
|
pass
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
|
2014-10-10 00:04:50 +08:00
|
|
|
@ArrayField.register_lookup
|
2014-03-27 00:44:21 +08:00
|
|
|
class ArrayLenTransform(Transform):
|
|
|
|
lookup_name = 'len'
|
2015-01-11 00:11:15 +08:00
|
|
|
output_field = IntegerField()
|
2014-03-27 00:44:21 +08:00
|
|
|
|
2014-11-16 09:56:42 +08:00
|
|
|
def as_sql(self, compiler, connection):
|
|
|
|
lhs, params = compiler.compile(self.lhs)
|
2015-11-22 08:59:37 +08:00
|
|
|
# Distinguish NULL and empty arrays
|
|
|
|
return (
|
|
|
|
'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
|
|
|
|
'coalesce(array_length(%(lhs)s, 1), 0) END'
|
|
|
|
) % {'lhs': lhs}, params
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
|
2015-11-07 23:06:06 +08:00
|
|
|
@ArrayField.register_lookup
|
|
|
|
class ArrayInLookup(In):
|
|
|
|
def get_prep_lookup(self):
|
2017-01-21 21:13:44 +08:00
|
|
|
values = super().get_prep_lookup()
|
2019-03-06 15:05:32 +08:00
|
|
|
if hasattr(values, 'resolve_expression'):
|
2017-10-28 05:38:43 +08:00
|
|
|
return values
|
2015-11-07 23:06:06 +08:00
|
|
|
# In.process_rhs() expects values to be hashable, so convert lists
|
|
|
|
# to tuples.
|
2016-06-03 02:05:25 +08:00
|
|
|
prepared_values = []
|
|
|
|
for value in values:
|
|
|
|
if hasattr(value, 'resolve_expression'):
|
|
|
|
prepared_values.append(value)
|
|
|
|
else:
|
|
|
|
prepared_values.append(tuple(value))
|
|
|
|
return prepared_values
|
2015-11-07 23:06:06 +08:00
|
|
|
|
|
|
|
|
2014-03-27 00:44:21 +08:00
|
|
|
class IndexTransform(Transform):
|
|
|
|
|
|
|
|
def __init__(self, index, base_field, *args, **kwargs):
|
2017-01-21 21:13:44 +08:00
|
|
|
super().__init__(*args, **kwargs)
|
2014-03-27 00:44:21 +08:00
|
|
|
self.index = index
|
|
|
|
self.base_field = base_field
|
|
|
|
|
2014-11-16 09:56:42 +08:00
|
|
|
def as_sql(self, compiler, connection):
|
|
|
|
lhs, params = compiler.compile(self.lhs)
|
2019-08-01 17:48:58 +08:00
|
|
|
return '%s[%%s]' % lhs, params + [self.index]
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
@property
|
2014-06-17 23:57:16 +08:00
|
|
|
def output_field(self):
|
2014-03-27 00:44:21 +08:00
|
|
|
return self.base_field
|
|
|
|
|
|
|
|
|
2017-01-19 15:39:46 +08:00
|
|
|
class IndexTransformFactory:
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
def __init__(self, index, base_field):
|
|
|
|
self.index = index
|
|
|
|
self.base_field = base_field
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
return IndexTransform(self.index, self.base_field, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class SliceTransform(Transform):
|
|
|
|
|
|
|
|
def __init__(self, start, end, *args, **kwargs):
|
2017-01-21 21:13:44 +08:00
|
|
|
super().__init__(*args, **kwargs)
|
2014-03-27 00:44:21 +08:00
|
|
|
self.start = start
|
|
|
|
self.end = end
|
|
|
|
|
2014-11-16 09:56:42 +08:00
|
|
|
def as_sql(self, compiler, connection):
|
|
|
|
lhs, params = compiler.compile(self.lhs)
|
2019-08-01 17:48:58 +08:00
|
|
|
return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
|
2017-01-19 15:39:46 +08:00
|
|
|
class SliceTransformFactory:
|
2014-03-27 00:44:21 +08:00
|
|
|
|
|
|
|
def __init__(self, start, end):
|
|
|
|
self.start = start
|
|
|
|
self.end = end
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
return SliceTransform(self.start, self.end, *args, **kwargs)
|