Added array field support for PostgreSQL.

The first part of django.contrib.postgres, including model and two form
fields for arrays of other data types.

This commit is formed of the following work:

    Add shell of postgres app and test handling.

    First draft of array fields.

    Use recursive deconstruction.

    Stop creating classes at lookup time.

    Add validation and size parameter.

    Add contained_by lookup.

    Add SimpleArrayField for forms.

    Add SplitArrayField (mainly for admin).

    Fix prepare_value for SimpleArrayField.

    Stop using MultiValueField and MultiWidget.

    They don't play nice with flexible sizes.

    Add basics of admin integration.

    Missing:
    - Tests
    - Fully working js

    Add reference document for django.contrib.postgres.fields.ArrayField.

    Various performance and style tweaks.

    Fix internal docs link, formalise code snippets.

    Remove the admin code for now.

    It needs a better way of handing JS widgets in the admin as a whole
    before it is easy to write. In particular there are serious issues
    involving DateTimePicker when used in an array.

    Add a test for nested array fields with different delimiters.

    This will be a documented pattern so having a test for it is useful.

    Add docs for SimpleArrayField.

    Add docs for SplitArrayField.

    Remove admin related code for now.

    definition -> description

    Fix typo.

    Py3 errors.

    Avoid using regexes where they're not needed.

    Allow passing tuples by the programmer.

    Add some more tests for multidimensional arrays.

    Also fix slicing as much as it can be fixed.

    Simplify SplitArrayWidget's data loading.

    If we aren't including the variable size one, we don't need to search
    like this.
This commit is contained in:
Marc Tamlyn 2014-03-26 16:44:21 +00:00
parent d9d9242505
commit 604162604b
15 changed files with 1272 additions and 1 deletions

View File

View File

@ -0,0 +1 @@
from .array import * # NOQA

View File

@ -0,0 +1,254 @@
import json
from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import Field, Lookup, Transform, IntegerField
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _
__all__ = ['ArrayField']
class AttributeSetter(object):
def __init__(self, name, value):
setattr(self, name, value)
class ArrayField(Field):
empty_strings_allowed = False
default_error_messages = {
'item_invalid': _('Item %(nth)s in the array did not validate: '),
'nested_array_mismatch': _('Nested arrays must have the same length.'),
}
def __init__(self, base_field, size=None, **kwargs):
self.base_field = base_field
self.size = size
if self.size:
self.default_validators = self.default_validators[:]
self.default_validators.append(ArrayMaxLengthValidator(self.size))
super(ArrayField, self).__init__(**kwargs)
def check(self, **kwargs):
errors = super(ArrayField, self).check(**kwargs)
if self.base_field.rel:
errors.append(
checks.Error(
'Base field for array cannot be a related field.',
hint=None,
obj=self,
id='postgres.E002'
)
)
else:
# Remove the field name checks as they are not needed here.
base_errors = self.base_field.check()
if base_errors:
messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
errors.append(
checks.Error(
'Base field for array has errors:\n %s' % messages,
hint=None,
obj=self,
id='postgres.E001'
)
)
return errors
def set_attributes_from_name(self, name):
super(ArrayField, self).set_attributes_from_name(name)
self.base_field.set_attributes_from_name(name)
@property
def description(self):
return 'Array of %s' % self.base_field.description
def db_type(self, connection):
size = self.size or ''
return '%s[%s]' % (self.base_field.db_type(connection), size)
def get_prep_value(self, value):
if isinstance(value, list) or isinstance(value, tuple):
return [self.base_field.get_prep_value(i) for i in value]
return value
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if lookup_type == 'contains':
return [self.get_prep_value(value)]
return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
connection, prepared=False)
def deconstruct(self):
name, path, args, kwargs = super(ArrayField, self).deconstruct()
path = 'django.contrib.postgres.fields.ArrayField'
args.insert(0, self.base_field)
kwargs['size'] = self.size
return name, path, args, kwargs
def to_python(self, value):
if isinstance(value, six.string_types):
# Assume we're deserializing
vals = json.loads(value)
value = [self.base_field.to_python(val) for val in vals]
return value
def value_to_string(self, obj):
values = []
vals = self._get_val_from_obj(obj)
base_field = self.base_field
for val in vals:
obj = AttributeSetter(base_field.attname, val)
values.append(base_field.value_to_string(obj))
return json.dumps(values)
def get_transform(self, name):
transform = super(ArrayField, self).get_transform(name)
if transform:
return transform
try:
index = int(name)
except ValueError:
pass
else:
index += 1 # postgres uses 1-indexing
return IndexTransformFactory(index, self.base_field)
try:
start, end = name.split('_')
start = int(start) + 1
end = int(end) # don't add one here because postgres slices are weird
except ValueError:
pass
else:
return SliceTransformFactory(start, end)
def validate(self, value, model_instance):
super(ArrayField, self).validate(value, model_instance)
for i, part in enumerate(value):
try:
self.base_field.validate(part, model_instance)
except exceptions.ValidationError as e:
raise exceptions.ValidationError(
string_concat(self.error_messages['item_invalid'], e.message),
code='item_invalid',
params={'nth': i},
)
if isinstance(self.base_field, ArrayField):
if len({len(i) for i in value}) > 1:
raise exceptions.ValidationError(
self.error_messages['nested_array_mismatch'],
code='nested_array_mismatch',
)
def formfield(self, **kwargs):
defaults = {
'form_class': SimpleArrayField,
'base_field': self.base_field.formfield(),
'max_length': self.size,
}
defaults.update(kwargs)
return super(ArrayField, self).formfield(**defaults)
class ArrayContainsLookup(Lookup):
lookup_name = 'contains'
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return '%s @> %s' % (lhs, rhs), params
ArrayField.register_lookup(ArrayContainsLookup)
class ArrayContainedByLookup(Lookup):
lookup_name = 'contained_by'
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params
ArrayField.register_lookup(ArrayContainedByLookup)
class ArrayOverlapLookup(Lookup):
lookup_name = 'overlap'
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return '%s && %s' % (lhs, rhs), params
ArrayField.register_lookup(ArrayOverlapLookup)
class ArrayLenTransform(Transform):
lookup_name = 'len'
@property
def output_type(self):
return IntegerField()
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return 'array_length(%s, 1)' % lhs, params
ArrayField.register_lookup(ArrayLenTransform)
class IndexTransform(Transform):
def __init__(self, index, base_field, *args, **kwargs):
super(IndexTransform, self).__init__(*args, **kwargs)
self.index = index
self.base_field = base_field
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return '%s[%s]' % (lhs, self.index), params
@property
def output_type(self):
return self.base_field
class IndexTransformFactory(object):
def __init__(self, index, base_field):
self.index = index
self.base_field = base_field
def __call__(self, *args, **kwargs):
return IndexTransform(self.index, self.base_field, *args, **kwargs)
class SliceTransform(Transform):
def __init__(self, start, end, *args, **kwargs):
super(SliceTransform, self).__init__(*args, **kwargs)
self.start = start
self.end = end
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return '%s[%s:%s]' % (lhs, self.start, self.end), params
class SliceTransformFactory(object):
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, *args, **kwargs):
return SliceTransform(self.start, self.end, *args, **kwargs)

View File

@ -0,0 +1 @@
from .array import * # NOQA

View File

@ -0,0 +1,185 @@
import copy
from django.contrib.postgres.validators import ArrayMinLengthValidator, ArrayMaxLengthValidator
from django.core.exceptions import ValidationError
from django import forms
from django.utils.safestring import mark_safe
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _
class SimpleArrayField(forms.CharField):
default_error_messages = {
'item_invalid': _('Item %(nth)s in the array did not validate: '),
}
def __init__(self, base_field, delimiter=',', max_length=None, min_length=None, *args, **kwargs):
self.base_field = base_field
self.delimiter = delimiter
super(SimpleArrayField, self).__init__(*args, **kwargs)
if min_length is not None:
self.min_length = min_length
self.validators.append(ArrayMinLengthValidator(int(min_length)))
if max_length is not None:
self.max_length = max_length
self.validators.append(ArrayMaxLengthValidator(int(max_length)))
def prepare_value(self, value):
if isinstance(value, list):
return self.delimiter.join([six.text_type(self.base_field.prepare_value(v)) for v in value])
return value
def to_python(self, value):
if value:
items = value.split(self.delimiter)
else:
items = []
errors = []
values = []
for i, item in enumerate(items):
try:
values.append(self.base_field.to_python(item))
except ValidationError as e:
for error in e.error_list:
errors.append(ValidationError(
string_concat(self.error_messages['item_invalid'], error.message),
code='item_invalid',
params={'nth': i},
))
if errors:
raise ValidationError(errors)
return values
def validate(self, value):
super(SimpleArrayField, self).validate(value)
errors = []
for i, item in enumerate(value):
try:
self.base_field.validate(item)
except ValidationError as e:
for error in e.error_list:
errors.append(ValidationError(
string_concat(self.error_messages['item_invalid'], error.message),
code='item_invalid',
params={'nth': i},
))
if errors:
raise ValidationError(errors)
def run_validators(self, value):
super(SimpleArrayField, self).run_validators(value)
errors = []
for i, item in enumerate(value):
try:
self.base_field.run_validators(item)
except ValidationError as e:
for error in e.error_list:
errors.append(ValidationError(
string_concat(self.error_messages['item_invalid'], error.message),
code='item_invalid',
params={'nth': i},
))
if errors:
raise ValidationError(errors)
class SplitArrayWidget(forms.Widget):
def __init__(self, widget, size, **kwargs):
self.widget = widget() if isinstance(widget, type) else widget
self.size = size
super(SplitArrayWidget, self).__init__(**kwargs)
@property
def is_hidden(self):
return self.widget.is_hidden
def value_from_datadict(self, data, files, name):
return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index))
for index in range(self.size)]
def id_for_label(self, id_):
# See the comment for RadioSelect.id_for_label()
if id_:
id_ += '_0'
return id_
def render(self, name, value, attrs=None):
if self.is_localized:
self.widget.is_localized = self.is_localized
value = value or []
output = []
final_attrs = self.build_attrs(attrs)
id_ = final_attrs.get('id', None)
for i in range(max(len(value), self.size)):
try:
widget_value = value[i]
except IndexError:
widget_value = None
if id_:
final_attrs = dict(final_attrs, id='%s_%s' % (id_, i))
output.append(self.widget.render(name + '_%s' % i, widget_value, final_attrs))
return mark_safe(self.format_output(output))
def format_output(self, rendered_widgets):
return ''.join(rendered_widgets)
@property
def media(self):
return self.widget.media
def __deepcopy__(self, memo):
obj = super(SplitArrayWidget, self).__deepcopy__(memo)
obj.widget = copy.deepcopy(self.widget)
return obj
@property
def needs_multipart_form(self):
return self.widget.needs_multipart_form
class SplitArrayField(forms.Field):
default_error_messages = {
'item_invalid': _('Item %(nth)s in the array did not validate: '),
}
def __init__(self, base_field, size, remove_trailing_nulls=False, **kwargs):
self.base_field = base_field
self.size = size
self.remove_trailing_nulls = remove_trailing_nulls
widget = SplitArrayWidget(widget=base_field.widget, size=size)
kwargs.setdefault('widget', widget)
super(SplitArrayField, self).__init__(**kwargs)
def clean(self, value):
cleaned_data = []
errors = []
if not any(value) and self.required:
raise ValidationError(self.error_messages['required'])
max_size = max(self.size, len(value))
for i in range(max_size):
item = value[i]
try:
cleaned_data.append(self.base_field.clean(item))
errors.append(None)
except ValidationError as error:
errors.append(ValidationError(
string_concat(self.error_messages['item_invalid'], error.message),
code='item_invalid',
params={'nth': i},
))
cleaned_data.append(None)
if self.remove_trailing_nulls:
null_index = None
for i, value in reversed(list(enumerate(cleaned_data))):
if value in self.base_field.empty_values:
null_index = i
else:
break
if null_index:
cleaned_data = cleaned_data[:null_index]
errors = errors[:null_index]
errors = list(filter(None, errors))
if errors:
raise ValidationError(errors)
return cleaned_data

View File

@ -0,0 +1,16 @@
from django.core.validators import MaxLengthValidator, MinLengthValidator
from django.utils.translation import ungettext_lazy
class ArrayMaxLengthValidator(MaxLengthValidator):
message = ungettext_lazy(
'List contains %(show_value)d item, it should contain no more than %(limit_value)d.',
'List contains %(show_value)d items, it should contain no more than %(limit_value)d.',
'limit_value')
class ArrayMinLengthValidator(MinLengthValidator):
message = ungettext_lazy(
'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.',
'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.',
'limit_value')

View File

@ -91,7 +91,8 @@ manipulating the data of your Web application. Learn more about it below:
:doc:`Supported databases <ref/databases>` |
:doc:`Legacy databases <howto/legacy-databases>` |
:doc:`Providing initial data <howto/initial-data>` |
:doc:`Optimize database access <topics/db/optimization>`
:doc:`Optimize database access <topics/db/optimization>` |
:doc:`PostgreSQL specific features <ref/contrib/postgres/index>`
The view layer
==============

View File

@ -31,6 +31,7 @@ those packages have.
gis/index
humanize
messages
postgres/index
redirects
sitemaps
sites
@ -122,6 +123,13 @@ messages
See the :doc:`messages documentation </ref/contrib/messages>`.
postgres
========
A collection of PostgreSQL specific features.
See the :doc:`contrib.postgres documentation </ref/contrib/postgres/index>`.
redirects
=========

View File

@ -0,0 +1,228 @@
PostgreSQL specific model fields
================================
All of these fields are available from the ``django.contrib.postgres.fields``
module.
.. currentmodule:: django.contrib.postgres.fields
ArrayField
----------
.. class:: ArrayField(base_field, size=None, **options)
A field for storing lists of data. Most field types can be used, you simply
pass another field instance as the :attr:`base_field
<ArrayField.base_field>`. You may also specify a :attr:`size
<ArrayField.size>`. ``ArrayField`` can be nested to store multi-dimensional
arrays.
.. attribute:: base_field
This is a required argument.
Specifies the underlying data type and behaviour for the array. It
should be an instance of a subclass of
:class:`~django.db.models.Field`. For example, it could be an
:class:`~django.db.models.IntegerField` or a
:class:`~django.db.models.CharField`. Most field types are permitted,
with the exception of those handling relational data
(:class:`~django.db.models.ForeignKey`,
:class:`~django.db.models.OneToOneField` and
:class:`~django.db.models.ManyToManyField`).
It is possible to nest array fields - you can specify an instance of
``ArrayField`` as the ``base_field``. For example::
from django.db import models
from django.contrib.postgres.fields import ArrayField
class ChessBoard(models.Model):
board = ArrayField(
ArrayField(
CharField(max_length=10, blank=True, null=True),
size=8),
size=8)
Transformation of values between the database and the model, validation
of data and configuration, and serialization are all delegated to the
underlying base field.
.. attribute:: size
This is an optional argument.
If passed, the array will have a maximum size as specified. This will
be passed to the database, although PostgreSQL at present does not
enforce the restriction.
.. note::
When nesting ``ArrayField``, whether you use the `size` parameter or not,
PostgreSQL requires that the arrays are rectangular::
from django.db import models
from django.contrib.postgres.fields import ArrayField
class Board(models.Model):
pieces = ArrayField(ArrayField(models.IntegerField()))
# Valid
Board(pieces=[
[2, 3],
[2, 1],
])
# Not valid
Board(pieces=[
[2, 3],
[2],
])
If irregular shapes are required, then the underlying field should be made
nullable and the values padded with ``None``.
Querying ArrayField
^^^^^^^^^^^^^^^^^^^
There are a number of custom lookups and transforms for :class:`ArrayField`.
We will use the following example model::
from django.db import models
from django.contrib.postgres.fields import ArrayField
class Post(models.Model):
name = models.CharField(max_length=200)
tags = ArrayField(models.CharField(max_length=200), blank=True)
def __str__(self): # __unicode__ on python 2
return self.name
.. fieldlookup:: arrayfield.contains
contains
~~~~~~~~
The :lookup:`contains` lookup is overridden on :class:`ArrayField`. The
returned objects will be those where the values passed are a subset of the
data. It uses the SQL operator ``@>``. For example::
>>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
>>> Post.objects.create(name='Second post', tags=['thoughts'])
>>> Post.objects.create(name='Third post', tags=['tutorial', 'django'])
>>> Post.objects.filter(tags__contains=['thoughts'])
[<Post: First post>, <Post: Second post>]
>>> Post.objects.filter(tags__contains=['django'])
[<Post: First post>, <Post: Third post>]
>>> Post.objects.filter(tags__contains=['django', 'thoughts'])
[<Post: First post>]
.. fieldlookup:: arrayfield.contained_by
contained_by
~~~~~~~~~~~~
This is the inverse of the :lookup:`contains <arrayfield.contains>` lookup -
the objects returned will be those where the data is a subset of the values
passed. It uses the SQL operator ``<@``. For example::
>>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
>>> Post.objects.create(name='Second post', tags=['thoughts'])
>>> Post.objects.create(name='Third post', tags=['tutorial', 'django'])
>>> Post.objects.filter(tags__contained_by=['thoughts', 'django'])
[<Post: First post>]
>>> Post.objects.filter(tags__contained_by=['thoughts', 'django', 'tutorial'])
[<Post: First post>, <Post: Second post>, <Post: Third post>]
.. fieldlookup:: arrayfield.overlap
overlap
~~~~~~~
Returns objects where the data shares any results with the values passed. Uses
the SQL operator ``&&``. For example::
>>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
>>> Post.objects.create(name='Second post', tags=['thoughts'])
>>> Post.objects.create(name='Third post', tags=['tutorial', 'django'])
>>> Post.objects.filter(tags__overlap=['thoughts'])
[<Post: First post>, <Post: Second post>]
>>> Post.objects.filter(tags__overlap=['thoughts', 'tutorial'])
[<Post: First post>, <Post: Second post>, <Post: Third post>]
.. fieldlookup:: arrayfield.index
Index transforms
~~~~~~~~~~~~~~~~
This class of transforms allows you to index into the array in queries. Any
non-negative integer can be used. There are no errors if it exceeds the
:attr:`size <ArrayField.size>` of the array. The lookups available after the
transform are those from the :attr:`base_field <ArrayField.base_field>`. For
example::
>>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
>>> Post.objects.create(name='Second post', tags=['thoughts'])
>>> Post.objects.filter(tags__0='thoughts')
[<Post: First post>, <Post: Second post>]
>>> Post.objects.filter(tags__1__iexact='Django')
[<Post: First post>]
>>> Post.objects.filter(tags__276='javascript')
[]
.. note::
PostgreSQL uses 1-based indexing for array fields when writing raw SQL.
However these indexes and those used in :lookup:`slices <arrayfield.slice>`
use 0-based indexing to be consistent with Python.
.. fieldlookup:: arrayfield.slice
Slice transforms
~~~~~~~~~~~~~~~~
This class of transforms allow you to take a slice of the array. Any two
non-negative integers can be used, separated by a single underscore. The
lookups available after the transform do not change. For example::
>>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
>>> Post.objects.create(name='Second post', tags=['thoughts'])
>>> Post.objects.create(name='Third post', tags=['django', 'python', 'thoughts'])
>>> Post.objects.filter(tags__0_1=['thoughts'])
[<Post: First post>]
>>> Post.objects.filter(tags__0_2__contains='thoughts')
[<Post: First post>, <Post: Second post>]
.. note::
PostgreSQL uses 1-based indexing for array fields when writing raw SQL.
However these slices and those used in :lookup:`indexes <arrayfield.index>`
use 0-based indexing to be consistent with Python.
.. admonition:: Multidimensional arrays with indexes and slices
PostgreSQL has some rather esoteric behaviour when using indexes and slices
on multidimensional arrays. It will always work to use indexes to reach
down to the final underlying data, but most other slices behave strangely
at the database level and cannot be supported in a logical, consistent
fashion by Django.
Indexing ArrayField
^^^^^^^^^^^^^^^^^^^
At present using :attr:`~django.db.models.Field.db_index` will create a
``btree`` index. This does not offer particularly significant help to querying.
A more useful index is a ``GIN`` index, which you should create using a
:class:`~django.db.migrations.operations.RunSQL` operation.

View File

@ -0,0 +1,135 @@
PostgreSQL specific form fields and widgets
===========================================
All of these fields and widgets are available from the
``django.contrib.postgres.forms`` module.
.. currentmodule:: django.contrib.postgres.forms
SimpleArrayField
----------------
.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None)
A simple field which maps to an array. It is represented by an HTML
``<input>``.
.. attribute:: base_field
This is a required argument.
It specifies the underlying form field for the array. This is not used
to render any HTML, but it is used to process the submitted data and
validate it. For example::
>>> from django.contrib.postgres.forms import SimpleArrayField
>>> from django import forms
>>> class NumberListForm(forms.Form):
... numbers = SimpleArrayField(forms.IntegerField())
>>> form = NumberListForm({'numbers': '1,2,3'})
>>> form.is_valid()
True
>>> form.cleaned_data
{'numbers': [1, 2, 3]}
>>> form = NumberListForm({'numbers': '1,2,a'})
>>> form.is_valid()
False
.. attribute:: delimiter
This is an optional argument which defaults to a comma: ``,``. This
value is used to split the submitted data. It allows you to chain
``SimpleArrayField`` for multidimensional data::
>>> from django.contrib.postgres.forms import SimpleArrayField
>>> from django import forms
>>> class GridForm(forms.Form):
... places = SimpleArrayField(SimpleArrayField(IntegerField()), delimiter='|')
>>> form = GridForm({'places': '1,2|2,1|4,3'})
>>> form.is_valid()
True
>>> form.cleaned_data
{'places': [[1, 2], [2, 1], [4, 3]]}
.. note::
The field does not support escaping of the delimiter, so be careful
in cases where the delimiter is a valid character in the underlying
field. The delimiter does not need to be only one character.
.. attribute:: max_length
This is an optional argument which validates that the array does not
exceed the stated length.
.. attribute:: min_length
This is an optional argument which validates that the array reaches at
least the stated length.
.. admonition:: User friendly forms
``SimpleArrayField`` is not particularly user friendly in most cases,
however it is a useful way to format data from a client-side widget for
submission to the server.
SplitArrayField
---------------
.. class:: SplitArrayField(base_field, size, remove_trailing_nulls=False)
This field handles arrays by reproducing the underlying field a fixed
number of times.
.. attribute:: base_field
This is a required argument. It specifies the form field to be
repeated.
.. attribute:: size
This is the fixed number of times the underlying field will be used.
.. attribute:: remove_trailing_nulls
By default, this is set to ``False``. When ``False``, each value from
the repeated fields is stored. When set to ``True``, any trailing
values which are blank will be stripped from the result. If the
underlying field has ``required=True``, but ``remove_trailing_nulls``
is ``True``, then null values are only allowed at the end, and will be
stripped.
Some examples::
SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=False)
['1', '2', '3'] # -> [1, 2, 3]
['1', '2', ''] # -> ValidationError - third entry required.
['1', '', '3'] # -> ValidationError - second entry required.
['', '2', ''] # -> ValidationError - first and third entries required.
SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=False)
['1', '2', '3'] # -> [1, 2, 3]
['1', '2', ''] # -> [1, 2, None]
['1', '', '3'] # -> [1, None, 3]
['', '2', ''] # -> [None, 2, None]
SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=True)
['1', '2', '3'] # -> [1, 2, 3]
['1', '2', ''] # -> [1, 2]
['1', '', '3'] # -> ValidationError - second entry required.
['', '2', ''] # -> ValidationError - first entry required.
SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=True)
['1', '2', '3'] # -> [1, 2, 3]
['1', '2', ''] # -> [1, 2]
['1', '', '3'] # -> [1, None, 3]
['', '2', ''] # -> [None, 2]

View File

@ -0,0 +1,28 @@
``django.contrib.postgres``
===========================
PostgreSQL has a number of features which are not shared by the other databases
Django supports. This optional module contains model fields and form fields for
a number of PostgreSQL specific data types.
.. note::
Django is, and will continue to be, a database-agnostic web framework. We
would encourage those writing reusable applications for the Django
community to write database-agnostic code where practical. However, we
recognise that real world projects written using Django need not be
database-agnostic. In fact, once a project reaches a given size changing
the underlying data store is already a significant challenge and is likely
to require changing the code base in some ways to handle differences
between the data stores.
Django provides support for a number of data types which will
only work with PostgreSQL. There is no fundamental reason why (for example)
a ``contrib.mysql`` module does not exist, except that PostgreSQL has the
richest feature set of the supported databases so its users have the most
to gain.
.. toctree::
:maxdepth: 2
fields
forms

View File

View File

@ -0,0 +1,22 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models
class IntegerArrayModel(models.Model):
field = ArrayField(models.IntegerField())
class NullableIntegerArrayModel(models.Model):
field = ArrayField(models.IntegerField(), blank=True, null=True)
class CharArrayModel(models.Model):
field = ArrayField(models.CharField(max_length=10))
class DateTimeArrayModel(models.Model):
field = ArrayField(models.DateTimeField())
class NestedIntegerArrayModel(models.Model):
field = ArrayField(ArrayField(models.IntegerField()))

View File

@ -0,0 +1,389 @@
import unittest
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField
from django.core import exceptions, serializers
from django.db import models, IntegrityError, connection
from django.db.migrations.writer import MigrationWriter
from django import forms
from django.test import TestCase
from django.utils import timezone
from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
class TestSaveLoad(TestCase):
def test_integer(self):
instance = IntegerArrayModel(field=[1, 2, 3])
instance.save()
loaded = IntegerArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
def test_char(self):
instance = CharArrayModel(field=['hello', 'goodbye'])
instance.save()
loaded = CharArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
def test_dates(self):
instance = DateTimeArrayModel(field=[timezone.now()])
instance.save()
loaded = DateTimeArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
def test_tuples(self):
instance = IntegerArrayModel(field=(1,))
instance.save()
loaded = IntegerArrayModel.objects.get()
self.assertSequenceEqual(instance.field, loaded.field)
def test_integers_passed_as_strings(self):
# This checks that get_prep_value is deferred properly
instance = IntegerArrayModel(field=['1'])
instance.save()
loaded = IntegerArrayModel.objects.get()
self.assertEqual(loaded.field, [1])
def test_null_handling(self):
instance = NullableIntegerArrayModel(field=None)
instance.save()
loaded = NullableIntegerArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
instance = IntegerArrayModel(field=None)
with self.assertRaises(IntegrityError):
instance.save()
def test_nested(self):
instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]])
instance.save()
loaded = NestedIntegerArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
class TestQuerying(TestCase):
def setUp(self):
self.objs = [
NullableIntegerArrayModel.objects.create(field=[1]),
NullableIntegerArrayModel.objects.create(field=[2]),
NullableIntegerArrayModel.objects.create(field=[2, 3]),
NullableIntegerArrayModel.objects.create(field=[20, 30, 40]),
NullableIntegerArrayModel.objects.create(field=None),
]
def test_exact(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__exact=[1]),
self.objs[:1]
)
def test_isnull(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__isnull=True),
self.objs[-1:]
)
def test_gt(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__gt=[0]),
self.objs[:4]
)
def test_lt(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__lt=[2]),
self.objs[:1]
)
def test_in(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
self.objs[:2]
)
def test_contained_by(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
self.objs[:2]
)
def test_contains(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contains=[2]),
self.objs[1:3]
)
def test_index(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0=2),
self.objs[1:3]
)
def test_index_chained(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0__lt=3),
self.objs[0:3]
)
def test_index_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
NestedIntegerArrayModel.objects.filter(field__0__0=1),
[instance]
)
@unittest.expectedFailure
def test_index_used_on_nested_data(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
NestedIntegerArrayModel.objects.filter(field__0=[1, 2]),
[instance]
)
def test_overlap(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
self.objs[0:3]
)
def test_len(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__len__lte=2),
self.objs[0:3]
)
def test_slice(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0_1=[2]),
self.objs[1:3]
)
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]),
self.objs[2:3]
)
@unittest.expectedFailure
def test_slice_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]),
[instance]
)
class TestChecks(TestCase):
def test_field_checks(self):
field = ArrayField(models.CharField())
field.set_attributes_from_name('field')
errors = field.check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, 'postgres.E001')
def test_invalid_base_fields(self):
field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel'))
field.set_attributes_from_name('field')
errors = field.check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, 'postgres.E002')
class TestMigrations(TestCase):
def test_deconstruct(self):
field = ArrayField(models.IntegerField())
name, path, args, kwargs = field.deconstruct()
new = ArrayField(*args, **kwargs)
self.assertEqual(type(new.base_field), type(field.base_field))
def test_deconstruct_with_size(self):
field = ArrayField(models.IntegerField(), size=3)
name, path, args, kwargs = field.deconstruct()
new = ArrayField(*args, **kwargs)
self.assertEqual(new.size, field.size)
def test_deconstruct_args(self):
field = ArrayField(models.CharField(max_length=20))
name, path, args, kwargs = field.deconstruct()
new = ArrayField(*args, **kwargs)
self.assertEqual(new.base_field.max_length, field.base_field.max_length)
def test_makemigrations(self):
field = ArrayField(models.CharField(max_length=20))
statement, imports = MigrationWriter.serialize(field)
self.assertEqual(statement, 'django.contrib.postgres.fields.ArrayField(models.CharField(max_length=20), size=None)')
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
class TestSerialization(TestCase):
test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\"]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
def test_dumping(self):
instance = IntegerArrayModel(field=[1, 2])
data = serializers.serialize('json', [instance])
self.assertEqual(data, self.test_data)
def test_loading(self):
instance = list(serializers.deserialize('json', self.test_data))[0].object
self.assertEqual(instance.field, [1, 2])
class TestValidation(TestCase):
def test_unbounded(self):
field = ArrayField(models.IntegerField())
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([1, None], None)
self.assertEqual(cm.exception.code, 'item_invalid')
self.assertEqual(cm.exception.message % cm.exception.params, 'Item 1 in the array did not validate: This field cannot be null.')
def test_blank_true(self):
field = ArrayField(models.IntegerField(blank=True, null=True))
# This should not raise a validation error
field.clean([1, None], None)
def test_with_size(self):
field = ArrayField(models.IntegerField(), size=3)
field.clean([1, 2, 3], None)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([1, 2, 3, 4], None)
self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.')
def test_nested_array_mismatch(self):
field = ArrayField(ArrayField(models.IntegerField()))
field.clean([[1, 2], [3, 4]], None)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([[1, 2], [3, 4, 5]], None)
self.assertEqual(cm.exception.code, 'nested_array_mismatch')
self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.')
class TestSimpleFormField(TestCase):
def test_valid(self):
field = SimpleArrayField(forms.CharField())
value = field.clean('a,b,c')
self.assertEqual(value, ['a', 'b', 'c'])
def test_to_python_fail(self):
field = SimpleArrayField(forms.IntegerField())
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('a,b,9')
self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a whole number.')
def test_validate_fail(self):
field = SimpleArrayField(forms.CharField(required=True))
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('a,b,')
self.assertEqual(cm.exception.messages[0], 'Item 2 in the array did not validate: This field is required.')
def test_validators_fail(self):
field = SimpleArrayField(forms.RegexField('[a-e]{2}'))
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('a,bc,de')
self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a valid value.')
def test_delimiter(self):
field = SimpleArrayField(forms.CharField(), delimiter='|')
value = field.clean('a|b|c')
self.assertEqual(value, ['a', 'b', 'c'])
def test_delimiter_with_nesting(self):
field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|')
value = field.clean('a,b|c,d')
self.assertEqual(value, [['a', 'b'], ['c', 'd']])
def test_prepare_value(self):
field = SimpleArrayField(forms.CharField())
value = field.prepare_value(['a', 'b', 'c'])
self.assertEqual(value, 'a,b,c')
def test_max_length(self):
field = SimpleArrayField(forms.CharField(), max_length=2)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('a,b,c')
self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.')
def test_min_length(self):
field = SimpleArrayField(forms.CharField(), min_length=4)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('a,b,c')
self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.')
def test_required(self):
field = SimpleArrayField(forms.CharField(), required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('')
self.assertEqual(cm.exception.messages[0], 'This field is required.')
def test_model_field_formfield(self):
model_field = ArrayField(models.CharField(max_length=27))
form_field = model_field.formfield()
self.assertIsInstance(form_field, SimpleArrayField)
self.assertIsInstance(form_field.base_field, forms.CharField)
self.assertEqual(form_field.base_field.max_length, 27)
def test_model_field_formfield_size(self):
model_field = ArrayField(models.CharField(max_length=27), size=4)
form_field = model_field.formfield()
self.assertIsInstance(form_field, SimpleArrayField)
self.assertEqual(form_field.max_length, 4)
class TestSplitFormField(TestCase):
def test_valid(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), size=3)
data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'}
form = SplitForm(data)
self.assertTrue(form.is_valid())
self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']})
def test_required(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), required=True, size=3)
data = {'array_0': '', 'array_1': '', 'array_2': ''}
form = SplitForm(data)
self.assertFalse(form.is_valid())
self.assertEqual(form.errors, {'array': ['This field is required.']})
def test_remove_trailing_nulls(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True)
data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''}
form = SplitForm(data)
self.assertTrue(form.is_valid(), form.errors)
self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']})
def test_required_field(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), size=3)
data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''}
form = SplitForm(data)
self.assertFalse(form.is_valid())
self.assertEqual(form.errors, {'array': ['Item 2 in the array did not validate: This field is required.']})
def test_rendering(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), size=3)
self.assertHTMLEqual(str(SplitForm()), '''
<tr>
<th><label for="id_array_0">Array:</label></th>
<td>
<input id="id_array_0" name="array_0" type="text" />
<input id="id_array_1" name="array_1" type="text" />
<input id="id_array_2" name="array_2" type="text" />
</td>
</tr>
''')

View File

@ -57,6 +57,7 @@ ALWAYS_INSTALLED_APPS = [
def get_test_modules():
from django.contrib.gis.tests.utils import HAS_SPATIAL_DB
from django.db import connection
modules = []
discovery_paths = [
(None, RUNTESTS_DIR),
@ -75,6 +76,8 @@ def get_test_modules():
os.path.isfile(f) or
not os.path.exists(os.path.join(dirpath, f, '__init__.py'))):
continue
if not connection.vendor == 'postgresql' and f == 'postgres_tests':
continue
modules.append((modpath, f))
return modules