From 503ee41497f346de27843a4ea976c57303c76558 Mon Sep 17 00:00:00 2001 From: manav014 Date: Mon, 24 May 2021 04:30:01 +0530 Subject: [PATCH] Refs #29898 -- Moved state_forwards()'s logic from migration operations to ProjectState. Thanks Simon Charette and Markus Holtermann for reviews. --- django/db/migrations/operations/fields.py | 95 +++----------- django/db/migrations/operations/models.py | 83 ++++-------- django/db/migrations/state.py | 149 ++++++++++++++++++++++ 3 files changed, 189 insertions(+), 138 deletions(-) diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index b303f70465..641c142191 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -1,7 +1,4 @@ -from django.core.exceptions import FieldDoesNotExist -from django.db.migrations.utils import ( - field_is_referenced, field_references, get_references, -) +from django.db.migrations.utils import field_references from django.db.models import NOT_PROVIDED from django.utils.functional import cached_property @@ -85,16 +82,13 @@ class AddField(FieldOperation): ) def state_forwards(self, app_label, state): - # If preserve default is off, don't use the default for future state - if not self.preserve_default: - field = self.field.clone() - field.default = NOT_PROVIDED - else: - field = self.field - state.models[app_label, self.model_name_lower].fields[self.name] = field - # Delay rendering of relationships if it's not a relational field - delay = not field.is_relation - state.reload_model(app_label, self.model_name_lower, delay=delay) + state.add_field( + app_label, + self.model_name_lower, + self.name, + self.field, + self.preserve_default, + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) @@ -160,11 +154,7 @@ class RemoveField(FieldOperation): ) def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.model_name_lower] - old_field = model_state.fields.pop(self.name) - # Delay rendering of relationships if it's not a relational field - delay = not old_field.is_relation - state.reload_model(app_label, self.model_name_lower, delay=delay) + state.remove_field(app_label, self.model_name_lower, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.apps.get_model(app_label, self.model_name) @@ -216,24 +206,13 @@ class AlterField(FieldOperation): ) def state_forwards(self, app_label, state): - if not self.preserve_default: - field = self.field.clone() - field.default = NOT_PROVIDED - else: - field = self.field - model_state = state.models[app_label, self.model_name_lower] - model_state.fields[self.name] = field - # TODO: investigate if old relational fields must be reloaded or if it's - # sufficient if the new field is (#27737). - # Delay rendering of relationships if it's not a relational field and - # not referenced by a foreign key. - delay = ( - not field.is_relation and - not field_is_referenced( - state, (app_label, self.model_name_lower), (self.name, field), - ) + state.alter_field( + app_label, + self.model_name_lower, + self.name, + self.field, + self.preserve_default, ) - state.reload_model(app_label, self.model_name_lower, delay=delay) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) @@ -301,49 +280,7 @@ class RenameField(FieldOperation): ) def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.model_name_lower] - # Rename the field - fields = model_state.fields - try: - found = fields.pop(self.old_name) - except KeyError: - raise FieldDoesNotExist( - "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name) - ) - fields[self.new_name] = found - for field in fields.values(): - # Fix from_fields to refer to the new field. - from_fields = getattr(field, 'from_fields', None) - if from_fields: - field.from_fields = tuple([ - self.new_name if from_field_name == self.old_name else from_field_name - for from_field_name in from_fields - ]) - # Fix index/unique_together to refer to the new field - options = model_state.options - for option in ('index_together', 'unique_together'): - if option in options: - options[option] = [ - [self.new_name if n == self.old_name else n for n in together] - for together in options[option] - ] - # Fix to_fields to refer to the new field. - delay = True - references = get_references( - state, (app_label, self.model_name_lower), (self.old_name, found), - ) - for *_, field, reference in references: - delay = False - if reference.to: - remote_field, to_fields = reference.to - if getattr(remote_field, 'field_name', None) == self.old_name: - remote_field.field_name = self.new_name - if to_fields: - field.to_fields = tuple([ - self.new_name if to_field_name == self.old_name else to_field_name - for to_field_name in to_fields - ]) - state.reload_model(app_label, self.model_name_lower, delay=delay) + state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index 10f2686dc6..982816be3a 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -1,9 +1,7 @@ from django.db import models from django.db.migrations.operations.base import Operation from django.db.migrations.state import ModelState -from django.db.migrations.utils import ( - field_references, get_references, resolve_relation, -) +from django.db.migrations.utils import field_references, resolve_relation from django.db.models.options import normalize_together from django.utils.functional import cached_property @@ -316,31 +314,7 @@ class RenameModel(ModelOperation): ) def state_forwards(self, app_label, state): - # Add a new model. - renamed_model = state.models[app_label, self.old_name_lower].clone() - renamed_model.name = self.new_name - state.models[app_label, self.new_name_lower] = renamed_model - # Repoint all fields pointing to the old model to the new one. - old_model_tuple = (app_label, self.old_name_lower) - new_remote_model = '%s.%s' % (app_label, self.new_name) - to_reload = set() - for model_state, name, field, reference in get_references(state, old_model_tuple): - changed_field = None - if reference.to: - changed_field = field.clone() - changed_field.remote_field.model = new_remote_model - if reference.through: - if changed_field is None: - changed_field = field.clone() - changed_field.remote_field.through = new_remote_model - if changed_field: - model_state.fields[name] = changed_field - to_reload.add((model_state.app_label, model_state.name_lower)) - # Reload models related to old model before removing the old model. - state.reload_models(to_reload, delay=True) - # Remove the old model. - state.remove_model(app_label, self.old_name_lower) - state.reload_model(app_label, self.new_name_lower, delay=True) + state.rename_model(app_label, self.old_name, self.new_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.new_name) @@ -458,8 +432,7 @@ class AlterModelTable(ModelOptionOperation): ) def state_forwards(self, app_label, state): - state.models[app_label, self.name_lower].options["db_table"] = self.table - state.reload_model(app_label, self.name_lower, delay=True) + state.alter_model_options(app_label, self.name_lower, {'db_table': self.table}) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.name) @@ -518,9 +491,11 @@ class AlterTogetherOptionOperation(ModelOptionOperation): ) def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.name_lower] - model_state.options[self.option_name] = self.option_value - state.reload_model(app_label, self.name_lower, delay=True) + state.alter_model_options( + app_label, + self.name_lower, + {self.option_name: self.option_value}, + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.name) @@ -596,9 +571,11 @@ class AlterOrderWithRespectTo(ModelOptionOperation): ) def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.name_lower] - model_state.options['order_with_respect_to'] = self.order_with_respect_to - state.reload_model(app_label, self.name_lower, delay=True) + state.alter_model_options( + app_label, + self.name_lower, + {self.option_name: self.order_with_respect_to}, + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.name) @@ -676,12 +653,12 @@ class AlterModelOptions(ModelOptionOperation): ) def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.name_lower] - model_state.options = {**model_state.options, **self.options} - for key in self.ALTER_OPTION_KEYS: - if key not in self.options: - model_state.options.pop(key, False) - state.reload_model(app_label, self.name_lower, delay=True) + state.alter_model_options( + app_label, + self.name_lower, + self.options, + self.ALTER_OPTION_KEYS, + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): pass @@ -714,9 +691,7 @@ class AlterModelManagers(ModelOptionOperation): ) def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.name_lower] - model_state.managers = list(self.managers) - state.reload_model(app_label, self.name_lower, delay=True) + state.alter_model_managers(app_label, self.name_lower, self.managers) def database_forwards(self, app_label, schema_editor, from_state, to_state): pass @@ -753,9 +728,7 @@ class AddIndex(IndexOperation): self.index = index def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.model_name_lower] - model_state.options[self.option_name] = [*model_state.options[self.option_name], self.index.clone()] - state.reload_model(app_label, self.model_name_lower, delay=True) + state.add_index(app_label, self.model_name_lower, self.index) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = to_state.apps.get_model(app_label, self.model_name) @@ -804,10 +777,7 @@ class RemoveIndex(IndexOperation): self.name = name def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.model_name_lower] - indexes = model_state.options[self.option_name] - model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name] - state.reload_model(app_label, self.model_name_lower, delay=True) + state.remove_index(app_label, self.model_name_lower, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = from_state.apps.get_model(app_label, self.model_name) @@ -850,9 +820,7 @@ class AddConstraint(IndexOperation): self.constraint = constraint def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.model_name_lower] - model_state.options[self.option_name] = [*model_state.options[self.option_name], self.constraint] - state.reload_model(app_label, self.model_name_lower, delay=True) + state.add_constraint(app_label, self.model_name_lower, self.constraint) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = to_state.apps.get_model(app_label, self.model_name) @@ -886,10 +854,7 @@ class RemoveConstraint(IndexOperation): self.name = name def state_forwards(self, app_label, state): - model_state = state.models[app_label, self.model_name_lower] - constraints = model_state.options[self.option_name] - model_state.options[self.option_name] = [c for c in constraints if c.name != self.name] - state.reload_model(app_label, self.model_name_lower, delay=True) + state.remove_constraint(app_label, self.model_name_lower, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = to_state.apps.get_model(app_label, self.model_name) diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index 38ec72dd72..af9093c4e8 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -6,7 +6,10 @@ from functools import partial from django.apps import AppConfig from django.apps.registry import Apps, apps as global_apps from django.conf import settings +from django.core.exceptions import FieldDoesNotExist from django.db import models +from django.db.migrations.utils import field_is_referenced, get_references +from django.db.models import NOT_PROVIDED from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT from django.db.models.options import DEFAULT_NAMES, normalize_together from django.db.models.utils import make_model_tuple @@ -107,6 +110,152 @@ class ProjectState: # the cache automatically (#24513) self.apps.clear_cache() + def rename_model(self, app_label, old_name, new_name): + # Add a new model. + old_name_lower = old_name.lower() + new_name_lower = new_name.lower() + renamed_model = self.models[app_label, old_name_lower].clone() + renamed_model.name = new_name + self.models[app_label, new_name_lower] = renamed_model + # Repoint all fields pointing to the old model to the new one. + old_model_tuple = (app_label, old_name_lower) + new_remote_model = f'{app_label}.{new_name}' + to_reload = set() + for model_state, name, field, reference in get_references(self, old_model_tuple): + changed_field = None + if reference.to: + changed_field = field.clone() + changed_field.remote_field.model = new_remote_model + if reference.through: + if changed_field is None: + changed_field = field.clone() + changed_field.remote_field.through = new_remote_model + if changed_field: + model_state.fields[name] = changed_field + to_reload.add((model_state.app_label, model_state.name_lower)) + # Reload models related to old model before removing the old model. + self.reload_models(to_reload, delay=True) + # Remove the old model. + self.remove_model(app_label, old_name_lower) + self.reload_model(app_label, new_name_lower, delay=True) + + def alter_model_options(self, app_label, model_name, options, option_keys=None): + model_state = self.models[app_label, model_name] + model_state.options = {**model_state.options, **options} + if option_keys: + for key in option_keys: + if key not in options: + model_state.options.pop(key, False) + self.reload_model(app_label, model_name, delay=True) + + def alter_model_managers(self, app_label, model_name, managers): + model_state = self.models[app_label, model_name] + model_state.managers = list(managers) + self.reload_model(app_label, model_name, delay=True) + + def _append_option(self, app_label, model_name, option_name, obj): + model_state = self.models[app_label, model_name] + model_state.options[option_name] = [*model_state.options[option_name], obj] + self.reload_model(app_label, model_name, delay=True) + + def _remove_option(self, app_label, model_name, option_name, obj_name): + model_state = self.models[app_label, model_name] + objs = model_state.options[option_name] + model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name] + self.reload_model(app_label, model_name, delay=True) + + def add_index(self, app_label, model_name, index): + self._append_option(app_label, model_name, 'indexes', index) + + def remove_index(self, app_label, model_name, index_name): + self._remove_option(app_label, model_name, 'indexes', index_name) + + def add_constraint(self, app_label, model_name, constraint): + self._append_option(app_label, model_name, 'constraints', constraint) + + def remove_constraint(self, app_label, model_name, constraint_name): + self._remove_option(app_label, model_name, 'constraints', constraint_name) + + def add_field(self, app_label, model_name, name, field, preserve_default): + # If preserve default is off, don't use the default for future state. + if not preserve_default: + field = field.clone() + field.default = NOT_PROVIDED + else: + field = field + self.models[app_label, model_name].fields[name] = field + # Delay rendering of relationships if it's not a relational field. + delay = not field.is_relation + self.reload_model(app_label, model_name, delay=delay) + + def remove_field(self, app_label, model_name, name): + model_state = self.models[app_label, model_name] + old_field = model_state.fields.pop(name) + # Delay rendering of relationships if it's not a relational field. + delay = not old_field.is_relation + self.reload_model(app_label, model_name, delay=delay) + + def alter_field(self, app_label, model_name, name, field, preserve_default): + if not preserve_default: + field = field.clone() + field.default = NOT_PROVIDED + else: + field = field + model_state = self.models[app_label, model_name] + model_state.fields[name] = field + # TODO: investigate if old relational fields must be reloaded or if + # it's sufficient if the new field is (#27737). + # Delay rendering of relationships if it's not a relational field and + # not referenced by a foreign key. + delay = ( + not field.is_relation and + not field_is_referenced(self, (app_label, model_name), (name, field)) + ) + self.reload_model(app_label, model_name, delay=delay) + + def rename_field(self, app_label, model_name, old_name, new_name): + model_state = self.models[app_label, model_name] + # Rename the field. + fields = model_state.fields + try: + found = fields.pop(old_name) + except KeyError: + raise FieldDoesNotExist( + f"{app_label}.{model_name} has no field named '{old_name}'" + ) + fields[new_name] = found + for field in fields.values(): + # Fix from_fields to refer to the new field. + from_fields = getattr(field, 'from_fields', None) + if from_fields: + field.from_fields = tuple([ + new_name if from_field_name == old_name else from_field_name + for from_field_name in from_fields + ]) + # Fix index/unique_together to refer to the new field. + options = model_state.options + for option in ('index_together', 'unique_together'): + if option in options: + options[option] = [ + [new_name if n == old_name else n for n in together] + for together in options[option] + ] + # Fix to_fields to refer to the new field. + delay = True + references = get_references(self, (app_label, model_name), (old_name, found)) + for *_, field, reference in references: + delay = False + if reference.to: + remote_field, to_fields = reference.to + if getattr(remote_field, 'field_name', None) == old_name: + remote_field.field_name = new_name + if to_fields: + field.to_fields = tuple([ + new_name if to_field_name == old_name else to_field_name + for to_field_name in to_fields + ]) + self.reload_model(app_label, model_name, delay=delay) + def _find_reload_model(self, app_label, model_name, delay=False): if delay: self.is_delayed = True