From 604162604bf816fa46b0d972621f06de64df6a66 Mon Sep 17 00:00:00 2001 From: Marc Tamlyn Date: Wed, 26 Mar 2014 16:44:21 +0000 Subject: [PATCH] Added array field support for PostgreSQL. The first part of django.contrib.postgres, including model and two form fields for arrays of other data types. This commit is formed of the following work: Add shell of postgres app and test handling. First draft of array fields. Use recursive deconstruction. Stop creating classes at lookup time. Add validation and size parameter. Add contained_by lookup. Add SimpleArrayField for forms. Add SplitArrayField (mainly for admin). Fix prepare_value for SimpleArrayField. Stop using MultiValueField and MultiWidget. They don't play nice with flexible sizes. Add basics of admin integration. Missing: - Tests - Fully working js Add reference document for django.contrib.postgres.fields.ArrayField. Various performance and style tweaks. Fix internal docs link, formalise code snippets. Remove the admin code for now. It needs a better way of handing JS widgets in the admin as a whole before it is easy to write. In particular there are serious issues involving DateTimePicker when used in an array. Add a test for nested array fields with different delimiters. This will be a documented pattern so having a test for it is useful. Add docs for SimpleArrayField. Add docs for SplitArrayField. Remove admin related code for now. definition -> description Fix typo. Py3 errors. Avoid using regexes where they're not needed. Allow passing tuples by the programmer. Add some more tests for multidimensional arrays. Also fix slicing as much as it can be fixed. Simplify SplitArrayWidget's data loading. If we aren't including the variable size one, we don't need to search like this. --- django/contrib/postgres/__init__.py | 0 django/contrib/postgres/fields/__init__.py | 1 + django/contrib/postgres/fields/array.py | 254 ++++++++++++++ django/contrib/postgres/forms/__init__.py | 1 + django/contrib/postgres/forms/array.py | 185 ++++++++++ django/contrib/postgres/validators.py | 16 + docs/index.txt | 3 +- docs/ref/contrib/index.txt | 8 + docs/ref/contrib/postgres/fields.txt | 228 ++++++++++++ docs/ref/contrib/postgres/forms.txt | 135 +++++++ docs/ref/contrib/postgres/index.txt | 28 ++ tests/postgres_tests/__init__.py | 0 tests/postgres_tests/models.py | 22 ++ tests/postgres_tests/test_array.py | 389 +++++++++++++++++++++ tests/runtests.py | 3 + 15 files changed, 1272 insertions(+), 1 deletion(-) create mode 100644 django/contrib/postgres/__init__.py create mode 100644 django/contrib/postgres/fields/__init__.py create mode 100644 django/contrib/postgres/fields/array.py create mode 100644 django/contrib/postgres/forms/__init__.py create mode 100644 django/contrib/postgres/forms/array.py create mode 100644 django/contrib/postgres/validators.py create mode 100644 docs/ref/contrib/postgres/fields.txt create mode 100644 docs/ref/contrib/postgres/forms.txt create mode 100644 docs/ref/contrib/postgres/index.txt create mode 100644 tests/postgres_tests/__init__.py create mode 100644 tests/postgres_tests/models.py create mode 100644 tests/postgres_tests/test_array.py diff --git a/django/contrib/postgres/__init__.py b/django/contrib/postgres/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/django/contrib/postgres/fields/__init__.py b/django/contrib/postgres/fields/__init__.py new file mode 100644 index 0000000000..e3ceebd62c --- /dev/null +++ b/django/contrib/postgres/fields/__init__.py @@ -0,0 +1 @@ +from .array import * # NOQA diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py new file mode 100644 index 0000000000..7a37267400 --- /dev/null +++ b/django/contrib/postgres/fields/array.py @@ -0,0 +1,254 @@ +import json + +from django.contrib.postgres.forms import SimpleArrayField +from django.contrib.postgres.validators import ArrayMaxLengthValidator +from django.core import checks, exceptions +from django.db.models import Field, Lookup, Transform, IntegerField +from django.utils import six +from django.utils.translation import string_concat, ugettext_lazy as _ + + +__all__ = ['ArrayField'] + + +class AttributeSetter(object): + def __init__(self, name, value): + setattr(self, name, value) + + +class ArrayField(Field): + empty_strings_allowed = False + default_error_messages = { + 'item_invalid': _('Item %(nth)s in the array did not validate: '), + 'nested_array_mismatch': _('Nested arrays must have the same length.'), + } + + def __init__(self, base_field, size=None, **kwargs): + self.base_field = base_field + self.size = size + if self.size: + self.default_validators = self.default_validators[:] + self.default_validators.append(ArrayMaxLengthValidator(self.size)) + super(ArrayField, self).__init__(**kwargs) + + def check(self, **kwargs): + errors = super(ArrayField, self).check(**kwargs) + if self.base_field.rel: + errors.append( + checks.Error( + 'Base field for array cannot be a related field.', + hint=None, + obj=self, + id='postgres.E002' + ) + ) + else: + # Remove the field name checks as they are not needed here. + base_errors = self.base_field.check() + if base_errors: + messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors) + errors.append( + checks.Error( + 'Base field for array has errors:\n %s' % messages, + hint=None, + obj=self, + id='postgres.E001' + ) + ) + return errors + + def set_attributes_from_name(self, name): + super(ArrayField, self).set_attributes_from_name(name) + self.base_field.set_attributes_from_name(name) + + @property + def description(self): + return 'Array of %s' % self.base_field.description + + def db_type(self, connection): + size = self.size or '' + return '%s[%s]' % (self.base_field.db_type(connection), size) + + def get_prep_value(self, value): + if isinstance(value, list) or isinstance(value, tuple): + return [self.base_field.get_prep_value(i) for i in value] + return value + + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): + if lookup_type == 'contains': + return [self.get_prep_value(value)] + return super(ArrayField, self).get_db_prep_lookup(lookup_type, value, + connection, prepared=False) + + def deconstruct(self): + name, path, args, kwargs = super(ArrayField, self).deconstruct() + path = 'django.contrib.postgres.fields.ArrayField' + args.insert(0, self.base_field) + kwargs['size'] = self.size + return name, path, args, kwargs + + def to_python(self, value): + if isinstance(value, six.string_types): + # Assume we're deserializing + vals = json.loads(value) + value = [self.base_field.to_python(val) for val in vals] + return value + + def value_to_string(self, obj): + values = [] + vals = self._get_val_from_obj(obj) + base_field = self.base_field + + for val in vals: + obj = AttributeSetter(base_field.attname, val) + values.append(base_field.value_to_string(obj)) + return json.dumps(values) + + def get_transform(self, name): + transform = super(ArrayField, self).get_transform(name) + if transform: + return transform + try: + index = int(name) + except ValueError: + pass + else: + index += 1 # postgres uses 1-indexing + return IndexTransformFactory(index, self.base_field) + try: + start, end = name.split('_') + start = int(start) + 1 + end = int(end) # don't add one here because postgres slices are weird + except ValueError: + pass + else: + return SliceTransformFactory(start, end) + + def validate(self, value, model_instance): + super(ArrayField, self).validate(value, model_instance) + for i, part in enumerate(value): + try: + self.base_field.validate(part, model_instance) + except exceptions.ValidationError as e: + raise exceptions.ValidationError( + string_concat(self.error_messages['item_invalid'], e.message), + code='item_invalid', + params={'nth': i}, + ) + if isinstance(self.base_field, ArrayField): + if len({len(i) for i in value}) > 1: + raise exceptions.ValidationError( + self.error_messages['nested_array_mismatch'], + code='nested_array_mismatch', + ) + + def formfield(self, **kwargs): + defaults = { + 'form_class': SimpleArrayField, + 'base_field': self.base_field.formfield(), + 'max_length': self.size, + } + defaults.update(kwargs) + return super(ArrayField, self).formfield(**defaults) + + +class ArrayContainsLookup(Lookup): + lookup_name = 'contains' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s @> %s' % (lhs, rhs), params + + +ArrayField.register_lookup(ArrayContainsLookup) + + +class ArrayContainedByLookup(Lookup): + lookup_name = 'contained_by' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s <@ %s' % (lhs, rhs), params + + +ArrayField.register_lookup(ArrayContainedByLookup) + + +class ArrayOverlapLookup(Lookup): + lookup_name = 'overlap' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s && %s' % (lhs, rhs), params + + +ArrayField.register_lookup(ArrayOverlapLookup) + + +class ArrayLenTransform(Transform): + lookup_name = 'len' + + @property + def output_type(self): + return IntegerField() + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return 'array_length(%s, 1)' % lhs, params + + +ArrayField.register_lookup(ArrayLenTransform) + + +class IndexTransform(Transform): + + def __init__(self, index, base_field, *args, **kwargs): + super(IndexTransform, self).__init__(*args, **kwargs) + self.index = index + self.base_field = base_field + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return '%s[%s]' % (lhs, self.index), params + + @property + def output_type(self): + return self.base_field + + +class IndexTransformFactory(object): + + def __init__(self, index, base_field): + self.index = index + self.base_field = base_field + + def __call__(self, *args, **kwargs): + return IndexTransform(self.index, self.base_field, *args, **kwargs) + + +class SliceTransform(Transform): + + def __init__(self, start, end, *args, **kwargs): + super(SliceTransform, self).__init__(*args, **kwargs) + self.start = start + self.end = end + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return '%s[%s:%s]' % (lhs, self.start, self.end), params + + +class SliceTransformFactory(object): + + def __init__(self, start, end): + self.start = start + self.end = end + + def __call__(self, *args, **kwargs): + return SliceTransform(self.start, self.end, *args, **kwargs) diff --git a/django/contrib/postgres/forms/__init__.py b/django/contrib/postgres/forms/__init__.py new file mode 100644 index 0000000000..e3ceebd62c --- /dev/null +++ b/django/contrib/postgres/forms/__init__.py @@ -0,0 +1 @@ +from .array import * # NOQA diff --git a/django/contrib/postgres/forms/array.py b/django/contrib/postgres/forms/array.py new file mode 100644 index 0000000000..620c7c7b6e --- /dev/null +++ b/django/contrib/postgres/forms/array.py @@ -0,0 +1,185 @@ +import copy + +from django.contrib.postgres.validators import ArrayMinLengthValidator, ArrayMaxLengthValidator +from django.core.exceptions import ValidationError +from django import forms +from django.utils.safestring import mark_safe +from django.utils import six +from django.utils.translation import string_concat, ugettext_lazy as _ + + +class SimpleArrayField(forms.CharField): + default_error_messages = { + 'item_invalid': _('Item %(nth)s in the array did not validate: '), + } + + def __init__(self, base_field, delimiter=',', max_length=None, min_length=None, *args, **kwargs): + self.base_field = base_field + self.delimiter = delimiter + super(SimpleArrayField, self).__init__(*args, **kwargs) + if min_length is not None: + self.min_length = min_length + self.validators.append(ArrayMinLengthValidator(int(min_length))) + if max_length is not None: + self.max_length = max_length + self.validators.append(ArrayMaxLengthValidator(int(max_length))) + + def prepare_value(self, value): + if isinstance(value, list): + return self.delimiter.join([six.text_type(self.base_field.prepare_value(v)) for v in value]) + return value + + def to_python(self, value): + if value: + items = value.split(self.delimiter) + else: + items = [] + errors = [] + values = [] + for i, item in enumerate(items): + try: + values.append(self.base_field.to_python(item)) + except ValidationError as e: + for error in e.error_list: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + if errors: + raise ValidationError(errors) + return values + + def validate(self, value): + super(SimpleArrayField, self).validate(value) + errors = [] + for i, item in enumerate(value): + try: + self.base_field.validate(item) + except ValidationError as e: + for error in e.error_list: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + if errors: + raise ValidationError(errors) + + def run_validators(self, value): + super(SimpleArrayField, self).run_validators(value) + errors = [] + for i, item in enumerate(value): + try: + self.base_field.run_validators(item) + except ValidationError as e: + for error in e.error_list: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + if errors: + raise ValidationError(errors) + + +class SplitArrayWidget(forms.Widget): + + def __init__(self, widget, size, **kwargs): + self.widget = widget() if isinstance(widget, type) else widget + self.size = size + super(SplitArrayWidget, self).__init__(**kwargs) + + @property + def is_hidden(self): + return self.widget.is_hidden + + def value_from_datadict(self, data, files, name): + return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index)) + for index in range(self.size)] + + def id_for_label(self, id_): + # See the comment for RadioSelect.id_for_label() + if id_: + id_ += '_0' + return id_ + + def render(self, name, value, attrs=None): + if self.is_localized: + self.widget.is_localized = self.is_localized + value = value or [] + output = [] + final_attrs = self.build_attrs(attrs) + id_ = final_attrs.get('id', None) + for i in range(max(len(value), self.size)): + try: + widget_value = value[i] + except IndexError: + widget_value = None + if id_: + final_attrs = dict(final_attrs, id='%s_%s' % (id_, i)) + output.append(self.widget.render(name + '_%s' % i, widget_value, final_attrs)) + return mark_safe(self.format_output(output)) + + def format_output(self, rendered_widgets): + return ''.join(rendered_widgets) + + @property + def media(self): + return self.widget.media + + def __deepcopy__(self, memo): + obj = super(SplitArrayWidget, self).__deepcopy__(memo) + obj.widget = copy.deepcopy(self.widget) + return obj + + @property + def needs_multipart_form(self): + return self.widget.needs_multipart_form + + +class SplitArrayField(forms.Field): + default_error_messages = { + 'item_invalid': _('Item %(nth)s in the array did not validate: '), + } + + def __init__(self, base_field, size, remove_trailing_nulls=False, **kwargs): + self.base_field = base_field + self.size = size + self.remove_trailing_nulls = remove_trailing_nulls + widget = SplitArrayWidget(widget=base_field.widget, size=size) + kwargs.setdefault('widget', widget) + super(SplitArrayField, self).__init__(**kwargs) + + def clean(self, value): + cleaned_data = [] + errors = [] + if not any(value) and self.required: + raise ValidationError(self.error_messages['required']) + max_size = max(self.size, len(value)) + for i in range(max_size): + item = value[i] + try: + cleaned_data.append(self.base_field.clean(item)) + errors.append(None) + except ValidationError as error: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + cleaned_data.append(None) + if self.remove_trailing_nulls: + null_index = None + for i, value in reversed(list(enumerate(cleaned_data))): + if value in self.base_field.empty_values: + null_index = i + else: + break + if null_index: + cleaned_data = cleaned_data[:null_index] + errors = errors[:null_index] + errors = list(filter(None, errors)) + if errors: + raise ValidationError(errors) + return cleaned_data diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py new file mode 100644 index 0000000000..353305949e --- /dev/null +++ b/django/contrib/postgres/validators.py @@ -0,0 +1,16 @@ +from django.core.validators import MaxLengthValidator, MinLengthValidator +from django.utils.translation import ungettext_lazy + + +class ArrayMaxLengthValidator(MaxLengthValidator): + message = ungettext_lazy( + 'List contains %(show_value)d item, it should contain no more than %(limit_value)d.', + 'List contains %(show_value)d items, it should contain no more than %(limit_value)d.', + 'limit_value') + + +class ArrayMinLengthValidator(MinLengthValidator): + message = ungettext_lazy( + 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.', + 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.', + 'limit_value') diff --git a/docs/index.txt b/docs/index.txt index e0529c4503..b41f5c0ecb 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -91,7 +91,8 @@ manipulating the data of your Web application. Learn more about it below: :doc:`Supported databases ` | :doc:`Legacy databases ` | :doc:`Providing initial data ` | - :doc:`Optimize database access ` + :doc:`Optimize database access ` | + :doc:`PostgreSQL specific features ` The view layer ============== diff --git a/docs/ref/contrib/index.txt b/docs/ref/contrib/index.txt index 533680659e..ebfc2874b4 100644 --- a/docs/ref/contrib/index.txt +++ b/docs/ref/contrib/index.txt @@ -31,6 +31,7 @@ those packages have. gis/index humanize messages + postgres/index redirects sitemaps sites @@ -122,6 +123,13 @@ messages See the :doc:`messages documentation `. +postgres +======== + +A collection of PostgreSQL specific features. + +See the :doc:`contrib.postgres documentation `. + redirects ========= diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt new file mode 100644 index 0000000000..dcde84d2ec --- /dev/null +++ b/docs/ref/contrib/postgres/fields.txt @@ -0,0 +1,228 @@ +PostgreSQL specific model fields +================================ + +All of these fields are available from the ``django.contrib.postgres.fields`` +module. + +.. currentmodule:: django.contrib.postgres.fields + +ArrayField +---------- + +.. class:: ArrayField(base_field, size=None, **options) + + A field for storing lists of data. Most field types can be used, you simply + pass another field instance as the :attr:`base_field + `. You may also specify a :attr:`size + `. ``ArrayField`` can be nested to store multi-dimensional + arrays. + + .. attribute:: base_field + + This is a required argument. + + Specifies the underlying data type and behaviour for the array. It + should be an instance of a subclass of + :class:`~django.db.models.Field`. For example, it could be an + :class:`~django.db.models.IntegerField` or a + :class:`~django.db.models.CharField`. Most field types are permitted, + with the exception of those handling relational data + (:class:`~django.db.models.ForeignKey`, + :class:`~django.db.models.OneToOneField` and + :class:`~django.db.models.ManyToManyField`). + + It is possible to nest array fields - you can specify an instance of + ``ArrayField`` as the ``base_field``. For example:: + + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class ChessBoard(models.Model): + board = ArrayField( + ArrayField( + CharField(max_length=10, blank=True, null=True), + size=8), + size=8) + + Transformation of values between the database and the model, validation + of data and configuration, and serialization are all delegated to the + underlying base field. + + .. attribute:: size + + This is an optional argument. + + If passed, the array will have a maximum size as specified. This will + be passed to the database, although PostgreSQL at present does not + enforce the restriction. + +.. note:: + + When nesting ``ArrayField``, whether you use the `size` parameter or not, + PostgreSQL requires that the arrays are rectangular:: + + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class Board(models.Model): + pieces = ArrayField(ArrayField(models.IntegerField())) + + # Valid + Board(pieces=[ + [2, 3], + [2, 1], + ]) + + # Not valid + Board(pieces=[ + [2, 3], + [2], + ]) + + If irregular shapes are required, then the underlying field should be made + nullable and the values padded with ``None``. + +Querying ArrayField +^^^^^^^^^^^^^^^^^^^ + +There are a number of custom lookups and transforms for :class:`ArrayField`. +We will use the following example model:: + + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class Post(models.Model): + name = models.CharField(max_length=200) + tags = ArrayField(models.CharField(max_length=200), blank=True) + + def __str__(self): # __unicode__ on python 2 + return self.name + +.. fieldlookup:: arrayfield.contains + +contains +~~~~~~~~ + +The :lookup:`contains` lookup is overridden on :class:`ArrayField`. The +returned objects will be those where the values passed are a subset of the +data. It uses the SQL operator ``@>``. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['tutorial', 'django']) + + >>> Post.objects.filter(tags__contains=['thoughts']) + [, ] + + >>> Post.objects.filter(tags__contains=['django']) + [, ] + + >>> Post.objects.filter(tags__contains=['django', 'thoughts']) + [] + +.. fieldlookup:: arrayfield.contained_by + +contained_by +~~~~~~~~~~~~ + +This is the inverse of the :lookup:`contains ` lookup - +the objects returned will be those where the data is a subset of the values +passed. It uses the SQL operator ``<@``. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['tutorial', 'django']) + + >>> Post.objects.filter(tags__contained_by=['thoughts', 'django']) + [] + + >>> Post.objects.filter(tags__contained_by=['thoughts', 'django', 'tutorial']) + [, , ] + +.. fieldlookup:: arrayfield.overlap + +overlap +~~~~~~~ + +Returns objects where the data shares any results with the values passed. Uses +the SQL operator ``&&``. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['tutorial', 'django']) + + >>> Post.objects.filter(tags__overlap=['thoughts']) + [, ] + + >>> Post.objects.filter(tags__overlap=['thoughts', 'tutorial']) + [, , ] + +.. fieldlookup:: arrayfield.index + +Index transforms +~~~~~~~~~~~~~~~~ + +This class of transforms allows you to index into the array in queries. Any +non-negative integer can be used. There are no errors if it exceeds the +:attr:`size ` of the array. The lookups available after the +transform are those from the :attr:`base_field `. For +example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + + >>> Post.objects.filter(tags__0='thoughts') + [, ] + + >>> Post.objects.filter(tags__1__iexact='Django') + [] + + >>> Post.objects.filter(tags__276='javascript') + [] + +.. note:: + + PostgreSQL uses 1-based indexing for array fields when writing raw SQL. + However these indexes and those used in :lookup:`slices ` + use 0-based indexing to be consistent with Python. + +.. fieldlookup:: arrayfield.slice + +Slice transforms +~~~~~~~~~~~~~~~~ + +This class of transforms allow you to take a slice of the array. Any two +non-negative integers can be used, separated by a single underscore. The +lookups available after the transform do not change. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['django', 'python', 'thoughts']) + + >>> Post.objects.filter(tags__0_1=['thoughts']) + [] + + >>> Post.objects.filter(tags__0_2__contains='thoughts') + [, ] + +.. note:: + + PostgreSQL uses 1-based indexing for array fields when writing raw SQL. + However these slices and those used in :lookup:`indexes ` + use 0-based indexing to be consistent with Python. + +.. admonition:: Multidimensional arrays with indexes and slices + + PostgreSQL has some rather esoteric behaviour when using indexes and slices + on multidimensional arrays. It will always work to use indexes to reach + down to the final underlying data, but most other slices behave strangely + at the database level and cannot be supported in a logical, consistent + fashion by Django. + +Indexing ArrayField +^^^^^^^^^^^^^^^^^^^ + +At present using :attr:`~django.db.models.Field.db_index` will create a +``btree`` index. This does not offer particularly significant help to querying. +A more useful index is a ``GIN`` index, which you should create using a +:class:`~django.db.migrations.operations.RunSQL` operation. diff --git a/docs/ref/contrib/postgres/forms.txt b/docs/ref/contrib/postgres/forms.txt new file mode 100644 index 0000000000..6cad537f3b --- /dev/null +++ b/docs/ref/contrib/postgres/forms.txt @@ -0,0 +1,135 @@ +PostgreSQL specific form fields and widgets +=========================================== + +All of these fields and widgets are available from the +``django.contrib.postgres.forms`` module. + +.. currentmodule:: django.contrib.postgres.forms + +SimpleArrayField +---------------- + +.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None) + + A simple field which maps to an array. It is represented by an HTML + ````. + + .. attribute:: base_field + + This is a required argument. + + It specifies the underlying form field for the array. This is not used + to render any HTML, but it is used to process the submitted data and + validate it. For example:: + + >>> from django.contrib.postgres.forms import SimpleArrayField + >>> from django import forms + + >>> class NumberListForm(forms.Form): + ... numbers = SimpleArrayField(forms.IntegerField()) + + >>> form = NumberListForm({'numbers': '1,2,3'}) + >>> form.is_valid() + True + >>> form.cleaned_data + {'numbers': [1, 2, 3]} + + >>> form = NumberListForm({'numbers': '1,2,a'}) + >>> form.is_valid() + False + + .. attribute:: delimiter + + This is an optional argument which defaults to a comma: ``,``. This + value is used to split the submitted data. It allows you to chain + ``SimpleArrayField`` for multidimensional data:: + + >>> from django.contrib.postgres.forms import SimpleArrayField + >>> from django import forms + + >>> class GridForm(forms.Form): + ... places = SimpleArrayField(SimpleArrayField(IntegerField()), delimiter='|') + + >>> form = GridForm({'places': '1,2|2,1|4,3'}) + >>> form.is_valid() + True + >>> form.cleaned_data + {'places': [[1, 2], [2, 1], [4, 3]]} + + .. note:: + + The field does not support escaping of the delimiter, so be careful + in cases where the delimiter is a valid character in the underlying + field. The delimiter does not need to be only one character. + + .. attribute:: max_length + + This is an optional argument which validates that the array does not + exceed the stated length. + + .. attribute:: min_length + + This is an optional argument which validates that the array reaches at + least the stated length. + + .. admonition:: User friendly forms + + ``SimpleArrayField`` is not particularly user friendly in most cases, + however it is a useful way to format data from a client-side widget for + submission to the server. + +SplitArrayField +--------------- + +.. class:: SplitArrayField(base_field, size, remove_trailing_nulls=False) + + This field handles arrays by reproducing the underlying field a fixed + number of times. + + .. attribute:: base_field + + This is a required argument. It specifies the form field to be + repeated. + + .. attribute:: size + + This is the fixed number of times the underlying field will be used. + + .. attribute:: remove_trailing_nulls + + By default, this is set to ``False``. When ``False``, each value from + the repeated fields is stored. When set to ``True``, any trailing + values which are blank will be stripped from the result. If the + underlying field has ``required=True``, but ``remove_trailing_nulls`` + is ``True``, then null values are only allowed at the end, and will be + stripped. + + Some examples:: + + SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=False) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> ValidationError - third entry required. + ['1', '', '3'] # -> ValidationError - second entry required. + ['', '2', ''] # -> ValidationError - first and third entries required. + + SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=False) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> [1, 2, None] + ['1', '', '3'] # -> [1, None, 3] + ['', '2', ''] # -> [None, 2, None] + + SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=True) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> [1, 2] + ['1', '', '3'] # -> ValidationError - second entry required. + ['', '2', ''] # -> ValidationError - first entry required. + + SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=True) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> [1, 2] + ['1', '', '3'] # -> [1, None, 3] + ['', '2', ''] # -> [None, 2] diff --git a/docs/ref/contrib/postgres/index.txt b/docs/ref/contrib/postgres/index.txt new file mode 100644 index 0000000000..5db4ab80ed --- /dev/null +++ b/docs/ref/contrib/postgres/index.txt @@ -0,0 +1,28 @@ +``django.contrib.postgres`` +=========================== + +PostgreSQL has a number of features which are not shared by the other databases +Django supports. This optional module contains model fields and form fields for +a number of PostgreSQL specific data types. + +.. note:: + Django is, and will continue to be, a database-agnostic web framework. We + would encourage those writing reusable applications for the Django + community to write database-agnostic code where practical. However, we + recognise that real world projects written using Django need not be + database-agnostic. In fact, once a project reaches a given size changing + the underlying data store is already a significant challenge and is likely + to require changing the code base in some ways to handle differences + between the data stores. + + Django provides support for a number of data types which will + only work with PostgreSQL. There is no fundamental reason why (for example) + a ``contrib.mysql`` module does not exist, except that PostgreSQL has the + richest feature set of the supported databases so its users have the most + to gain. + +.. toctree:: + :maxdepth: 2 + + fields + forms diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py new file mode 100644 index 0000000000..6420ebe1cd --- /dev/null +++ b/tests/postgres_tests/models.py @@ -0,0 +1,22 @@ +from django.contrib.postgres.fields import ArrayField +from django.db import models + + +class IntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField()) + + +class NullableIntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField(), blank=True, null=True) + + +class CharArrayModel(models.Model): + field = ArrayField(models.CharField(max_length=10)) + + +class DateTimeArrayModel(models.Model): + field = ArrayField(models.DateTimeField()) + + +class NestedIntegerArrayModel(models.Model): + field = ArrayField(ArrayField(models.IntegerField())) diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py new file mode 100644 index 0000000000..35ea65480a --- /dev/null +++ b/tests/postgres_tests/test_array.py @@ -0,0 +1,389 @@ +import unittest + +from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField +from django.core import exceptions, serializers +from django.db import models, IntegrityError, connection +from django.db.migrations.writer import MigrationWriter +from django import forms +from django.test import TestCase +from django.utils import timezone + +from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestSaveLoad(TestCase): + + def test_integer(self): + instance = IntegerArrayModel(field=[1, 2, 3]) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_char(self): + instance = CharArrayModel(field=['hello', 'goodbye']) + instance.save() + loaded = CharArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_dates(self): + instance = DateTimeArrayModel(field=[timezone.now()]) + instance.save() + loaded = DateTimeArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_tuples(self): + instance = IntegerArrayModel(field=(1,)) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertSequenceEqual(instance.field, loaded.field) + + def test_integers_passed_as_strings(self): + # This checks that get_prep_value is deferred properly + instance = IntegerArrayModel(field=['1']) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertEqual(loaded.field, [1]) + + def test_null_handling(self): + instance = NullableIntegerArrayModel(field=None) + instance.save() + loaded = NullableIntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + instance = IntegerArrayModel(field=None) + with self.assertRaises(IntegrityError): + instance.save() + + def test_nested(self): + instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]]) + instance.save() + loaded = NestedIntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestQuerying(TestCase): + + def setUp(self): + self.objs = [ + NullableIntegerArrayModel.objects.create(field=[1]), + NullableIntegerArrayModel.objects.create(field=[2]), + NullableIntegerArrayModel.objects.create(field=[2, 3]), + NullableIntegerArrayModel.objects.create(field=[20, 30, 40]), + NullableIntegerArrayModel.objects.create(field=None), + ] + + def test_exact(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[1]), + self.objs[:1] + ) + + def test_isnull(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__isnull=True), + self.objs[-1:] + ) + + def test_gt(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=[0]), + self.objs[:4] + ) + + def test_lt(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=[2]), + self.objs[:1] + ) + + def test_in(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), + self.objs[:2] + ) + + def test_contained_by(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), + self.objs[:2] + ) + + def test_contains(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[2]), + self.objs[1:3] + ) + + def test_index(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0=2), + self.objs[1:3] + ) + + def test_index_chained(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0__lt=3), + self.objs[0:3] + ) + + def test_index_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0__0=1), + [instance] + ) + + @unittest.expectedFailure + def test_index_used_on_nested_data(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), + [instance] + ) + + def test_overlap(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), + self.objs[0:3] + ) + + def test_len(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__len__lte=2), + self.objs[0:3] + ) + + def test_slice(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0_1=[2]), + self.objs[1:3] + ) + + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), + self.objs[2:3] + ) + + @unittest.expectedFailure + def test_slice_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), + [instance] + ) + + +class TestChecks(TestCase): + + def test_field_checks(self): + field = ArrayField(models.CharField()) + field.set_attributes_from_name('field') + errors = field.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, 'postgres.E001') + + def test_invalid_base_fields(self): + field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) + field.set_attributes_from_name('field') + errors = field.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, 'postgres.E002') + + +class TestMigrations(TestCase): + + def test_deconstruct(self): + field = ArrayField(models.IntegerField()) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(type(new.base_field), type(field.base_field)) + + def test_deconstruct_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(new.size, field.size) + + def test_deconstruct_args(self): + field = ArrayField(models.CharField(max_length=20)) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(new.base_field.max_length, field.base_field.max_length) + + def test_makemigrations(self): + field = ArrayField(models.CharField(max_length=20)) + statement, imports = MigrationWriter.serialize(field) + self.assertEqual(statement, 'django.contrib.postgres.fields.ArrayField(models.CharField(max_length=20), size=None)') + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestSerialization(TestCase): + test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\"]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]' + + def test_dumping(self): + instance = IntegerArrayModel(field=[1, 2]) + data = serializers.serialize('json', [instance]) + self.assertEqual(data, self.test_data) + + def test_loading(self): + instance = list(serializers.deserialize('json', self.test_data))[0].object + self.assertEqual(instance.field, [1, 2]) + + +class TestValidation(TestCase): + + def test_unbounded(self): + field = ArrayField(models.IntegerField()) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([1, None], None) + self.assertEqual(cm.exception.code, 'item_invalid') + self.assertEqual(cm.exception.message % cm.exception.params, 'Item 1 in the array did not validate: This field cannot be null.') + + def test_blank_true(self): + field = ArrayField(models.IntegerField(blank=True, null=True)) + # This should not raise a validation error + field.clean([1, None], None) + + def test_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + field.clean([1, 2, 3], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([1, 2, 3, 4], None) + self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.') + + def test_nested_array_mismatch(self): + field = ArrayField(ArrayField(models.IntegerField())) + field.clean([[1, 2], [3, 4]], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([[1, 2], [3, 4, 5]], None) + self.assertEqual(cm.exception.code, 'nested_array_mismatch') + self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.') + + +class TestSimpleFormField(TestCase): + + def test_valid(self): + field = SimpleArrayField(forms.CharField()) + value = field.clean('a,b,c') + self.assertEqual(value, ['a', 'b', 'c']) + + def test_to_python_fail(self): + field = SimpleArrayField(forms.IntegerField()) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,9') + self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a whole number.') + + def test_validate_fail(self): + field = SimpleArrayField(forms.CharField(required=True)) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,') + self.assertEqual(cm.exception.messages[0], 'Item 2 in the array did not validate: This field is required.') + + def test_validators_fail(self): + field = SimpleArrayField(forms.RegexField('[a-e]{2}')) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,bc,de') + self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a valid value.') + + def test_delimiter(self): + field = SimpleArrayField(forms.CharField(), delimiter='|') + value = field.clean('a|b|c') + self.assertEqual(value, ['a', 'b', 'c']) + + def test_delimiter_with_nesting(self): + field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|') + value = field.clean('a,b|c,d') + self.assertEqual(value, [['a', 'b'], ['c', 'd']]) + + def test_prepare_value(self): + field = SimpleArrayField(forms.CharField()) + value = field.prepare_value(['a', 'b', 'c']) + self.assertEqual(value, 'a,b,c') + + def test_max_length(self): + field = SimpleArrayField(forms.CharField(), max_length=2) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,c') + self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.') + + def test_min_length(self): + field = SimpleArrayField(forms.CharField(), min_length=4) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,c') + self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.') + + def test_required(self): + field = SimpleArrayField(forms.CharField(), required=True) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('') + self.assertEqual(cm.exception.messages[0], 'This field is required.') + + def test_model_field_formfield(self): + model_field = ArrayField(models.CharField(max_length=27)) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertIsInstance(form_field.base_field, forms.CharField) + self.assertEqual(form_field.base_field.max_length, 27) + + def test_model_field_formfield_size(self): + model_field = ArrayField(models.CharField(max_length=27), size=4) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertEqual(form_field.max_length, 4) + + +class TestSplitFormField(TestCase): + + def test_valid(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'} + form = SplitForm(data) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']}) + + def test_required(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), required=True, size=3) + + data = {'array_0': '', 'array_1': '', 'array_2': ''} + form = SplitForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors, {'array': ['This field is required.']}) + + def test_remove_trailing_nulls(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True) + + data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''} + form = SplitForm(data) + self.assertTrue(form.is_valid(), form.errors) + self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']}) + + def test_required_field(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''} + form = SplitForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors, {'array': ['Item 2 in the array did not validate: This field is required.']}) + + def test_rendering(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + self.assertHTMLEqual(str(SplitForm()), ''' + + + + + + + + + ''') diff --git a/tests/runtests.py b/tests/runtests.py index 787b83e7a5..14014f4b01 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -57,6 +57,7 @@ ALWAYS_INSTALLED_APPS = [ def get_test_modules(): from django.contrib.gis.tests.utils import HAS_SPATIAL_DB + from django.db import connection modules = [] discovery_paths = [ (None, RUNTESTS_DIR), @@ -75,6 +76,8 @@ def get_test_modules(): os.path.isfile(f) or not os.path.exists(os.path.join(dirpath, f, '__init__.py'))): continue + if not connection.vendor == 'postgresql' and f == 'postgres_tests': + continue modules.append((modpath, f)) return modules