Refs #29898 -- Moved state_forwards()'s logic from migration operations to ProjectState.

Thanks Simon Charette and Markus Holtermann for reviews.
This commit is contained in:
manav014 2021-05-24 04:30:01 +05:30 committed by Mariusz Felisiak
parent 594d6e9407
commit 503ee41497
3 changed files with 189 additions and 138 deletions

View File

@ -1,7 +1,4 @@
from django.core.exceptions import FieldDoesNotExist from django.db.migrations.utils import field_references
from django.db.migrations.utils import (
field_is_referenced, field_references, get_references,
)
from django.db.models import NOT_PROVIDED from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -85,16 +82,13 @@ class AddField(FieldOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
# If preserve default is off, don't use the default for future state state.add_field(
if not self.preserve_default: app_label,
field = self.field.clone() self.model_name_lower,
field.default = NOT_PROVIDED self.name,
else: self.field,
field = self.field self.preserve_default,
state.models[app_label, self.model_name_lower].fields[self.name] = field )
# Delay rendering of relationships if it's not a relational field
delay = not field.is_relation
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name) to_model = to_state.apps.get_model(app_label, self.model_name)
@ -160,11 +154,7 @@ class RemoveField(FieldOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower] state.remove_field(app_label, self.model_name_lower, self.name)
old_field = model_state.fields.pop(self.name)
# Delay rendering of relationships if it's not a relational field
delay = not old_field.is_relation
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name) from_model = from_state.apps.get_model(app_label, self.model_name)
@ -216,24 +206,13 @@ class AlterField(FieldOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
if not self.preserve_default: state.alter_field(
field = self.field.clone() app_label,
field.default = NOT_PROVIDED self.model_name_lower,
else: self.name,
field = self.field self.field,
model_state = state.models[app_label, self.model_name_lower] self.preserve_default,
model_state.fields[self.name] = field
# TODO: investigate if old relational fields must be reloaded or if it's
# sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = (
not field.is_relation and
not field_is_referenced(
state, (app_label, self.model_name_lower), (self.name, field),
) )
)
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name) to_model = to_state.apps.get_model(app_label, self.model_name)
@ -301,49 +280,7 @@ class RenameField(FieldOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower] state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name)
# Rename the field
fields = model_state.fields
try:
found = fields.pop(self.old_name)
except KeyError:
raise FieldDoesNotExist(
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
)
fields[self.new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, 'from_fields', None)
if from_fields:
field.from_fields = tuple([
self.new_name if from_field_name == self.old_name else from_field_name
for from_field_name in from_fields
])
# Fix index/unique_together to refer to the new field
options = model_state.options
for option in ('index_together', 'unique_together'):
if option in options:
options[option] = [
[self.new_name if n == self.old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(
state, (app_label, self.model_name_lower), (self.old_name, found),
)
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, 'field_name', None) == self.old_name:
remote_field.field_name = self.new_name
if to_fields:
field.to_fields = tuple([
self.new_name if to_field_name == self.old_name else to_field_name
for to_field_name in to_fields
])
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name) to_model = to_state.apps.get_model(app_label, self.model_name)

View File

@ -1,9 +1,7 @@
from django.db import models from django.db import models
from django.db.migrations.operations.base import Operation from django.db.migrations.operations.base import Operation
from django.db.migrations.state import ModelState from django.db.migrations.state import ModelState
from django.db.migrations.utils import ( from django.db.migrations.utils import field_references, resolve_relation
field_references, get_references, resolve_relation,
)
from django.db.models.options import normalize_together from django.db.models.options import normalize_together
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -316,31 +314,7 @@ class RenameModel(ModelOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
# Add a new model. state.rename_model(app_label, self.old_name, self.new_name)
renamed_model = state.models[app_label, self.old_name_lower].clone()
renamed_model.name = self.new_name
state.models[app_label, self.new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, self.old_name_lower)
new_remote_model = '%s.%s' % (app_label, self.new_name)
to_reload = set()
for model_state, name, field, reference in get_references(state, old_model_tuple):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
# Reload models related to old model before removing the old model.
state.reload_models(to_reload, delay=True)
# Remove the old model.
state.remove_model(app_label, self.old_name_lower)
state.reload_model(app_label, self.new_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name) new_model = to_state.apps.get_model(app_label, self.new_name)
@ -458,8 +432,7 @@ class AlterModelTable(ModelOptionOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
state.models[app_label, self.name_lower].options["db_table"] = self.table state.alter_model_options(app_label, self.name_lower, {'db_table': self.table})
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name) new_model = to_state.apps.get_model(app_label, self.name)
@ -518,9 +491,11 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower] state.alter_model_options(
model_state.options[self.option_name] = self.option_value app_label,
state.reload_model(app_label, self.name_lower, delay=True) self.name_lower,
{self.option_name: self.option_value},
)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name) new_model = to_state.apps.get_model(app_label, self.name)
@ -596,9 +571,11 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower] state.alter_model_options(
model_state.options['order_with_respect_to'] = self.order_with_respect_to app_label,
state.reload_model(app_label, self.name_lower, delay=True) self.name_lower,
{self.option_name: self.order_with_respect_to},
)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.name) to_model = to_state.apps.get_model(app_label, self.name)
@ -676,12 +653,12 @@ class AlterModelOptions(ModelOptionOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower] state.alter_model_options(
model_state.options = {**model_state.options, **self.options} app_label,
for key in self.ALTER_OPTION_KEYS: self.name_lower,
if key not in self.options: self.options,
model_state.options.pop(key, False) self.ALTER_OPTION_KEYS,
state.reload_model(app_label, self.name_lower, delay=True) )
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass pass
@ -714,9 +691,7 @@ class AlterModelManagers(ModelOptionOperation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower] state.alter_model_managers(app_label, self.name_lower, self.managers)
model_state.managers = list(self.managers)
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass pass
@ -753,9 +728,7 @@ class AddIndex(IndexOperation):
self.index = index self.index = index
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower] state.add_index(app_label, self.model_name_lower, self.index)
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.index.clone()]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name) model = to_state.apps.get_model(app_label, self.model_name)
@ -804,10 +777,7 @@ class RemoveIndex(IndexOperation):
self.name = name self.name = name
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower] state.remove_index(app_label, self.model_name_lower, self.name)
indexes = model_state.options[self.option_name]
model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name) model = from_state.apps.get_model(app_label, self.model_name)
@ -850,9 +820,7 @@ class AddConstraint(IndexOperation):
self.constraint = constraint self.constraint = constraint
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower] state.add_constraint(app_label, self.model_name_lower, self.constraint)
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.constraint]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name) model = to_state.apps.get_model(app_label, self.model_name)
@ -886,10 +854,7 @@ class RemoveConstraint(IndexOperation):
self.name = name self.name = name
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower] state.remove_constraint(app_label, self.model_name_lower, self.name)
constraints = model_state.options[self.option_name]
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name) model = to_state.apps.get_model(app_label, self.model_name)

View File

@ -6,7 +6,10 @@ from functools import partial
from django.apps import AppConfig from django.apps import AppConfig
from django.apps.registry import Apps, apps as global_apps from django.apps.registry import Apps, apps as global_apps
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.db import models from django.db import models
from django.db.migrations.utils import field_is_referenced, get_references
from django.db.models import NOT_PROVIDED
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
from django.db.models.options import DEFAULT_NAMES, normalize_together from django.db.models.options import DEFAULT_NAMES, normalize_together
from django.db.models.utils import make_model_tuple from django.db.models.utils import make_model_tuple
@ -107,6 +110,152 @@ class ProjectState:
# the cache automatically (#24513) # the cache automatically (#24513)
self.apps.clear_cache() self.apps.clear_cache()
def rename_model(self, app_label, old_name, new_name):
# Add a new model.
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
renamed_model = self.models[app_label, old_name_lower].clone()
renamed_model.name = new_name
self.models[app_label, new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, old_name_lower)
new_remote_model = f'{app_label}.{new_name}'
to_reload = set()
for model_state, name, field, reference in get_references(self, old_model_tuple):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
# Reload models related to old model before removing the old model.
self.reload_models(to_reload, delay=True)
# Remove the old model.
self.remove_model(app_label, old_name_lower)
self.reload_model(app_label, new_name_lower, delay=True)
def alter_model_options(self, app_label, model_name, options, option_keys=None):
model_state = self.models[app_label, model_name]
model_state.options = {**model_state.options, **options}
if option_keys:
for key in option_keys:
if key not in options:
model_state.options.pop(key, False)
self.reload_model(app_label, model_name, delay=True)
def alter_model_managers(self, app_label, model_name, managers):
model_state = self.models[app_label, model_name]
model_state.managers = list(managers)
self.reload_model(app_label, model_name, delay=True)
def _append_option(self, app_label, model_name, option_name, obj):
model_state = self.models[app_label, model_name]
model_state.options[option_name] = [*model_state.options[option_name], obj]
self.reload_model(app_label, model_name, delay=True)
def _remove_option(self, app_label, model_name, option_name, obj_name):
model_state = self.models[app_label, model_name]
objs = model_state.options[option_name]
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
self.reload_model(app_label, model_name, delay=True)
def add_index(self, app_label, model_name, index):
self._append_option(app_label, model_name, 'indexes', index)
def remove_index(self, app_label, model_name, index_name):
self._remove_option(app_label, model_name, 'indexes', index_name)
def add_constraint(self, app_label, model_name, constraint):
self._append_option(app_label, model_name, 'constraints', constraint)
def remove_constraint(self, app_label, model_name, constraint_name):
self._remove_option(app_label, model_name, 'constraints', constraint_name)
def add_field(self, app_label, model_name, name, field, preserve_default):
# If preserve default is off, don't use the default for future state.
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
self.models[app_label, model_name].fields[name] = field
# Delay rendering of relationships if it's not a relational field.
delay = not field.is_relation
self.reload_model(app_label, model_name, delay=delay)
def remove_field(self, app_label, model_name, name):
model_state = self.models[app_label, model_name]
old_field = model_state.fields.pop(name)
# Delay rendering of relationships if it's not a relational field.
delay = not old_field.is_relation
self.reload_model(app_label, model_name, delay=delay)
def alter_field(self, app_label, model_name, name, field, preserve_default):
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_state = self.models[app_label, model_name]
model_state.fields[name] = field
# TODO: investigate if old relational fields must be reloaded or if
# it's sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = (
not field.is_relation and
not field_is_referenced(self, (app_label, model_name), (name, field))
)
self.reload_model(app_label, model_name, delay=delay)
def rename_field(self, app_label, model_name, old_name, new_name):
model_state = self.models[app_label, model_name]
# Rename the field.
fields = model_state.fields
try:
found = fields.pop(old_name)
except KeyError:
raise FieldDoesNotExist(
f"{app_label}.{model_name} has no field named '{old_name}'"
)
fields[new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, 'from_fields', None)
if from_fields:
field.from_fields = tuple([
new_name if from_field_name == old_name else from_field_name
for from_field_name in from_fields
])
# Fix index/unique_together to refer to the new field.
options = model_state.options
for option in ('index_together', 'unique_together'):
if option in options:
options[option] = [
[new_name if n == old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(self, (app_label, model_name), (old_name, found))
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, 'field_name', None) == old_name:
remote_field.field_name = new_name
if to_fields:
field.to_fields = tuple([
new_name if to_field_name == old_name else to_field_name
for to_field_name in to_fields
])
self.reload_model(app_label, model_name, delay=delay)
def _find_reload_model(self, app_label, model_name, delay=False): def _find_reload_model(self, app_label, model_name, delay=False):
if delay: if delay:
self.is_delayed = True self.is_delayed = True