Refs #29738 -- Allowed registering serializers with MigrationWriter.
This commit is contained in:
parent
3c01fe30f3
commit
7d3b3897c1
|
@ -8,6 +8,7 @@ import math
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from django.conf import SettingsReference
|
from django.conf import SettingsReference
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
@ -271,6 +272,38 @@ class UUIDSerializer(BaseSerializer):
|
||||||
return "uuid.%s" % repr(self.value), {"import uuid"}
|
return "uuid.%s" % repr(self.value), {"import uuid"}
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer:
|
||||||
|
_registry = OrderedDict([
|
||||||
|
(frozenset, FrozensetSerializer),
|
||||||
|
(list, SequenceSerializer),
|
||||||
|
(set, SetSerializer),
|
||||||
|
(tuple, TupleSerializer),
|
||||||
|
(dict, DictionarySerializer),
|
||||||
|
(enum.Enum, EnumSerializer),
|
||||||
|
(datetime.datetime, DatetimeDatetimeSerializer),
|
||||||
|
((datetime.date, datetime.timedelta, datetime.time), DateTimeSerializer),
|
||||||
|
(SettingsReference, SettingsReferenceSerializer),
|
||||||
|
(float, FloatSerializer),
|
||||||
|
((bool, int, type(None), bytes, str), BaseSimpleSerializer),
|
||||||
|
(decimal.Decimal, DecimalSerializer),
|
||||||
|
((functools.partial, functools.partialmethod), FunctoolsPartialSerializer),
|
||||||
|
((types.FunctionType, types.BuiltinFunctionType, types.MethodType), FunctionTypeSerializer),
|
||||||
|
(collections.abc.Iterable, IterableSerializer),
|
||||||
|
((COMPILED_REGEX_TYPE, RegexObject), RegexSerializer),
|
||||||
|
(uuid.UUID, UUIDSerializer),
|
||||||
|
])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, type_, serializer):
|
||||||
|
if not issubclass(serializer, BaseSerializer):
|
||||||
|
raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
|
||||||
|
cls._registry[type_] = serializer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister(cls, type_):
|
||||||
|
cls._registry.pop(type_)
|
||||||
|
|
||||||
|
|
||||||
def serializer_factory(value):
|
def serializer_factory(value):
|
||||||
if isinstance(value, Promise):
|
if isinstance(value, Promise):
|
||||||
value = str(value)
|
value = str(value)
|
||||||
|
@ -290,42 +323,9 @@ def serializer_factory(value):
|
||||||
# Anything that knows how to deconstruct itself.
|
# Anything that knows how to deconstruct itself.
|
||||||
if hasattr(value, 'deconstruct'):
|
if hasattr(value, 'deconstruct'):
|
||||||
return DeconstructableSerializer(value)
|
return DeconstructableSerializer(value)
|
||||||
|
for type_, serializer_cls in Serializer._registry.items():
|
||||||
# Unfortunately some of these are order-dependent.
|
if isinstance(value, type_):
|
||||||
if isinstance(value, frozenset):
|
return serializer_cls(value)
|
||||||
return FrozensetSerializer(value)
|
|
||||||
if isinstance(value, list):
|
|
||||||
return SequenceSerializer(value)
|
|
||||||
if isinstance(value, set):
|
|
||||||
return SetSerializer(value)
|
|
||||||
if isinstance(value, tuple):
|
|
||||||
return TupleSerializer(value)
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return DictionarySerializer(value)
|
|
||||||
if isinstance(value, enum.Enum):
|
|
||||||
return EnumSerializer(value)
|
|
||||||
if isinstance(value, datetime.datetime):
|
|
||||||
return DatetimeDatetimeSerializer(value)
|
|
||||||
if isinstance(value, (datetime.date, datetime.timedelta, datetime.time)):
|
|
||||||
return DateTimeSerializer(value)
|
|
||||||
if isinstance(value, SettingsReference):
|
|
||||||
return SettingsReferenceSerializer(value)
|
|
||||||
if isinstance(value, float):
|
|
||||||
return FloatSerializer(value)
|
|
||||||
if isinstance(value, (bool, int, type(None), bytes, str)):
|
|
||||||
return BaseSimpleSerializer(value)
|
|
||||||
if isinstance(value, decimal.Decimal):
|
|
||||||
return DecimalSerializer(value)
|
|
||||||
if isinstance(value, (functools.partial, functools.partialmethod)):
|
|
||||||
return FunctoolsPartialSerializer(value)
|
|
||||||
if isinstance(value, (types.FunctionType, types.BuiltinFunctionType, types.MethodType)):
|
|
||||||
return FunctionTypeSerializer(value)
|
|
||||||
if isinstance(value, collections.abc.Iterable):
|
|
||||||
return IterableSerializer(value)
|
|
||||||
if isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)):
|
|
||||||
return RegexSerializer(value)
|
|
||||||
if isinstance(value, uuid.UUID):
|
|
||||||
return UUIDSerializer(value)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
|
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
|
||||||
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
|
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
|
||||||
|
|
|
@ -8,7 +8,7 @@ from django.apps import apps
|
||||||
from django.conf import SettingsReference # NOQA
|
from django.conf import SettingsReference # NOQA
|
||||||
from django.db import migrations
|
from django.db import migrations
|
||||||
from django.db.migrations.loader import MigrationLoader
|
from django.db.migrations.loader import MigrationLoader
|
||||||
from django.db.migrations.serializer import serializer_factory
|
from django.db.migrations.serializer import Serializer, serializer_factory
|
||||||
from django.utils.inspect import get_func_args
|
from django.utils.inspect import get_func_args
|
||||||
from django.utils.module_loading import module_dir
|
from django.utils.module_loading import module_dir
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
@ -270,6 +270,14 @@ class MigrationWriter:
|
||||||
def serialize(cls, value):
|
def serialize(cls, value):
|
||||||
return serializer_factory(value).serialize()
|
return serializer_factory(value).serialize()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_serializer(cls, type_, serializer):
|
||||||
|
Serializer.register(type_, serializer)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister_serializer(cls, type_):
|
||||||
|
Serializer.unregister(type_)
|
||||||
|
|
||||||
|
|
||||||
MIGRATION_HEADER_TEMPLATE = """\
|
MIGRATION_HEADER_TEMPLATE = """\
|
||||||
# Generated by Django %(version)s on %(timestamp)s
|
# Generated by Django %(version)s on %(timestamp)s
|
||||||
|
|
|
@ -211,6 +211,9 @@ Migrations
|
||||||
|
|
||||||
* ``NoneType`` can now be serialized in migrations.
|
* ``NoneType`` can now be serialized in migrations.
|
||||||
|
|
||||||
|
* You can now :ref:`register custom serializers <custom-migration-serializers>`
|
||||||
|
for migrations.
|
||||||
|
|
||||||
Models
|
Models
|
||||||
~~~~~~
|
~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -697,6 +697,35 @@ Django cannot serialize:
|
||||||
- Arbitrary class instances (e.g. ``MyClass(4.3, 5.7)``)
|
- Arbitrary class instances (e.g. ``MyClass(4.3, 5.7)``)
|
||||||
- Lambdas
|
- Lambdas
|
||||||
|
|
||||||
|
.. _custom-migration-serializers:
|
||||||
|
|
||||||
|
Custom serializers
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. versionadded:: 2.2
|
||||||
|
|
||||||
|
You can serialize other types by writing a custom serializer. For example, if
|
||||||
|
Django didn't serialize :class:`~decimal.Decimal` by default, you could do
|
||||||
|
this::
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from django.db.migrations.serializer import BaseSerializer
|
||||||
|
from django.db.migrations.writer import MigrationWriter
|
||||||
|
|
||||||
|
class DecimalSerializer(BaseSerializer):
|
||||||
|
def serialize(self):
|
||||||
|
return repr(self.value), {'from decimal import Decimal'}
|
||||||
|
|
||||||
|
MigrationWriter.register_serializer(Decimal, DecimalSerializer)
|
||||||
|
|
||||||
|
The first argument of ``MigrationWriter.register_serializer()`` is a type or
|
||||||
|
iterable of types that should use the serializer.
|
||||||
|
|
||||||
|
The ``serialize()`` method of your serializer must return a string of how the
|
||||||
|
value should appear in migrations and a set of any imports that are needed in
|
||||||
|
the migration.
|
||||||
|
|
||||||
.. _custom-deconstruct-method:
|
.. _custom-deconstruct-method:
|
||||||
|
|
||||||
Adding a ``deconstruct()`` method
|
Adding a ``deconstruct()`` method
|
||||||
|
|
|
@ -15,6 +15,7 @@ from django import get_version
|
||||||
from django.conf import SettingsReference, settings
|
from django.conf import SettingsReference, settings
|
||||||
from django.core.validators import EmailValidator, RegexValidator
|
from django.core.validators import EmailValidator, RegexValidator
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
from django.db.migrations.serializer import BaseSerializer
|
||||||
from django.db.migrations.writer import MigrationWriter, OperationWriter
|
from django.db.migrations.writer import MigrationWriter, OperationWriter
|
||||||
from django.test import SimpleTestCase
|
from django.test import SimpleTestCase
|
||||||
from django.utils.deconstruct import deconstructible
|
from django.utils.deconstruct import deconstructible
|
||||||
|
@ -653,3 +654,18 @@ class WriterTests(SimpleTestCase):
|
||||||
|
|
||||||
string = MigrationWriter.serialize(models.CharField(default=DeconstructibleInstances))[0]
|
string = MigrationWriter.serialize(models.CharField(default=DeconstructibleInstances))[0]
|
||||||
self.assertEqual(string, "models.CharField(default=migrations.test_writer.DeconstructibleInstances)")
|
self.assertEqual(string, "models.CharField(default=migrations.test_writer.DeconstructibleInstances)")
|
||||||
|
|
||||||
|
def test_register_serializer(self):
|
||||||
|
class ComplexSerializer(BaseSerializer):
|
||||||
|
def serialize(self):
|
||||||
|
return 'complex(%r)' % self.value, {}
|
||||||
|
|
||||||
|
MigrationWriter.register_serializer(complex, ComplexSerializer)
|
||||||
|
self.assertSerializedEqual(complex(1, 2))
|
||||||
|
MigrationWriter.unregister_serializer(complex)
|
||||||
|
with self.assertRaisesMessage(ValueError, 'Cannot serialize: (1+2j)'):
|
||||||
|
self.assertSerializedEqual(complex(1, 2))
|
||||||
|
|
||||||
|
def test_register_non_serializer(self):
|
||||||
|
with self.assertRaisesMessage(ValueError, "'TestModel1' must inherit from 'BaseSerializer'."):
|
||||||
|
MigrationWriter.register_serializer(complex, TestModel1)
|
||||||
|
|
Loading…
Reference in New Issue