2014-03-27 00:44:21 +08:00
import json
from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import Field, Lookup, Transform, IntegerField
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _
__all__ = ['ArrayField']
class AttributeSetter(object):
def __init__(self, name, value):
setattr(self, name, value)
class ArrayField(Field):
empty_strings_allowed = False
default_error_messages = {
'item_invalid': _('Item %(nth)s in the array did not validate: '),
'nested_array_mismatch': _('Nested arrays must have the same length.'),
def __init__(self, base_field, size=None, **kwargs):
self.base_field = base_field
self.size = size
if self.size:
self.default_validators = self.default_validators[:]
super(ArrayField, self).__init__(**kwargs)
def check(self, **kwargs):
errors = super(ArrayField, self).check(**kwargs)
if self.base_field.rel:
'Base field for array cannot be a related field.',
# Remove the field name checks as they are not needed here.
base_errors = self.base_field.check()
if base_errors:
messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
'Base field for array has errors:\n %s' % messages,
return errors
def set_attributes_from_name(self, name):
super(ArrayField, self).set_attributes_from_name(name)
def description(self):
return 'Array of %s' % self.base_field.description
def db_type(self, connection):
size = self.size or ''
return '%s[%s]' % (self.base_field.db_type(connection), size)
def get_prep_value(self, value):
if isinstance(value, list) or isinstance(value, tuple):
return [self.base_field.get_prep_value(i) for i in value]
return value
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if lookup_type == 'contains':
return [self.get_prep_value(value)]
return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
connection, prepared=False)
def deconstruct(self):
name, path, args, kwargs = super(ArrayField, self).deconstruct()
2014-12-21 18:21:37 +08:00
if path == 'django.contrib.postgres.fields.array.ArrayField':
path = 'django.contrib.postgres.fields.ArrayField'
2014-11-23 01:57:13 +08:00
'base_field': self.base_field,
'size': self.size,
2014-03-27 00:44:21 +08:00
return name, path, args, kwargs
def to_python(self, value):
if isinstance(value, six.string_types):
# Assume we're deserializing
vals = json.loads(value)
value = [self.base_field.to_python(val) for val in vals]
return value
def value_to_string(self, obj):
values = []
vals = self._get_val_from_obj(obj)
base_field = self.base_field
for val in vals:
obj = AttributeSetter(base_field.attname, val)
return json.dumps(values)
def get_transform(self, name):
transform = super(ArrayField, self).get_transform(name)
if transform:
return transform
index = int(name)
except ValueError:
index += 1 # postgres uses 1-indexing
return IndexTransformFactory(index, self.base_field)
start, end = name.split('_')
start = int(start) + 1
end = int(end) # don't add one here because postgres slices are weird
except ValueError:
return SliceTransformFactory(start, end)
def validate(self, value, model_instance):
super(ArrayField, self).validate(value, model_instance)
for i, part in enumerate(value):
self.base_field.validate(part, model_instance)
except exceptions.ValidationError as e:
raise exceptions.ValidationError(
string_concat(self.error_messages['item_invalid'], e.message),
params={'nth': i},
if isinstance(self.base_field, ArrayField):
if len({len(i) for i in value}) > 1:
raise exceptions.ValidationError(
def formfield(self, **kwargs):
defaults = {
'form_class': SimpleArrayField,
'base_field': self.base_field.formfield(),
'max_length': self.size,
return super(ArrayField, self).formfield(**defaults)
2014-10-10 00:04:50 +08:00
2014-03-27 00:44:21 +08:00
class ArrayContainsLookup(Lookup):
lookup_name = 'contains'
2014-11-16 09:56:42 +08:00
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
2014-03-27 00:44:21 +08:00
params = lhs_params + rhs_params
2014-03-15 01:34:49 +08:00
type_cast = self.lhs.output_field.db_type(connection)
2014-07-15 19:13:23 +08:00
return '%s @> %s::%s' % (lhs, rhs, type_cast), params
2014-03-27 00:44:21 +08:00
2014-10-10 00:04:50 +08:00
2014-03-27 00:44:21 +08:00
class ArrayContainedByLookup(Lookup):
lookup_name = 'contained_by'
2014-11-16 09:56:42 +08:00
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
2014-03-27 00:44:21 +08:00
params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params
2014-10-10 00:04:50 +08:00
2014-03-27 00:44:21 +08:00
class ArrayOverlapLookup(Lookup):
lookup_name = 'overlap'
2014-11-16 09:56:42 +08:00
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
2014-03-27 00:44:21 +08:00
params = lhs_params + rhs_params
return '%s && %s' % (lhs, rhs), params
2014-10-10 00:04:50 +08:00
2014-03-27 00:44:21 +08:00
class ArrayLenTransform(Transform):
lookup_name = 'len'
2014-06-17 23:57:16 +08:00
def output_field(self):
2014-03-27 00:44:21 +08:00
return IntegerField()
2014-11-16 09:56:42 +08:00
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
2014-03-27 00:44:21 +08:00
return 'array_length(%s, 1)' % lhs, params
class IndexTransform(Transform):
def __init__(self, index, base_field, *args, **kwargs):
super(IndexTransform, self).__init__(*args, **kwargs)
self.index = index
self.base_field = base_field
2014-11-16 09:56:42 +08:00
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
2014-03-27 00:44:21 +08:00
return '%s[%s]' % (lhs, self.index), params
2014-06-17 23:57:16 +08:00
def output_field(self):
2014-03-27 00:44:21 +08:00
return self.base_field
class IndexTransformFactory(object):
def __init__(self, index, base_field):
self.index = index
self.base_field = base_field
def __call__(self, *args, **kwargs):
return IndexTransform(self.index, self.base_field, *args, **kwargs)
class SliceTransform(Transform):
def __init__(self, start, end, *args, **kwargs):
super(SliceTransform, self).__init__(*args, **kwargs)
self.start = start
self.end = end
2014-11-16 09:56:42 +08:00
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
2014-03-27 00:44:21 +08:00
return '%s[%s:%s]' % (lhs, self.start, self.end), params
class SliceTransformFactory(object):
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, *args, **kwargs):
return SliceTransform(self.start, self.end, *args, **kwargs)