Fixed #25143 -- Added ArrayField.from_db_value().
Thanks Karan Lyons for contributing to the patch.
This commit is contained in:
parent
f8d20da047
commit
2495023a4c
|
@ -28,6 +28,10 @@ class ArrayField(Field):
|
||||||
if self.size:
|
if self.size:
|
||||||
self.default_validators = self.default_validators[:]
|
self.default_validators = self.default_validators[:]
|
||||||
self.default_validators.append(ArrayMaxLengthValidator(self.size))
|
self.default_validators.append(ArrayMaxLengthValidator(self.size))
|
||||||
|
# For performance, only add a from_db_value() method if the base field
|
||||||
|
# implements it.
|
||||||
|
if hasattr(self.base_field, 'from_db_value'):
|
||||||
|
self.from_db_value = self._from_db_value
|
||||||
super(ArrayField, self).__init__(**kwargs)
|
super(ArrayField, self).__init__(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -100,6 +104,14 @@ class ArrayField(Field):
|
||||||
value = [self.base_field.to_python(val) for val in vals]
|
value = [self.base_field.to_python(val) for val in vals]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def _from_db_value(self, value, expression, connection, context):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return [
|
||||||
|
self.base_field.from_db_value(item, expression, connection, context)
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
|
||||||
def value_to_string(self, obj):
|
def value_to_string(self, obj):
|
||||||
values = []
|
values = []
|
||||||
vals = self.value_from_object(obj)
|
vals = self.value_from_object(obj)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import unicode_literals
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
from ..fields import * # NOQA
|
from ..fields import * # NOQA
|
||||||
|
from ..models import TagField
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
@ -55,6 +56,7 @@ class Migration(migrations.Migration):
|
||||||
('ips', ArrayField(models.GenericIPAddressField(), size=None)),
|
('ips', ArrayField(models.GenericIPAddressField(), size=None)),
|
||||||
('uuids', ArrayField(models.UUIDField(), size=None)),
|
('uuids', ArrayField(models.UUIDField(), size=None)),
|
||||||
('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
|
('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
|
||||||
|
('tags', ArrayField(TagField(), blank=True, null=True, size=None)),
|
||||||
],
|
],
|
||||||
options={
|
options={
|
||||||
'required_db_vendor': 'postgresql',
|
'required_db_vendor': 'postgresql',
|
||||||
|
|
|
@ -6,6 +6,35 @@ from .fields import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(object):
|
||||||
|
def __init__(self, tag_id):
|
||||||
|
self.tag_id = tag_id
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, Tag) and self.tag_id == other.tag_id
|
||||||
|
|
||||||
|
|
||||||
|
class TagField(models.SmallIntegerField):
|
||||||
|
|
||||||
|
def from_db_value(self, value, expression, connection, context):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return Tag(int(value))
|
||||||
|
|
||||||
|
def to_python(self, value):
|
||||||
|
if isinstance(value, Tag):
|
||||||
|
return value
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return Tag(int(value))
|
||||||
|
|
||||||
|
def get_prep_value(self, value):
|
||||||
|
return value.tag_id
|
||||||
|
|
||||||
|
def get_db_prep_value(self, value, connection, prepared=False):
|
||||||
|
return self.get_prep_value(value)
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLModel(models.Model):
|
class PostgreSQLModel(models.Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -38,6 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel):
|
||||||
ips = ArrayField(models.GenericIPAddressField())
|
ips = ArrayField(models.GenericIPAddressField())
|
||||||
uuids = ArrayField(models.UUIDField())
|
uuids = ArrayField(models.UUIDField())
|
||||||
decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
|
decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
|
||||||
|
tags = ArrayField(TagField(), blank=True, null=True)
|
||||||
|
|
||||||
|
|
||||||
class HStoreModel(PostgreSQLModel):
|
class HStoreModel(PostgreSQLModel):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from . import PostgreSQLTestCase
|
||||||
from .models import (
|
from .models import (
|
||||||
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
|
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
|
||||||
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
|
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
|
||||||
PostgreSQLModel,
|
PostgreSQLModel, Tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -92,12 +92,24 @@ class TestSaveLoad(PostgreSQLTestCase):
|
||||||
ips=['192.168.0.1', '::1'],
|
ips=['192.168.0.1', '::1'],
|
||||||
uuids=[uuid.uuid4()],
|
uuids=[uuid.uuid4()],
|
||||||
decimals=[decimal.Decimal(1.25), 1.75],
|
decimals=[decimal.Decimal(1.25), 1.75],
|
||||||
|
tags=[Tag(1), Tag(2), Tag(3)],
|
||||||
)
|
)
|
||||||
instance.save()
|
instance.save()
|
||||||
loaded = OtherTypesArrayModel.objects.get()
|
loaded = OtherTypesArrayModel.objects.get()
|
||||||
self.assertEqual(instance.ips, loaded.ips)
|
self.assertEqual(instance.ips, loaded.ips)
|
||||||
self.assertEqual(instance.uuids, loaded.uuids)
|
self.assertEqual(instance.uuids, loaded.uuids)
|
||||||
self.assertEqual(instance.decimals, loaded.decimals)
|
self.assertEqual(instance.decimals, loaded.decimals)
|
||||||
|
self.assertEqual(instance.tags, loaded.tags)
|
||||||
|
|
||||||
|
def test_null_from_db_value_handling(self):
|
||||||
|
instance = OtherTypesArrayModel.objects.create(
|
||||||
|
ips=['192.168.0.1', '::1'],
|
||||||
|
uuids=[uuid.uuid4()],
|
||||||
|
decimals=[decimal.Decimal(1.25), 1.75],
|
||||||
|
tags=None,
|
||||||
|
)
|
||||||
|
instance.refresh_from_db()
|
||||||
|
self.assertIsNone(instance.tags)
|
||||||
|
|
||||||
def test_model_set_on_base_field(self):
|
def test_model_set_on_base_field(self):
|
||||||
instance = IntegerArrayModel()
|
instance = IntegerArrayModel()
|
||||||
|
@ -306,11 +318,13 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase):
|
||||||
self.ips = ['192.168.0.1', '::1']
|
self.ips = ['192.168.0.1', '::1']
|
||||||
self.uuids = [uuid.uuid4()]
|
self.uuids = [uuid.uuid4()]
|
||||||
self.decimals = [decimal.Decimal(1.25), 1.75]
|
self.decimals = [decimal.Decimal(1.25), 1.75]
|
||||||
|
self.tags = [Tag(1), Tag(2), Tag(3)]
|
||||||
self.objs = [
|
self.objs = [
|
||||||
OtherTypesArrayModel.objects.create(
|
OtherTypesArrayModel.objects.create(
|
||||||
ips=self.ips,
|
ips=self.ips,
|
||||||
uuids=self.uuids,
|
uuids=self.uuids,
|
||||||
decimals=self.decimals,
|
decimals=self.decimals,
|
||||||
|
tags=self.tags,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -332,6 +346,12 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase):
|
||||||
self.objs
|
self.objs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_exact_tags(self):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
OtherTypesArrayModel.objects.filter(tags=self.tags),
|
||||||
|
self.objs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@isolate_apps('postgres_tests')
|
@isolate_apps('postgres_tests')
|
||||||
class TestChecks(PostgreSQLTestCase):
|
class TestChecks(PostgreSQLTestCase):
|
||||||
|
|
Loading…
Reference in New Issue