Fixed #25995 -- Added an encoder option to JSONField

Thanks Berker Peksag and Tim Graham for the reviews.
This commit is contained in:
Claude Paroz 2016-08-11 21:05:52 +02:00
parent 989f6108d3
commit 13c3e5d5a0
7 changed files with 98 additions and 10 deletions

View File

@ -10,6 +10,19 @@ from django.utils.translation import ugettext_lazy as _
__all__ = ['JSONField'] __all__ = ['JSONField']
class JsonAdapter(Json):
"""
Customized psycopg2.extras.Json to allow for a custom encoder.
"""
def __init__(self, adapted, dumps=None, encoder=None):
self.encoder = encoder
super(JsonAdapter, self).__init__(adapted, dumps=dumps)
def dumps(self, obj):
options = {'cls': self.encoder} if self.encoder else {}
return json.dumps(obj, **options)
class JSONField(Field): class JSONField(Field):
empty_strings_allowed = False empty_strings_allowed = False
description = _('A JSON object') description = _('A JSON object')
@ -17,9 +30,21 @@ class JSONField(Field):
'invalid': _("Value must be valid JSON."), 'invalid': _("Value must be valid JSON."),
} }
def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs):
if encoder and not callable(encoder):
raise ValueError("The encoder parameter must be a callable object.")
self.encoder = encoder
super(JSONField, self).__init__(verbose_name, name, **kwargs)
def db_type(self, connection): def db_type(self, connection):
return 'jsonb' return 'jsonb'
def deconstruct(self):
name, path, args, kwargs = super(JSONField, self).deconstruct()
if self.encoder is not None:
kwargs['encoder'] = self.encoder
return name, path, args, kwargs
def get_transform(self, name): def get_transform(self, name):
transform = super(JSONField, self).get_transform(name) transform = super(JSONField, self).get_transform(name)
if transform: if transform:
@ -28,13 +53,14 @@ class JSONField(Field):
def get_prep_value(self, value): def get_prep_value(self, value):
if value is not None: if value is not None:
return Json(value) return JsonAdapter(value, encoder=self.encoder)
return value return value
def validate(self, value, model_instance): def validate(self, value, model_instance):
super(JSONField, self).validate(value, model_instance) super(JSONField, self).validate(value, model_instance)
options = {'cls': self.encoder} if self.encoder else {}
try: try:
json.dumps(value) json.dumps(value, **options)
except TypeError: except TypeError:
raise exceptions.ValidationError( raise exceptions.ValidationError(
self.error_messages['invalid'], self.error_messages['invalid'],

View File

@ -458,17 +458,32 @@ using in conjunction with lookups on
``JSONField`` ``JSONField``
============= =============
.. class:: JSONField(**options) .. class:: JSONField(encoder=None, **options)
A field for storing JSON encoded data. In Python the data is represented in A field for storing JSON encoded data. In Python the data is represented in
its Python native format: dictionaries, lists, strings, numbers, booleans its Python native format: dictionaries, lists, strings, numbers, booleans
and ``None``. and ``None``.
If you want to store other data types, you'll need to serialize them first. .. attribute:: encoder
For example, you might cast a ``datetime`` to a string. You might also want
to convert the string back to a ``datetime`` when you retrieve the data .. versionadded:: 1.11
from the database. There are some third-party ``JSONField`` implementations
which do this sort of thing automatically. An optional JSON-encoding class to serialize data types not supported
by the standard JSON serializer (``datetime``, ``uuid``, etc.). For
example, you can use the
:class:`~django.core.serializers.json.DjangoJSONEncoder` class or any
other :py:class:`json.JSONEncoder` subclass.
When the value is retrieved from the database, it will be in the format
chosen by the custom encoder (most often a string), so you'll need to
take extra steps to convert the value back to the initial data type
(:meth:`Model.from_db() <django.db.models.Model.from_db>` and
:meth:`Field.from_db_value() <django.db.models.Field.from_db_value>`
are two possible hooks for that purpose). Your deserialization may need
to account for the fact that you can't be certain of the input type.
For example, you run the risk of returning a ``datetime`` that was
actually a string that just happened to be in the same format chosen
for ``datetime``\s.
If you give the field a :attr:`~django.db.models.Field.default`, ensure If you give the field a :attr:`~django.db.models.Field.default`, ensure
it's a callable such as ``dict`` (for an empty default) or a callable that it's a callable such as ``dict`` (for an empty default) or a callable that

View File

@ -129,6 +129,10 @@ Minor features
* The new :class:`~django.contrib.postgres.indexes.GinIndex` class allows * The new :class:`~django.contrib.postgres.indexes.GinIndex` class allows
creating gin indexes in the database. creating gin indexes in the database.
* :class:`~django.contrib.postgres.fields.JSONField` accepts a new ``encoder``
parameter to specify a custom class to encode data types not supported by the
standard encoder.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -23,6 +23,10 @@ except ImportError:
}) })
return name, path, args, kwargs return name, path, args, kwargs
class DummyJSONField(models.Field):
def __init__(self, encoder=None, **kwargs):
super(DummyJSONField, self).__init__(**kwargs)
ArrayField = DummyArrayField ArrayField = DummyArrayField
BigIntegerRangeField = models.Field BigIntegerRangeField = models.Field
DateRangeField = models.Field DateRangeField = models.Field
@ -30,5 +34,5 @@ except ImportError:
FloatRangeField = models.Field FloatRangeField = models.Field
HStoreField = models.Field HStoreField = models.Field
IntegerRangeField = models.Field IntegerRangeField = models.Field
JSONField = models.Field JSONField = DummyJSONField
SearchVectorField = models.Field SearchVectorField = models.Field

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.serializers.json import DjangoJSONEncoder
from django.db import migrations, models from django.db import migrations, models
from ..fields import ( from ..fields import (
@ -223,6 +224,7 @@ class Migration(migrations.Migration):
fields=[ fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', JSONField(null=True, blank=True)), ('field', JSONField(null=True, blank=True)),
('field_custom', JSONField(null=True, blank=True, encoder=DjangoJSONEncoder)),
], ],
options={ options={
}, },

View File

@ -1,3 +1,4 @@
from django.core.serializers.json import DjangoJSONEncoder
from django.db import connection, models from django.db import connection, models
from .fields import ( from .fields import (
@ -132,6 +133,7 @@ class RangeLookupsModel(PostgreSQLModel):
if connection.vendor == 'postgresql' and connection.pg_version >= 90400: if connection.vendor == 'postgresql' and connection.pg_version >= 90400:
class JSONModel(models.Model): class JSONModel(models.Model):
field = JSONField(blank=True, null=True) field = JSONField(blank=True, null=True)
field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
else: else:
# create an object with this name so we don't have failing imports # create an object with this name so we don't have failing imports
class JSONModel(object): class JSONModel(object):

View File

@ -1,7 +1,12 @@
from __future__ import unicode_literals
import datetime import datetime
import unittest import unittest
import uuid
from decimal import Decimal
from django.core import exceptions, serializers from django.core import exceptions, serializers
from django.core.serializers.json import DjangoJSONEncoder
from django.db import connection from django.db import connection
from django.forms import CharField, Form, widgets from django.forms import CharField, Form, widgets
from django.test import TestCase from django.test import TestCase
@ -79,6 +84,27 @@ class TestSaveLoad(TestCase):
loaded = JSONModel.objects.get() loaded = JSONModel.objects.get()
self.assertEqual(loaded.field, obj) self.assertEqual(loaded.field, obj)
def test_custom_encoding(self):
"""
JSONModel.field_custom has a custom DjangoJSONEncoder.
"""
some_uuid = uuid.uuid4()
obj_before = {
'date': datetime.date(2016, 8, 12),
'datetime': datetime.datetime(2016, 8, 12, 13, 44, 47, 575981),
'decimal': Decimal('10.54'),
'uuid': some_uuid,
}
obj_after = {
'date': '2016-08-12',
'datetime': '2016-08-12T13:44:47.575',
'decimal': '10.54',
'uuid': str(some_uuid),
}
JSONModel.objects.create(field_custom=obj_before)
loaded = JSONModel.objects.get()
self.assertEqual(loaded.field_custom, obj_after)
@skipUnlessPG94 @skipUnlessPG94
class TestQuerying(TestCase): class TestQuerying(TestCase):
@ -215,7 +241,10 @@ class TestQuerying(TestCase):
@skipUnlessPG94 @skipUnlessPG94
class TestSerialization(TestCase): class TestSerialization(TestCase):
test_data = '[{"fields": {"field": {"a": "b", "c": null}}, "model": "postgres_tests.jsonmodel", "pk": null}]' test_data = (
'[{"fields": {"field": {"a": "b", "c": null}, "field_custom": null}, '
'"model": "postgres_tests.jsonmodel", "pk": null}]'
)
def test_dumping(self): def test_dumping(self):
instance = JSONModel(field={'a': 'b', 'c': None}) instance = JSONModel(field={'a': 'b', 'c': None})
@ -236,6 +265,12 @@ class TestValidation(PostgreSQLTestCase):
self.assertEqual(cm.exception.code, 'invalid') self.assertEqual(cm.exception.code, 'invalid')
self.assertEqual(cm.exception.message % cm.exception.params, "Value must be valid JSON.") self.assertEqual(cm.exception.message % cm.exception.params, "Value must be valid JSON.")
def test_custom_encoder(self):
with self.assertRaisesMessage(ValueError, "The encoder parameter must be a callable object."):
field = JSONField(encoder=DjangoJSONEncoder())
field = JSONField(encoder=DjangoJSONEncoder)
self.assertEqual(field.clean(datetime.timedelta(days=1), None), datetime.timedelta(days=1))
class TestFormField(PostgreSQLTestCase): class TestFormField(PostgreSQLTestCase):