diff --git a/django/contrib/postgres/__init__.py b/django/contrib/postgres/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/django/contrib/postgres/fields/__init__.py b/django/contrib/postgres/fields/__init__.py new file mode 100644 index 00000000000..e3ceebd62cd --- /dev/null +++ b/django/contrib/postgres/fields/__init__.py @@ -0,0 +1 @@ +from .array import * # NOQA diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py new file mode 100644 index 00000000000..7a37267400d --- /dev/null +++ b/django/contrib/postgres/fields/array.py @@ -0,0 +1,254 @@ +import json + +from django.contrib.postgres.forms import SimpleArrayField +from django.contrib.postgres.validators import ArrayMaxLengthValidator +from django.core import checks, exceptions +from django.db.models import Field, Lookup, Transform, IntegerField +from django.utils import six +from django.utils.translation import string_concat, ugettext_lazy as _ + + +__all__ = ['ArrayField'] + + +class AttributeSetter(object): + def __init__(self, name, value): + setattr(self, name, value) + + +class ArrayField(Field): + empty_strings_allowed = False + default_error_messages = { + 'item_invalid': _('Item %(nth)s in the array did not validate: '), + 'nested_array_mismatch': _('Nested arrays must have the same length.'), + } + + def __init__(self, base_field, size=None, **kwargs): + self.base_field = base_field + self.size = size + if self.size: + self.default_validators = self.default_validators[:] + self.default_validators.append(ArrayMaxLengthValidator(self.size)) + super(ArrayField, self).__init__(**kwargs) + + def check(self, **kwargs): + errors = super(ArrayField, self).check(**kwargs) + if self.base_field.rel: + errors.append( + checks.Error( + 'Base field for array cannot be a related field.', + hint=None, + obj=self, + id='postgres.E002' + ) + ) + else: + # Remove the field name checks as they are not needed here. + base_errors = self.base_field.check() + if base_errors: + messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors) + errors.append( + checks.Error( + 'Base field for array has errors:\n %s' % messages, + hint=None, + obj=self, + id='postgres.E001' + ) + ) + return errors + + def set_attributes_from_name(self, name): + super(ArrayField, self).set_attributes_from_name(name) + self.base_field.set_attributes_from_name(name) + + @property + def description(self): + return 'Array of %s' % self.base_field.description + + def db_type(self, connection): + size = self.size or '' + return '%s[%s]' % (self.base_field.db_type(connection), size) + + def get_prep_value(self, value): + if isinstance(value, list) or isinstance(value, tuple): + return [self.base_field.get_prep_value(i) for i in value] + return value + + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): + if lookup_type == 'contains': + return [self.get_prep_value(value)] + return super(ArrayField, self).get_db_prep_lookup(lookup_type, value, + connection, prepared=False) + + def deconstruct(self): + name, path, args, kwargs = super(ArrayField, self).deconstruct() + path = 'django.contrib.postgres.fields.ArrayField' + args.insert(0, self.base_field) + kwargs['size'] = self.size + return name, path, args, kwargs + + def to_python(self, value): + if isinstance(value, six.string_types): + # Assume we're deserializing + vals = json.loads(value) + value = [self.base_field.to_python(val) for val in vals] + return value + + def value_to_string(self, obj): + values = [] + vals = self._get_val_from_obj(obj) + base_field = self.base_field + + for val in vals: + obj = AttributeSetter(base_field.attname, val) + values.append(base_field.value_to_string(obj)) + return json.dumps(values) + + def get_transform(self, name): + transform = super(ArrayField, self).get_transform(name) + if transform: + return transform + try: + index = int(name) + except ValueError: + pass + else: + index += 1 # postgres uses 1-indexing + return IndexTransformFactory(index, self.base_field) + try: + start, end = name.split('_') + start = int(start) + 1 + end = int(end) # don't add one here because postgres slices are weird + except ValueError: + pass + else: + return SliceTransformFactory(start, end) + + def validate(self, value, model_instance): + super(ArrayField, self).validate(value, model_instance) + for i, part in enumerate(value): + try: + self.base_field.validate(part, model_instance) + except exceptions.ValidationError as e: + raise exceptions.ValidationError( + string_concat(self.error_messages['item_invalid'], e.message), + code='item_invalid', + params={'nth': i}, + ) + if isinstance(self.base_field, ArrayField): + if len({len(i) for i in value}) > 1: + raise exceptions.ValidationError( + self.error_messages['nested_array_mismatch'], + code='nested_array_mismatch', + ) + + def formfield(self, **kwargs): + defaults = { + 'form_class': SimpleArrayField, + 'base_field': self.base_field.formfield(), + 'max_length': self.size, + } + defaults.update(kwargs) + return super(ArrayField, self).formfield(**defaults) + + +class ArrayContainsLookup(Lookup): + lookup_name = 'contains' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s @> %s' % (lhs, rhs), params + + +ArrayField.register_lookup(ArrayContainsLookup) + + +class ArrayContainedByLookup(Lookup): + lookup_name = 'contained_by' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s <@ %s' % (lhs, rhs), params + + +ArrayField.register_lookup(ArrayContainedByLookup) + + +class ArrayOverlapLookup(Lookup): + lookup_name = 'overlap' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s && %s' % (lhs, rhs), params + + +ArrayField.register_lookup(ArrayOverlapLookup) + + +class ArrayLenTransform(Transform): + lookup_name = 'len' + + @property + def output_type(self): + return IntegerField() + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return 'array_length(%s, 1)' % lhs, params + + +ArrayField.register_lookup(ArrayLenTransform) + + +class IndexTransform(Transform): + + def __init__(self, index, base_field, *args, **kwargs): + super(IndexTransform, self).__init__(*args, **kwargs) + self.index = index + self.base_field = base_field + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return '%s[%s]' % (lhs, self.index), params + + @property + def output_type(self): + return self.base_field + + +class IndexTransformFactory(object): + + def __init__(self, index, base_field): + self.index = index + self.base_field = base_field + + def __call__(self, *args, **kwargs): + return IndexTransform(self.index, self.base_field, *args, **kwargs) + + +class SliceTransform(Transform): + + def __init__(self, start, end, *args, **kwargs): + super(SliceTransform, self).__init__(*args, **kwargs) + self.start = start + self.end = end + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return '%s[%s:%s]' % (lhs, self.start, self.end), params + + +class SliceTransformFactory(object): + + def __init__(self, start, end): + self.start = start + self.end = end + + def __call__(self, *args, **kwargs): + return SliceTransform(self.start, self.end, *args, **kwargs) diff --git a/django/contrib/postgres/forms/__init__.py b/django/contrib/postgres/forms/__init__.py new file mode 100644 index 00000000000..e3ceebd62cd --- /dev/null +++ b/django/contrib/postgres/forms/__init__.py @@ -0,0 +1 @@ +from .array import * # NOQA diff --git a/django/contrib/postgres/forms/array.py b/django/contrib/postgres/forms/array.py new file mode 100644 index 00000000000..620c7c7b6e0 --- /dev/null +++ b/django/contrib/postgres/forms/array.py @@ -0,0 +1,185 @@ +import copy + +from django.contrib.postgres.validators import ArrayMinLengthValidator, ArrayMaxLengthValidator +from django.core.exceptions import ValidationError +from django import forms +from django.utils.safestring import mark_safe +from django.utils import six +from django.utils.translation import string_concat, ugettext_lazy as _ + + +class SimpleArrayField(forms.CharField): + default_error_messages = { + 'item_invalid': _('Item %(nth)s in the array did not validate: '), + } + + def __init__(self, base_field, delimiter=',', max_length=None, min_length=None, *args, **kwargs): + self.base_field = base_field + self.delimiter = delimiter + super(SimpleArrayField, self).__init__(*args, **kwargs) + if min_length is not None: + self.min_length = min_length + self.validators.append(ArrayMinLengthValidator(int(min_length))) + if max_length is not None: + self.max_length = max_length + self.validators.append(ArrayMaxLengthValidator(int(max_length))) + + def prepare_value(self, value): + if isinstance(value, list): + return self.delimiter.join([six.text_type(self.base_field.prepare_value(v)) for v in value]) + return value + + def to_python(self, value): + if value: + items = value.split(self.delimiter) + else: + items = [] + errors = [] + values = [] + for i, item in enumerate(items): + try: + values.append(self.base_field.to_python(item)) + except ValidationError as e: + for error in e.error_list: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + if errors: + raise ValidationError(errors) + return values + + def validate(self, value): + super(SimpleArrayField, self).validate(value) + errors = [] + for i, item in enumerate(value): + try: + self.base_field.validate(item) + except ValidationError as e: + for error in e.error_list: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + if errors: + raise ValidationError(errors) + + def run_validators(self, value): + super(SimpleArrayField, self).run_validators(value) + errors = [] + for i, item in enumerate(value): + try: + self.base_field.run_validators(item) + except ValidationError as e: + for error in e.error_list: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + if errors: + raise ValidationError(errors) + + +class SplitArrayWidget(forms.Widget): + + def __init__(self, widget, size, **kwargs): + self.widget = widget() if isinstance(widget, type) else widget + self.size = size + super(SplitArrayWidget, self).__init__(**kwargs) + + @property + def is_hidden(self): + return self.widget.is_hidden + + def value_from_datadict(self, data, files, name): + return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index)) + for index in range(self.size)] + + def id_for_label(self, id_): + # See the comment for RadioSelect.id_for_label() + if id_: + id_ += '_0' + return id_ + + def render(self, name, value, attrs=None): + if self.is_localized: + self.widget.is_localized = self.is_localized + value = value or [] + output = [] + final_attrs = self.build_attrs(attrs) + id_ = final_attrs.get('id', None) + for i in range(max(len(value), self.size)): + try: + widget_value = value[i] + except IndexError: + widget_value = None + if id_: + final_attrs = dict(final_attrs, id='%s_%s' % (id_, i)) + output.append(self.widget.render(name + '_%s' % i, widget_value, final_attrs)) + return mark_safe(self.format_output(output)) + + def format_output(self, rendered_widgets): + return ''.join(rendered_widgets) + + @property + def media(self): + return self.widget.media + + def __deepcopy__(self, memo): + obj = super(SplitArrayWidget, self).__deepcopy__(memo) + obj.widget = copy.deepcopy(self.widget) + return obj + + @property + def needs_multipart_form(self): + return self.widget.needs_multipart_form + + +class SplitArrayField(forms.Field): + default_error_messages = { + 'item_invalid': _('Item %(nth)s in the array did not validate: '), + } + + def __init__(self, base_field, size, remove_trailing_nulls=False, **kwargs): + self.base_field = base_field + self.size = size + self.remove_trailing_nulls = remove_trailing_nulls + widget = SplitArrayWidget(widget=base_field.widget, size=size) + kwargs.setdefault('widget', widget) + super(SplitArrayField, self).__init__(**kwargs) + + def clean(self, value): + cleaned_data = [] + errors = [] + if not any(value) and self.required: + raise ValidationError(self.error_messages['required']) + max_size = max(self.size, len(value)) + for i in range(max_size): + item = value[i] + try: + cleaned_data.append(self.base_field.clean(item)) + errors.append(None) + except ValidationError as error: + errors.append(ValidationError( + string_concat(self.error_messages['item_invalid'], error.message), + code='item_invalid', + params={'nth': i}, + )) + cleaned_data.append(None) + if self.remove_trailing_nulls: + null_index = None + for i, value in reversed(list(enumerate(cleaned_data))): + if value in self.base_field.empty_values: + null_index = i + else: + break + if null_index: + cleaned_data = cleaned_data[:null_index] + errors = errors[:null_index] + errors = list(filter(None, errors)) + if errors: + raise ValidationError(errors) + return cleaned_data diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py new file mode 100644 index 00000000000..353305949ee --- /dev/null +++ b/django/contrib/postgres/validators.py @@ -0,0 +1,16 @@ +from django.core.validators import MaxLengthValidator, MinLengthValidator +from django.utils.translation import ungettext_lazy + + +class ArrayMaxLengthValidator(MaxLengthValidator): + message = ungettext_lazy( + 'List contains %(show_value)d item, it should contain no more than %(limit_value)d.', + 'List contains %(show_value)d items, it should contain no more than %(limit_value)d.', + 'limit_value') + + +class ArrayMinLengthValidator(MinLengthValidator): + message = ungettext_lazy( + 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.', + 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.', + 'limit_value') diff --git a/docs/index.txt b/docs/index.txt index e0529c4503f..b41f5c0ecb0 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -91,7 +91,8 @@ manipulating the data of your Web application. Learn more about it below: :doc:`Supported databases ` | :doc:`Legacy databases ` | :doc:`Providing initial data ` | - :doc:`Optimize database access ` + :doc:`Optimize database access ` | + :doc:`PostgreSQL specific features ` The view layer ============== diff --git a/docs/ref/contrib/index.txt b/docs/ref/contrib/index.txt index 533680659ea..ebfc2874b44 100644 --- a/docs/ref/contrib/index.txt +++ b/docs/ref/contrib/index.txt @@ -31,6 +31,7 @@ those packages have. gis/index humanize messages + postgres/index redirects sitemaps sites @@ -122,6 +123,13 @@ messages See the :doc:`messages documentation `. +postgres +======== + +A collection of PostgreSQL specific features. + +See the :doc:`contrib.postgres documentation `. + redirects ========= diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt new file mode 100644 index 00000000000..dcde84d2ecd --- /dev/null +++ b/docs/ref/contrib/postgres/fields.txt @@ -0,0 +1,228 @@ +PostgreSQL specific model fields +================================ + +All of these fields are available from the ``django.contrib.postgres.fields`` +module. + +.. currentmodule:: django.contrib.postgres.fields + +ArrayField +---------- + +.. class:: ArrayField(base_field, size=None, **options) + + A field for storing lists of data. Most field types can be used, you simply + pass another field instance as the :attr:`base_field + `. You may also specify a :attr:`size + `. ``ArrayField`` can be nested to store multi-dimensional + arrays. + + .. attribute:: base_field + + This is a required argument. + + Specifies the underlying data type and behaviour for the array. It + should be an instance of a subclass of + :class:`~django.db.models.Field`. For example, it could be an + :class:`~django.db.models.IntegerField` or a + :class:`~django.db.models.CharField`. Most field types are permitted, + with the exception of those handling relational data + (:class:`~django.db.models.ForeignKey`, + :class:`~django.db.models.OneToOneField` and + :class:`~django.db.models.ManyToManyField`). + + It is possible to nest array fields - you can specify an instance of + ``ArrayField`` as the ``base_field``. For example:: + + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class ChessBoard(models.Model): + board = ArrayField( + ArrayField( + CharField(max_length=10, blank=True, null=True), + size=8), + size=8) + + Transformation of values between the database and the model, validation + of data and configuration, and serialization are all delegated to the + underlying base field. + + .. attribute:: size + + This is an optional argument. + + If passed, the array will have a maximum size as specified. This will + be passed to the database, although PostgreSQL at present does not + enforce the restriction. + +.. note:: + + When nesting ``ArrayField``, whether you use the `size` parameter or not, + PostgreSQL requires that the arrays are rectangular:: + + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class Board(models.Model): + pieces = ArrayField(ArrayField(models.IntegerField())) + + # Valid + Board(pieces=[ + [2, 3], + [2, 1], + ]) + + # Not valid + Board(pieces=[ + [2, 3], + [2], + ]) + + If irregular shapes are required, then the underlying field should be made + nullable and the values padded with ``None``. + +Querying ArrayField +^^^^^^^^^^^^^^^^^^^ + +There are a number of custom lookups and transforms for :class:`ArrayField`. +We will use the following example model:: + + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class Post(models.Model): + name = models.CharField(max_length=200) + tags = ArrayField(models.CharField(max_length=200), blank=True) + + def __str__(self): # __unicode__ on python 2 + return self.name + +.. fieldlookup:: arrayfield.contains + +contains +~~~~~~~~ + +The :lookup:`contains` lookup is overridden on :class:`ArrayField`. The +returned objects will be those where the values passed are a subset of the +data. It uses the SQL operator ``@>``. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['tutorial', 'django']) + + >>> Post.objects.filter(tags__contains=['thoughts']) + [, ] + + >>> Post.objects.filter(tags__contains=['django']) + [, ] + + >>> Post.objects.filter(tags__contains=['django', 'thoughts']) + [] + +.. fieldlookup:: arrayfield.contained_by + +contained_by +~~~~~~~~~~~~ + +This is the inverse of the :lookup:`contains ` lookup - +the objects returned will be those where the data is a subset of the values +passed. It uses the SQL operator ``<@``. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['tutorial', 'django']) + + >>> Post.objects.filter(tags__contained_by=['thoughts', 'django']) + [] + + >>> Post.objects.filter(tags__contained_by=['thoughts', 'django', 'tutorial']) + [, , ] + +.. fieldlookup:: arrayfield.overlap + +overlap +~~~~~~~ + +Returns objects where the data shares any results with the values passed. Uses +the SQL operator ``&&``. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['tutorial', 'django']) + + >>> Post.objects.filter(tags__overlap=['thoughts']) + [, ] + + >>> Post.objects.filter(tags__overlap=['thoughts', 'tutorial']) + [, , ] + +.. fieldlookup:: arrayfield.index + +Index transforms +~~~~~~~~~~~~~~~~ + +This class of transforms allows you to index into the array in queries. Any +non-negative integer can be used. There are no errors if it exceeds the +:attr:`size ` of the array. The lookups available after the +transform are those from the :attr:`base_field `. For +example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + + >>> Post.objects.filter(tags__0='thoughts') + [, ] + + >>> Post.objects.filter(tags__1__iexact='Django') + [] + + >>> Post.objects.filter(tags__276='javascript') + [] + +.. note:: + + PostgreSQL uses 1-based indexing for array fields when writing raw SQL. + However these indexes and those used in :lookup:`slices ` + use 0-based indexing to be consistent with Python. + +.. fieldlookup:: arrayfield.slice + +Slice transforms +~~~~~~~~~~~~~~~~ + +This class of transforms allow you to take a slice of the array. Any two +non-negative integers can be used, separated by a single underscore. The +lookups available after the transform do not change. For example:: + + >>> Post.objects.create(name='First post', tags=['thoughts', 'django']) + >>> Post.objects.create(name='Second post', tags=['thoughts']) + >>> Post.objects.create(name='Third post', tags=['django', 'python', 'thoughts']) + + >>> Post.objects.filter(tags__0_1=['thoughts']) + [] + + >>> Post.objects.filter(tags__0_2__contains='thoughts') + [, ] + +.. note:: + + PostgreSQL uses 1-based indexing for array fields when writing raw SQL. + However these slices and those used in :lookup:`indexes ` + use 0-based indexing to be consistent with Python. + +.. admonition:: Multidimensional arrays with indexes and slices + + PostgreSQL has some rather esoteric behaviour when using indexes and slices + on multidimensional arrays. It will always work to use indexes to reach + down to the final underlying data, but most other slices behave strangely + at the database level and cannot be supported in a logical, consistent + fashion by Django. + +Indexing ArrayField +^^^^^^^^^^^^^^^^^^^ + +At present using :attr:`~django.db.models.Field.db_index` will create a +``btree`` index. This does not offer particularly significant help to querying. +A more useful index is a ``GIN`` index, which you should create using a +:class:`~django.db.migrations.operations.RunSQL` operation. diff --git a/docs/ref/contrib/postgres/forms.txt b/docs/ref/contrib/postgres/forms.txt new file mode 100644 index 00000000000..6cad537f3be --- /dev/null +++ b/docs/ref/contrib/postgres/forms.txt @@ -0,0 +1,135 @@ +PostgreSQL specific form fields and widgets +=========================================== + +All of these fields and widgets are available from the +``django.contrib.postgres.forms`` module. + +.. currentmodule:: django.contrib.postgres.forms + +SimpleArrayField +---------------- + +.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None) + + A simple field which maps to an array. It is represented by an HTML + ````. + + .. attribute:: base_field + + This is a required argument. + + It specifies the underlying form field for the array. This is not used + to render any HTML, but it is used to process the submitted data and + validate it. For example:: + + >>> from django.contrib.postgres.forms import SimpleArrayField + >>> from django import forms + + >>> class NumberListForm(forms.Form): + ... numbers = SimpleArrayField(forms.IntegerField()) + + >>> form = NumberListForm({'numbers': '1,2,3'}) + >>> form.is_valid() + True + >>> form.cleaned_data + {'numbers': [1, 2, 3]} + + >>> form = NumberListForm({'numbers': '1,2,a'}) + >>> form.is_valid() + False + + .. attribute:: delimiter + + This is an optional argument which defaults to a comma: ``,``. This + value is used to split the submitted data. It allows you to chain + ``SimpleArrayField`` for multidimensional data:: + + >>> from django.contrib.postgres.forms import SimpleArrayField + >>> from django import forms + + >>> class GridForm(forms.Form): + ... places = SimpleArrayField(SimpleArrayField(IntegerField()), delimiter='|') + + >>> form = GridForm({'places': '1,2|2,1|4,3'}) + >>> form.is_valid() + True + >>> form.cleaned_data + {'places': [[1, 2], [2, 1], [4, 3]]} + + .. note:: + + The field does not support escaping of the delimiter, so be careful + in cases where the delimiter is a valid character in the underlying + field. The delimiter does not need to be only one character. + + .. attribute:: max_length + + This is an optional argument which validates that the array does not + exceed the stated length. + + .. attribute:: min_length + + This is an optional argument which validates that the array reaches at + least the stated length. + + .. admonition:: User friendly forms + + ``SimpleArrayField`` is not particularly user friendly in most cases, + however it is a useful way to format data from a client-side widget for + submission to the server. + +SplitArrayField +--------------- + +.. class:: SplitArrayField(base_field, size, remove_trailing_nulls=False) + + This field handles arrays by reproducing the underlying field a fixed + number of times. + + .. attribute:: base_field + + This is a required argument. It specifies the form field to be + repeated. + + .. attribute:: size + + This is the fixed number of times the underlying field will be used. + + .. attribute:: remove_trailing_nulls + + By default, this is set to ``False``. When ``False``, each value from + the repeated fields is stored. When set to ``True``, any trailing + values which are blank will be stripped from the result. If the + underlying field has ``required=True``, but ``remove_trailing_nulls`` + is ``True``, then null values are only allowed at the end, and will be + stripped. + + Some examples:: + + SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=False) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> ValidationError - third entry required. + ['1', '', '3'] # -> ValidationError - second entry required. + ['', '2', ''] # -> ValidationError - first and third entries required. + + SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=False) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> [1, 2, None] + ['1', '', '3'] # -> [1, None, 3] + ['', '2', ''] # -> [None, 2, None] + + SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=True) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> [1, 2] + ['1', '', '3'] # -> ValidationError - second entry required. + ['', '2', ''] # -> ValidationError - first entry required. + + SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=True) + + ['1', '2', '3'] # -> [1, 2, 3] + ['1', '2', ''] # -> [1, 2] + ['1', '', '3'] # -> [1, None, 3] + ['', '2', ''] # -> [None, 2] diff --git a/docs/ref/contrib/postgres/index.txt b/docs/ref/contrib/postgres/index.txt new file mode 100644 index 00000000000..5db4ab80ed0 --- /dev/null +++ b/docs/ref/contrib/postgres/index.txt @@ -0,0 +1,28 @@ +``django.contrib.postgres`` +=========================== + +PostgreSQL has a number of features which are not shared by the other databases +Django supports. This optional module contains model fields and form fields for +a number of PostgreSQL specific data types. + +.. note:: + Django is, and will continue to be, a database-agnostic web framework. We + would encourage those writing reusable applications for the Django + community to write database-agnostic code where practical. However, we + recognise that real world projects written using Django need not be + database-agnostic. In fact, once a project reaches a given size changing + the underlying data store is already a significant challenge and is likely + to require changing the code base in some ways to handle differences + between the data stores. + + Django provides support for a number of data types which will + only work with PostgreSQL. There is no fundamental reason why (for example) + a ``contrib.mysql`` module does not exist, except that PostgreSQL has the + richest feature set of the supported databases so its users have the most + to gain. + +.. toctree:: + :maxdepth: 2 + + fields + forms diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py new file mode 100644 index 00000000000..6420ebe1cd6 --- /dev/null +++ b/tests/postgres_tests/models.py @@ -0,0 +1,22 @@ +from django.contrib.postgres.fields import ArrayField +from django.db import models + + +class IntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField()) + + +class NullableIntegerArrayModel(models.Model): + field = ArrayField(models.IntegerField(), blank=True, null=True) + + +class CharArrayModel(models.Model): + field = ArrayField(models.CharField(max_length=10)) + + +class DateTimeArrayModel(models.Model): + field = ArrayField(models.DateTimeField()) + + +class NestedIntegerArrayModel(models.Model): + field = ArrayField(ArrayField(models.IntegerField())) diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py new file mode 100644 index 00000000000..35ea65480ad --- /dev/null +++ b/tests/postgres_tests/test_array.py @@ -0,0 +1,389 @@ +import unittest + +from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField +from django.core import exceptions, serializers +from django.db import models, IntegrityError, connection +from django.db.migrations.writer import MigrationWriter +from django import forms +from django.test import TestCase +from django.utils import timezone + +from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestSaveLoad(TestCase): + + def test_integer(self): + instance = IntegerArrayModel(field=[1, 2, 3]) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_char(self): + instance = CharArrayModel(field=['hello', 'goodbye']) + instance.save() + loaded = CharArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_dates(self): + instance = DateTimeArrayModel(field=[timezone.now()]) + instance.save() + loaded = DateTimeArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + def test_tuples(self): + instance = IntegerArrayModel(field=(1,)) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertSequenceEqual(instance.field, loaded.field) + + def test_integers_passed_as_strings(self): + # This checks that get_prep_value is deferred properly + instance = IntegerArrayModel(field=['1']) + instance.save() + loaded = IntegerArrayModel.objects.get() + self.assertEqual(loaded.field, [1]) + + def test_null_handling(self): + instance = NullableIntegerArrayModel(field=None) + instance.save() + loaded = NullableIntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + instance = IntegerArrayModel(field=None) + with self.assertRaises(IntegrityError): + instance.save() + + def test_nested(self): + instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]]) + instance.save() + loaded = NestedIntegerArrayModel.objects.get() + self.assertEqual(instance.field, loaded.field) + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestQuerying(TestCase): + + def setUp(self): + self.objs = [ + NullableIntegerArrayModel.objects.create(field=[1]), + NullableIntegerArrayModel.objects.create(field=[2]), + NullableIntegerArrayModel.objects.create(field=[2, 3]), + NullableIntegerArrayModel.objects.create(field=[20, 30, 40]), + NullableIntegerArrayModel.objects.create(field=None), + ] + + def test_exact(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[1]), + self.objs[:1] + ) + + def test_isnull(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__isnull=True), + self.objs[-1:] + ) + + def test_gt(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=[0]), + self.objs[:4] + ) + + def test_lt(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=[2]), + self.objs[:1] + ) + + def test_in(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), + self.objs[:2] + ) + + def test_contained_by(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), + self.objs[:2] + ) + + def test_contains(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[2]), + self.objs[1:3] + ) + + def test_index(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0=2), + self.objs[1:3] + ) + + def test_index_chained(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0__lt=3), + self.objs[0:3] + ) + + def test_index_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0__0=1), + [instance] + ) + + @unittest.expectedFailure + def test_index_used_on_nested_data(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), + [instance] + ) + + def test_overlap(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), + self.objs[0:3] + ) + + def test_len(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__len__lte=2), + self.objs[0:3] + ) + + def test_slice(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0_1=[2]), + self.objs[1:3] + ) + + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), + self.objs[2:3] + ) + + @unittest.expectedFailure + def test_slice_nested(self): + instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) + self.assertSequenceEqual( + NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), + [instance] + ) + + +class TestChecks(TestCase): + + def test_field_checks(self): + field = ArrayField(models.CharField()) + field.set_attributes_from_name('field') + errors = field.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, 'postgres.E001') + + def test_invalid_base_fields(self): + field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) + field.set_attributes_from_name('field') + errors = field.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, 'postgres.E002') + + +class TestMigrations(TestCase): + + def test_deconstruct(self): + field = ArrayField(models.IntegerField()) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(type(new.base_field), type(field.base_field)) + + def test_deconstruct_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(new.size, field.size) + + def test_deconstruct_args(self): + field = ArrayField(models.CharField(max_length=20)) + name, path, args, kwargs = field.deconstruct() + new = ArrayField(*args, **kwargs) + self.assertEqual(new.base_field.max_length, field.base_field.max_length) + + def test_makemigrations(self): + field = ArrayField(models.CharField(max_length=20)) + statement, imports = MigrationWriter.serialize(field) + self.assertEqual(statement, 'django.contrib.postgres.fields.ArrayField(models.CharField(max_length=20), size=None)') + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestSerialization(TestCase): + test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\"]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]' + + def test_dumping(self): + instance = IntegerArrayModel(field=[1, 2]) + data = serializers.serialize('json', [instance]) + self.assertEqual(data, self.test_data) + + def test_loading(self): + instance = list(serializers.deserialize('json', self.test_data))[0].object + self.assertEqual(instance.field, [1, 2]) + + +class TestValidation(TestCase): + + def test_unbounded(self): + field = ArrayField(models.IntegerField()) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([1, None], None) + self.assertEqual(cm.exception.code, 'item_invalid') + self.assertEqual(cm.exception.message % cm.exception.params, 'Item 1 in the array did not validate: This field cannot be null.') + + def test_blank_true(self): + field = ArrayField(models.IntegerField(blank=True, null=True)) + # This should not raise a validation error + field.clean([1, None], None) + + def test_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + field.clean([1, 2, 3], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([1, 2, 3, 4], None) + self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.') + + def test_nested_array_mismatch(self): + field = ArrayField(ArrayField(models.IntegerField())) + field.clean([[1, 2], [3, 4]], None) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean([[1, 2], [3, 4, 5]], None) + self.assertEqual(cm.exception.code, 'nested_array_mismatch') + self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.') + + +class TestSimpleFormField(TestCase): + + def test_valid(self): + field = SimpleArrayField(forms.CharField()) + value = field.clean('a,b,c') + self.assertEqual(value, ['a', 'b', 'c']) + + def test_to_python_fail(self): + field = SimpleArrayField(forms.IntegerField()) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,9') + self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a whole number.') + + def test_validate_fail(self): + field = SimpleArrayField(forms.CharField(required=True)) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,') + self.assertEqual(cm.exception.messages[0], 'Item 2 in the array did not validate: This field is required.') + + def test_validators_fail(self): + field = SimpleArrayField(forms.RegexField('[a-e]{2}')) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,bc,de') + self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a valid value.') + + def test_delimiter(self): + field = SimpleArrayField(forms.CharField(), delimiter='|') + value = field.clean('a|b|c') + self.assertEqual(value, ['a', 'b', 'c']) + + def test_delimiter_with_nesting(self): + field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|') + value = field.clean('a,b|c,d') + self.assertEqual(value, [['a', 'b'], ['c', 'd']]) + + def test_prepare_value(self): + field = SimpleArrayField(forms.CharField()) + value = field.prepare_value(['a', 'b', 'c']) + self.assertEqual(value, 'a,b,c') + + def test_max_length(self): + field = SimpleArrayField(forms.CharField(), max_length=2) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,c') + self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.') + + def test_min_length(self): + field = SimpleArrayField(forms.CharField(), min_length=4) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('a,b,c') + self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.') + + def test_required(self): + field = SimpleArrayField(forms.CharField(), required=True) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('') + self.assertEqual(cm.exception.messages[0], 'This field is required.') + + def test_model_field_formfield(self): + model_field = ArrayField(models.CharField(max_length=27)) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertIsInstance(form_field.base_field, forms.CharField) + self.assertEqual(form_field.base_field.max_length, 27) + + def test_model_field_formfield_size(self): + model_field = ArrayField(models.CharField(max_length=27), size=4) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertEqual(form_field.max_length, 4) + + +class TestSplitFormField(TestCase): + + def test_valid(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'} + form = SplitForm(data) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']}) + + def test_required(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), required=True, size=3) + + data = {'array_0': '', 'array_1': '', 'array_2': ''} + form = SplitForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors, {'array': ['This field is required.']}) + + def test_remove_trailing_nulls(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True) + + data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''} + form = SplitForm(data) + self.assertTrue(form.is_valid(), form.errors) + self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']}) + + def test_required_field(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''} + form = SplitForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors, {'array': ['Item 2 in the array did not validate: This field is required.']}) + + def test_rendering(self): + class SplitForm(forms.Form): + array = SplitArrayField(forms.CharField(), size=3) + + self.assertHTMLEqual(str(SplitForm()), ''' + + + + + + + + + ''') diff --git a/tests/runtests.py b/tests/runtests.py index 787b83e7a5d..14014f4b014 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -57,6 +57,7 @@ ALWAYS_INSTALLED_APPS = [ def get_test_modules(): from django.contrib.gis.tests.utils import HAS_SPATIAL_DB + from django.db import connection modules = [] discovery_paths = [ (None, RUNTESTS_DIR), @@ -75,6 +76,8 @@ def get_test_modules(): os.path.isfile(f) or not os.path.exists(os.path.join(dirpath, f, '__init__.py'))): continue + if not connection.vendor == 'postgresql' and f == 'postgres_tests': + continue modules.append((modpath, f)) return modules