diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py new file mode 100644 index 00000000000..9004d25e214 --- /dev/null +++ b/django/db/migrations/serializer.py @@ -0,0 +1,388 @@ +from __future__ import unicode_literals + +import collections +import datetime +import decimal +import functools +import math +import types +from importlib import import_module + +from django.db import models +from django.db.migrations.operations.base import Operation +from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject +from django.utils import datetime_safe, six +from django.utils.encoding import force_text +from django.utils.functional import LazyObject, Promise +from django.utils.timezone import utc +from django.utils.version import get_docs_version + +try: + import enum +except ImportError: + # No support on Python 2 if enum34 isn't installed. + enum = None + + +class BaseSerializer(object): + def __init__(self, value): + self.value = value + + def serialize(self): + raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.') + + +class BaseSequenceSerializer(BaseSerializer): + def _format(self): + raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.') + + def serialize(self): + imports = set() + strings = [] + for item in self.value: + item_string, item_imports = serializer_factory(item).serialize() + imports.update(item_imports) + strings.append(item_string) + value = self._format() + return value % (", ".join(strings)), imports + + +class BaseSimpleSerializer(BaseSerializer): + def serialize(self): + return repr(self.value), set() + + +class ByteTypeSerializer(BaseSerializer): + def serialize(self): + value_repr = repr(self.value) + if six.PY2: + # Prepend the `b` prefix since we're importing unicode_literals + value_repr = 'b' + value_repr + return value_repr, set() + + +class DatetimeSerializer(BaseSerializer): + def serialize(self): + if self.value.tzinfo is not None and self.value.tzinfo != utc: + self.value = self.value.astimezone(utc) + value_repr = repr(self.value).replace("", "utc") + if isinstance(self.value, datetime_safe.datetime): + value_repr = "datetime.%s" % value_repr + imports = ["import datetime"] + if self.value.tzinfo is not None: + imports.append("from django.utils.timezone import utc") + return value_repr, set(imports) + + +class DateSerializer(BaseSerializer): + def serialize(self): + value_repr = repr(self.value) + if isinstance(self.value, datetime_safe.date): + value_repr = "datetime.%s" % value_repr + return value_repr, {"import datetime"} + + +class DecimalSerializer(BaseSerializer): + def serialize(self): + return repr(self.value), {"from decimal import Decimal"} + + +class DeconstructableSerializer(BaseSerializer): + @staticmethod + def serialize_deconstructed(path, args, kwargs): + name, imports = DeconstructableSerializer._serialize_path(path) + strings = [] + for arg in args: + arg_string, arg_imports = serializer_factory(arg).serialize() + strings.append(arg_string) + imports.update(arg_imports) + for kw, arg in sorted(kwargs.items()): + arg_string, arg_imports = serializer_factory(arg).serialize() + imports.update(arg_imports) + strings.append("%s=%s" % (kw, arg_string)) + return "%s(%s)" % (name, ", ".join(strings)), imports + + @staticmethod + def _serialize_path(path): + module, name = path.rsplit(".", 1) + if module == "django.db.models": + imports = {"from django.db import models"} + name = "models.%s" % name + else: + imports = {"import %s" % module} + name = path + return name, imports + + def serialize(self): + return self.serialize_deconstructed(*self.value.deconstruct()) + + +class DictionarySerializer(BaseSerializer): + def serialize(self): + imports = set() + strings = [] + for k, v in sorted(self.value.items()): + k_string, k_imports = serializer_factory(k).serialize() + v_string, v_imports = serializer_factory(v).serialize() + imports.update(k_imports) + imports.update(v_imports) + strings.append((k_string, v_string)) + return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports + + +class EnumSerializer(BaseSerializer): + def serialize(self): + enum_class = self.value.__class__ + module = enum_class.__module__ + imports = {"import %s" % module} + v_string, v_imports = serializer_factory(self.value.value).serialize() + imports.update(v_imports) + return "%s.%s(%s)" % (module, enum_class.__name__, v_string), imports + + +class FloatSerializer(BaseSimpleSerializer): + def serialize(self): + if math.isnan(self.value) or math.isinf(self.value): + return 'float("{}")'.format(self.value), set() + return super(FloatSerializer, self).serialize() + + +class FrozensetSerializer(BaseSequenceSerializer): + def _format(self): + return "frozenset([%s])" + + +class FunctionTypeSerializer(BaseSerializer): + def serialize(self): + if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type): + klass = self.value.__self__ + module = klass.__module__ + return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module} + # Further error checking + if self.value.__name__ == '': + raise ValueError("Cannot serialize function: lambda") + if self.value.__module__ is None: + raise ValueError("Cannot serialize function %r: No module" % self.value) + # Python 3 is a lot easier, and only uses this branch if it's not local. + if getattr(self.value, "__qualname__", None) and getattr(self.value, "__module__", None): + if "<" not in self.value.__qualname__: # Qualname can include + return "%s.%s" % \ + (self.value.__module__, self.value.__qualname__), {"import %s" % self.value.__module__} + # Python 2/fallback version + module_name = self.value.__module__ + # Make sure it's actually there and not an unbound method + module = import_module(module_name) + if not hasattr(module, self.value.__name__): + raise ValueError( + "Could not find function %s in %s.\n" + "Please note that due to Python 2 limitations, you cannot " + "serialize unbound method functions (e.g. a method " + "declared and used in the same class body). Please move " + "the function into the main module body to use migrations.\n" + "For more information, see " + "https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values" + % (self.value.__name__, module_name, get_docs_version()) + ) + # Needed on Python 2 only + if module_name == '__builtin__': + return self.value.__name__, set() + return "%s.%s" % (module_name, self.value.__name__), {"import %s" % module_name} + + +class FunctoolsPartialSerializer(BaseSerializer): + def serialize(self): + imports = {'import functools'} + # Serialize functools.partial() arguments + func_string, func_imports = serializer_factory(self.value.func).serialize() + args_string, args_imports = serializer_factory(self.value.args).serialize() + keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize() + # Add any imports needed by arguments + imports.update(func_imports) + imports.update(args_imports) + imports.update(keywords_imports) + return ( + "functools.partial(%s, *%s, **%s)" % ( + func_string, args_string, keywords_string, + ), + imports, + ) + + +class IterableSerializer(BaseSerializer): + def serialize(self): + imports = set() + strings = [] + for item in self.value: + item_string, item_imports = serializer_factory(item).serialize() + imports.update(item_imports) + strings.append(item_string) + # When len(strings)==0, the empty iterable should be serialized as + # "()", not "(,)" because (,) is invalid Python syntax. + value = "(%s)" if len(strings) != 1 else "(%s,)" + return value % (", ".join(strings)), imports + + +class ModelFieldSerializer(DeconstructableSerializer): + def serialize(self): + attr_name, path, args, kwargs = self.value.deconstruct() + return self.serialize_deconstructed(path, args, kwargs) + + +class ModelManagerSerializer(DeconstructableSerializer): + def serialize(self): + as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct() + if as_manager: + name, imports = self._serialize_path(qs_path) + return "%s.as_manager()" % name, imports + else: + return self.serialize_deconstructed(manager_path, args, kwargs) + + +class OperationSerializer(BaseSerializer): + def serialize(self): + from django.db.migrations.writer import OperationWriter + string, imports = OperationWriter(self.value, indentation=0).serialize() + # Nested operation, trailing comma is handled in upper OperationWriter._write() + return string.rstrip(','), imports + + +class RegexSerializer(BaseSerializer): + def serialize(self): + imports = {"import re"} + regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize() + regex_flags, flag_imports = serializer_factory(self.value.flags).serialize() + imports.update(pattern_imports) + imports.update(flag_imports) + args = [regex_pattern] + if self.value.flags: + args.append(regex_flags) + return "re.compile(%s)" % ', '.join(args), imports + + +class SequenceSerializer(BaseSequenceSerializer): + def _format(self): + return "[%s]" + + +class SetSerializer(BaseSequenceSerializer): + def _format(self): + # Don't use the literal "{%s}" as it doesn't support empty set + return "set([%s])" + + +class SettingsReferenceSerializer(BaseSerializer): + def serialize(self): + return "settings.%s" % self.value.setting_name, {"from django.conf import settings"} + + +class TextTypeSerializer(BaseSerializer): + def serialize(self): + value_repr = repr(self.value) + if six.PY2: + # Strip the `u` prefix since we're importing unicode_literals + value_repr = value_repr[1:] + return value_repr, set() + + +class TimedeltaSerializer(BaseSerializer): + def serialize(self): + return repr(self.value), {"import datetime"} + + +class TimeSerializer(BaseSerializer): + def serialize(self): + value_repr = repr(self.value) + if isinstance(self.value, datetime_safe.time): + value_repr = "datetime.%s" % value_repr + return value_repr, {"import datetime"} + + +class TupleSerializer(BaseSequenceSerializer): + def _format(self): + # When len(value)==0, the empty tuple should be serialized as "()", + # not "(,)" because (,) is invalid Python syntax. + return "(%s)" if len(self.value) != 1 else "(%s,)" + + +class TypeSerializer(BaseSerializer): + def serialize(self): + special_cases = [ + (models.Model, "models.Model", []), + ] + for case, string, imports in special_cases: + if case is self.value: + return string, set(imports) + if hasattr(self.value, "__module__"): + module = self.value.__module__ + if module == six.moves.builtins.__name__: + return self.value.__name__, set() + else: + return "%s.%s" % (module, self.value.__name__), {"import %s" % module} + + +def serializer_factory(value): + from django.db.migrations.writer import SettingsReference + if isinstance(value, Promise): + value = force_text(value) + elif isinstance(value, LazyObject): + # The unwrapped value is returned as the first item of the arguments + # tuple. + value = value.__reduce__()[1][0] + + # Unfortunately some of these are order-dependent. + if isinstance(value, frozenset): + return FrozensetSerializer(value) + if isinstance(value, list): + return SequenceSerializer(value) + if isinstance(value, set): + return SetSerializer(value) + if isinstance(value, tuple): + return TupleSerializer(value) + if isinstance(value, dict): + return DictionarySerializer(value) + if enum and isinstance(value, enum.Enum): + return EnumSerializer(value) + if isinstance(value, datetime.datetime): + return DatetimeSerializer(value) + if isinstance(value, datetime.date): + return DateSerializer(value) + if isinstance(value, datetime.time): + return TimeSerializer(value) + if isinstance(value, datetime.timedelta): + return TimedeltaSerializer(value) + if isinstance(value, SettingsReference): + return SettingsReferenceSerializer(value) + if isinstance(value, float): + return FloatSerializer(value) + if isinstance(value, six.integer_types + (bool, type(None))): + return BaseSimpleSerializer(value) + if isinstance(value, six.binary_type): + return ByteTypeSerializer(value) + if isinstance(value, six.text_type): + return TextTypeSerializer(value) + if isinstance(value, decimal.Decimal): + return DecimalSerializer(value) + if isinstance(value, models.Field): + return ModelFieldSerializer(value) + if isinstance(value, type): + return TypeSerializer(value) + if isinstance(value, models.manager.BaseManager): + return ModelManagerSerializer(value) + if isinstance(value, Operation): + return OperationSerializer(value) + if isinstance(value, functools.partial): + return FunctoolsPartialSerializer(value) + # Anything that knows how to deconstruct itself. + if hasattr(value, 'deconstruct'): + return DeconstructableSerializer(value) + if isinstance(value, (types.FunctionType, types.BuiltinFunctionType)): + return FunctionTypeSerializer(value) + if isinstance(value, collections.Iterable): + return IterableSerializer(value) + if isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)): + return RegexSerializer(value) + raise ValueError( + "Cannot serialize: %r\nThere are some values Django cannot serialize into " + "migration files.\nFor more, see https://docs.djangoproject.com/en/%s/" + "topics/migrations/#migration-serializing" % (value, get_docs_version()) + ) diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index f36db323a09..137f83dbc01 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -1,29 +1,19 @@ from __future__ import unicode_literals -import collections -import datetime -import decimal -import functools -import math import os import re -import types from importlib import import_module from django import get_version from django.apps import apps -from django.db import migrations, models +from django.db import migrations from django.db.migrations.loader import MigrationLoader -from django.db.migrations.operations.base import Operation -from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject -from django.utils import datetime_safe, six +from django.db.migrations.serializer import serializer_factory from django.utils._os import upath from django.utils.encoding import force_text -from django.utils.functional import LazyObject, Promise from django.utils.inspect import get_func_args from django.utils.module_loading import module_dir -from django.utils.timezone import now, utc -from django.utils.version import get_docs_version +from django.utils.timezone import now try: import enum @@ -229,20 +219,6 @@ class MigrationWriter(object): return (MIGRATION_TEMPLATE % items).encode("utf8") - @staticmethod - def serialize_datetime(value): - """ - Returns a serialized version of a datetime object that is valid, - executable python code. It converts timezone-aware values to utc with - an 'executable' utc representation of tzinfo. - """ - if value.tzinfo is not None and value.tzinfo != utc: - value = value.astimezone(utc) - value_repr = repr(value).replace("", "utc") - if isinstance(value, datetime_safe.datetime): - value_repr = "datetime.%s" % value_repr - return value_repr - @property def basedir(self): migrations_package_name = MigrationLoader.migrations_module(self.migration.app_label) @@ -312,247 +288,9 @@ class MigrationWriter(object): def path(self): return os.path.join(self.basedir, self.filename) - @classmethod - def serialize_deconstructed(cls, path, args, kwargs): - name, imports = cls._serialize_path(path) - strings = [] - for arg in args: - arg_string, arg_imports = cls.serialize(arg) - strings.append(arg_string) - imports.update(arg_imports) - for kw, arg in sorted(kwargs.items()): - arg_string, arg_imports = cls.serialize(arg) - imports.update(arg_imports) - strings.append("%s=%s" % (kw, arg_string)) - return "%s(%s)" % (name, ", ".join(strings)), imports - - @classmethod - def _serialize_path(cls, path): - module, name = path.rsplit(".", 1) - if module == "django.db.models": - imports = {"from django.db import models"} - name = "models.%s" % name - else: - imports = {"import %s" % module} - name = path - return name, imports - @classmethod def serialize(cls, value): - """ - Serializes the value to a string that's parsable by Python, along - with any needed imports to make that string work. - More advanced than repr() as it can encode things - like datetime.datetime.now. - """ - # FIXME: Ideally Promise would be reconstructible, but for now we - # use force_text on them and defer to the normal string serialization - # process. - if isinstance(value, Promise): - value = force_text(value) - elif isinstance(value, LazyObject): - # The unwrapped value is returned as the first item of the - # arguments tuple. - value = value.__reduce__()[1][0] - - # Sequences - if isinstance(value, (frozenset, list, set, tuple)): - imports = set() - strings = [] - for item in value: - item_string, item_imports = cls.serialize(item) - imports.update(item_imports) - strings.append(item_string) - if isinstance(value, set): - # Don't use the literal "{%s}" as it doesn't support empty set - format = "set([%s])" - elif isinstance(value, frozenset): - format = "frozenset([%s])" - elif isinstance(value, tuple): - # When len(value)==0, the empty tuple should be serialized as - # "()", not "(,)" because (,) is invalid Python syntax. - format = "(%s)" if len(value) != 1 else "(%s,)" - else: - format = "[%s]" - return format % (", ".join(strings)), imports - # Dictionaries - elif isinstance(value, dict): - imports = set() - strings = [] - for k, v in sorted(value.items()): - k_string, k_imports = cls.serialize(k) - v_string, v_imports = cls.serialize(v) - imports.update(k_imports) - imports.update(v_imports) - strings.append((k_string, v_string)) - return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports - # Enums - elif enum and isinstance(value, enum.Enum): - enum_class = value.__class__ - module = enum_class.__module__ - imports = {"import %s" % module} - v_string, v_imports = cls.serialize(value.value) - imports.update(v_imports) - return "%s.%s(%s)" % (module, enum_class.__name__, v_string), imports - # Datetimes - elif isinstance(value, datetime.datetime): - value_repr = cls.serialize_datetime(value) - imports = ["import datetime"] - if value.tzinfo is not None: - imports.append("from django.utils.timezone import utc") - return value_repr, set(imports) - # Dates - elif isinstance(value, datetime.date): - value_repr = repr(value) - if isinstance(value, datetime_safe.date): - value_repr = "datetime.%s" % value_repr - return value_repr, {"import datetime"} - # Times - elif isinstance(value, datetime.time): - value_repr = repr(value) - if isinstance(value, datetime_safe.time): - value_repr = "datetime.%s" % value_repr - return value_repr, {"import datetime"} - # Timedeltas - elif isinstance(value, datetime.timedelta): - return repr(value), {"import datetime"} - # Settings references - elif isinstance(value, SettingsReference): - return "settings.%s" % value.setting_name, {"from django.conf import settings"} - # Simple types - elif isinstance(value, float): - if math.isnan(value) or math.isinf(value): - return 'float("{}")'.format(value), set() - return repr(value), set() - elif isinstance(value, six.integer_types + (bool, type(None))): - return repr(value), set() - elif isinstance(value, six.binary_type): - value_repr = repr(value) - if six.PY2: - # Prepend the `b` prefix since we're importing unicode_literals - value_repr = 'b' + value_repr - return value_repr, set() - elif isinstance(value, six.text_type): - value_repr = repr(value) - if six.PY2: - # Strip the `u` prefix since we're importing unicode_literals - value_repr = value_repr[1:] - return value_repr, set() - # Decimal - elif isinstance(value, decimal.Decimal): - return repr(value), {"from decimal import Decimal"} - # Django fields - elif isinstance(value, models.Field): - attr_name, path, args, kwargs = value.deconstruct() - return cls.serialize_deconstructed(path, args, kwargs) - # Classes - elif isinstance(value, type): - special_cases = [ - (models.Model, "models.Model", []), - ] - for case, string, imports in special_cases: - if case is value: - return string, set(imports) - if hasattr(value, "__module__"): - module = value.__module__ - if module == six.moves.builtins.__name__: - return value.__name__, set() - else: - return "%s.%s" % (module, value.__name__), {"import %s" % module} - elif isinstance(value, models.manager.BaseManager): - as_manager, manager_path, qs_path, args, kwargs = value.deconstruct() - if as_manager: - name, imports = cls._serialize_path(qs_path) - return "%s.as_manager()" % name, imports - else: - return cls.serialize_deconstructed(manager_path, args, kwargs) - elif isinstance(value, Operation): - string, imports = OperationWriter(value, indentation=0).serialize() - # Nested operation, trailing comma is handled in upper OperationWriter._write() - return string.rstrip(','), imports - elif isinstance(value, functools.partial): - imports = {'import functools'} - # Serialize functools.partial() arguments - func_string, func_imports = cls.serialize(value.func) - args_string, args_imports = cls.serialize(value.args) - keywords_string, keywords_imports = cls.serialize(value.keywords) - # Add any imports needed by arguments - imports.update(func_imports) - imports.update(args_imports) - imports.update(keywords_imports) - return ( - "functools.partial(%s, *%s, **%s)" % ( - func_string, args_string, keywords_string, - ), - imports, - ) - # Anything that knows how to deconstruct itself. - elif hasattr(value, 'deconstruct'): - return cls.serialize_deconstructed(*value.deconstruct()) - # Functions - elif isinstance(value, (types.FunctionType, types.BuiltinFunctionType)): - # @classmethod? - if getattr(value, "__self__", None) and isinstance(value.__self__, type): - klass = value.__self__ - module = klass.__module__ - return "%s.%s.%s" % (module, klass.__name__, value.__name__), {"import %s" % module} - # Further error checking - if value.__name__ == '': - raise ValueError("Cannot serialize function: lambda") - if value.__module__ is None: - raise ValueError("Cannot serialize function %r: No module" % value) - # Python 3 is a lot easier, and only uses this branch if it's not local. - if getattr(value, "__qualname__", None) and getattr(value, "__module__", None): - if "<" not in value.__qualname__: # Qualname can include - return "%s.%s" % (value.__module__, value.__qualname__), {"import %s" % value.__module__} - # Python 2/fallback version - module_name = value.__module__ - # Make sure it's actually there and not an unbound method - module = import_module(module_name) - if not hasattr(module, value.__name__): - raise ValueError( - "Could not find function %s in %s.\n" - "Please note that due to Python 2 limitations, you cannot " - "serialize unbound method functions (e.g. a method " - "declared and used in the same class body). Please move " - "the function into the main module body to use migrations.\n" - "For more information, see " - "https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values" - % (value.__name__, module_name, get_docs_version())) - # Needed on Python 2 only - if module_name == '__builtin__': - return value.__name__, set() - return "%s.%s" % (module_name, value.__name__), {"import %s" % module_name} - # Other iterables - elif isinstance(value, collections.Iterable): - imports = set() - strings = [] - for item in value: - item_string, item_imports = cls.serialize(item) - imports.update(item_imports) - strings.append(item_string) - # When len(strings)==0, the empty iterable should be serialized as - # "()", not "(,)" because (,) is invalid Python syntax. - format = "(%s)" if len(strings) != 1 else "(%s,)" - return format % (", ".join(strings)), imports - # Compiled regex - elif isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)): - imports = {"import re"} - regex_pattern, pattern_imports = cls.serialize(value.pattern) - regex_flags, flag_imports = cls.serialize(value.flags) - imports.update(pattern_imports) - imports.update(flag_imports) - args = [regex_pattern] - if value.flags: - args.append(regex_flags) - return "re.compile(%s)" % ', '.join(args), imports - # Uh oh. - else: - raise ValueError( - "Cannot serialize: %r\nThere are some values Django cannot serialize into " - "migration files.\nFor more, see https://docs.djangoproject.com/en/%s/" - "topics/migrations/#migration-serializing" % (value, get_docs_version()) - ) + return serializer_factory(value).serialize() MIGRATION_TEMPLATE = """\