diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index f8009bfb8a..8ec5f88e9e 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -28,6 +28,10 @@ class ArrayField(Field): if self.size: self.default_validators = self.default_validators[:] self.default_validators.append(ArrayMaxLengthValidator(self.size)) + # For performance, only add a from_db_value() method if the base field + # implements it. + if hasattr(self.base_field, 'from_db_value'): + self.from_db_value = self._from_db_value super(ArrayField, self).__init__(**kwargs) @property @@ -100,6 +104,14 @@ class ArrayField(Field): value = [self.base_field.to_python(val) for val in vals] return value + def _from_db_value(self, value, expression, connection, context): + if value is None: + return value + return [ + self.base_field.from_db_value(item, expression, connection, context) + for item in value + ] + def value_to_string(self, obj): values = [] vals = self.value_from_object(obj) diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index a58119e23c..872d8aeb58 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.db import migrations, models from ..fields import * # NOQA +from ..models import TagField class Migration(migrations.Migration): @@ -55,6 +56,7 @@ class Migration(migrations.Migration): ('ips', ArrayField(models.GenericIPAddressField(), size=None)), ('uuids', ArrayField(models.UUIDField(), size=None)), ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), + ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), ], options={ 'required_db_vendor': 'postgresql', diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 62a32b113d..b950134bb9 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -6,6 +6,35 @@ from .fields import ( ) +class Tag(object): + def __init__(self, tag_id): + self.tag_id = tag_id + + def __eq__(self, other): + return isinstance(other, Tag) and self.tag_id == other.tag_id + + +class TagField(models.SmallIntegerField): + + def from_db_value(self, value, expression, connection, context): + if value is None: + return value + return Tag(int(value)) + + def to_python(self, value): + if isinstance(value, Tag): + return value + if value is None: + return value + return Tag(int(value)) + + def get_prep_value(self, value): + return value.tag_id + + def get_db_prep_value(self, value, connection, prepared=False): + return self.get_prep_value(value) + + class PostgreSQLModel(models.Model): class Meta: abstract = True @@ -38,6 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel): ips = ArrayField(models.GenericIPAddressField()) uuids = ArrayField(models.UUIDField()) decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) + tags = ArrayField(TagField(), blank=True, null=True) class HStoreModel(PostgreSQLModel): diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 5df25c1c0f..f5c333e56f 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -15,7 +15,7 @@ from . import PostgreSQLTestCase from .models import ( ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, - PostgreSQLModel, + PostgreSQLModel, Tag, ) try: @@ -92,12 +92,24 @@ class TestSaveLoad(PostgreSQLTestCase): ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], + tags=[Tag(1), Tag(2), Tag(3)], ) instance.save() loaded = OtherTypesArrayModel.objects.get() self.assertEqual(instance.ips, loaded.ips) self.assertEqual(instance.uuids, loaded.uuids) self.assertEqual(instance.decimals, loaded.decimals) + self.assertEqual(instance.tags, loaded.tags) + + def test_null_from_db_value_handling(self): + instance = OtherTypesArrayModel.objects.create( + ips=['192.168.0.1', '::1'], + uuids=[uuid.uuid4()], + decimals=[decimal.Decimal(1.25), 1.75], + tags=None, + ) + instance.refresh_from_db() + self.assertIsNone(instance.tags) def test_model_set_on_base_field(self): instance = IntegerArrayModel() @@ -306,11 +318,13 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): self.ips = ['192.168.0.1', '::1'] self.uuids = [uuid.uuid4()] self.decimals = [decimal.Decimal(1.25), 1.75] + self.tags = [Tag(1), Tag(2), Tag(3)] self.objs = [ OtherTypesArrayModel.objects.create( ips=self.ips, uuids=self.uuids, decimals=self.decimals, + tags=self.tags, ) ] @@ -332,6 +346,12 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): self.objs ) + def test_exact_tags(self): + self.assertSequenceEqual( + OtherTypesArrayModel.objects.filter(tags=self.tags), + self.objs + ) + @isolate_apps('postgres_tests') class TestChecks(PostgreSQLTestCase):