mirror of https://github.com/django/django.git
Refs #29898 -- Moved state_forwards()'s logic from migration operations to ProjectState.
Thanks Simon Charette and Markus Holtermann for reviews.
This commit is contained in:
parent
594d6e9407
commit
503ee41497
|
@ -1,7 +1,4 @@
|
|||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db.migrations.utils import (
|
||||
field_is_referenced, field_references, get_references,
|
||||
)
|
||||
from django.db.migrations.utils import field_references
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
@ -85,16 +82,13 @@ class AddField(FieldOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# If preserve default is off, don't use the default for future state
|
||||
if not self.preserve_default:
|
||||
field = self.field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = self.field
|
||||
state.models[app_label, self.model_name_lower].fields[self.name] = field
|
||||
# Delay rendering of relationships if it's not a relational field
|
||||
delay = not field.is_relation
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
state.add_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
|
@ -160,11 +154,7 @@ class RemoveField(FieldOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
old_field = model_state.fields.pop(self.name)
|
||||
# Delay rendering of relationships if it's not a relational field
|
||||
delay = not old_field.is_relation
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
state.remove_field(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
|
@ -216,24 +206,13 @@ class AlterField(FieldOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
if not self.preserve_default:
|
||||
field = self.field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = self.field
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
model_state.fields[self.name] = field
|
||||
# TODO: investigate if old relational fields must be reloaded or if it's
|
||||
# sufficient if the new field is (#27737).
|
||||
# Delay rendering of relationships if it's not a relational field and
|
||||
# not referenced by a foreign key.
|
||||
delay = (
|
||||
not field.is_relation and
|
||||
not field_is_referenced(
|
||||
state, (app_label, self.model_name_lower), (self.name, field),
|
||||
state.alter_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
)
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
|
@ -301,49 +280,7 @@ class RenameField(FieldOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
# Rename the field
|
||||
fields = model_state.fields
|
||||
try:
|
||||
found = fields.pop(self.old_name)
|
||||
except KeyError:
|
||||
raise FieldDoesNotExist(
|
||||
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
|
||||
)
|
||||
fields[self.new_name] = found
|
||||
for field in fields.values():
|
||||
# Fix from_fields to refer to the new field.
|
||||
from_fields = getattr(field, 'from_fields', None)
|
||||
if from_fields:
|
||||
field.from_fields = tuple([
|
||||
self.new_name if from_field_name == self.old_name else from_field_name
|
||||
for from_field_name in from_fields
|
||||
])
|
||||
# Fix index/unique_together to refer to the new field
|
||||
options = model_state.options
|
||||
for option in ('index_together', 'unique_together'):
|
||||
if option in options:
|
||||
options[option] = [
|
||||
[self.new_name if n == self.old_name else n for n in together]
|
||||
for together in options[option]
|
||||
]
|
||||
# Fix to_fields to refer to the new field.
|
||||
delay = True
|
||||
references = get_references(
|
||||
state, (app_label, self.model_name_lower), (self.old_name, found),
|
||||
)
|
||||
for *_, field, reference in references:
|
||||
delay = False
|
||||
if reference.to:
|
||||
remote_field, to_fields = reference.to
|
||||
if getattr(remote_field, 'field_name', None) == self.old_name:
|
||||
remote_field.field_name = self.new_name
|
||||
if to_fields:
|
||||
field.to_fields = tuple([
|
||||
self.new_name if to_field_name == self.old_name else to_field_name
|
||||
for to_field_name in to_fields
|
||||
])
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.state import ModelState
|
||||
from django.db.migrations.utils import (
|
||||
field_references, get_references, resolve_relation,
|
||||
)
|
||||
from django.db.migrations.utils import field_references, resolve_relation
|
||||
from django.db.models.options import normalize_together
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
@ -316,31 +314,7 @@ class RenameModel(ModelOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# Add a new model.
|
||||
renamed_model = state.models[app_label, self.old_name_lower].clone()
|
||||
renamed_model.name = self.new_name
|
||||
state.models[app_label, self.new_name_lower] = renamed_model
|
||||
# Repoint all fields pointing to the old model to the new one.
|
||||
old_model_tuple = (app_label, self.old_name_lower)
|
||||
new_remote_model = '%s.%s' % (app_label, self.new_name)
|
||||
to_reload = set()
|
||||
for model_state, name, field, reference in get_references(state, old_model_tuple):
|
||||
changed_field = None
|
||||
if reference.to:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.model = new_remote_model
|
||||
if reference.through:
|
||||
if changed_field is None:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.through = new_remote_model
|
||||
if changed_field:
|
||||
model_state.fields[name] = changed_field
|
||||
to_reload.add((model_state.app_label, model_state.name_lower))
|
||||
# Reload models related to old model before removing the old model.
|
||||
state.reload_models(to_reload, delay=True)
|
||||
# Remove the old model.
|
||||
state.remove_model(app_label, self.old_name_lower)
|
||||
state.reload_model(app_label, self.new_name_lower, delay=True)
|
||||
state.rename_model(app_label, self.old_name, self.new_name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.new_name)
|
||||
|
@ -458,8 +432,7 @@ class AlterModelTable(ModelOptionOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.models[app_label, self.name_lower].options["db_table"] = self.table
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
state.alter_model_options(app_label, self.name_lower, {'db_table': self.table})
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.name)
|
||||
|
@ -518,9 +491,11 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.options[self.option_name] = self.option_value
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
state.alter_model_options(
|
||||
app_label,
|
||||
self.name_lower,
|
||||
{self.option_name: self.option_value},
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.name)
|
||||
|
@ -596,9 +571,11 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.options['order_with_respect_to'] = self.order_with_respect_to
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
state.alter_model_options(
|
||||
app_label,
|
||||
self.name_lower,
|
||||
{self.option_name: self.order_with_respect_to},
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.name)
|
||||
|
@ -676,12 +653,12 @@ class AlterModelOptions(ModelOptionOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.options = {**model_state.options, **self.options}
|
||||
for key in self.ALTER_OPTION_KEYS:
|
||||
if key not in self.options:
|
||||
model_state.options.pop(key, False)
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
state.alter_model_options(
|
||||
app_label,
|
||||
self.name_lower,
|
||||
self.options,
|
||||
self.ALTER_OPTION_KEYS,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
@ -714,9 +691,7 @@ class AlterModelManagers(ModelOptionOperation):
|
|||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.name_lower]
|
||||
model_state.managers = list(self.managers)
|
||||
state.reload_model(app_label, self.name_lower, delay=True)
|
||||
state.alter_model_managers(app_label, self.name_lower, self.managers)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
@ -753,9 +728,7 @@ class AddIndex(IndexOperation):
|
|||
self.index = index
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.index.clone()]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
state.add_index(app_label, self.model_name_lower, self.index)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
|
@ -804,10 +777,7 @@ class RemoveIndex(IndexOperation):
|
|||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
indexes = model_state.options[self.option_name]
|
||||
model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
state.remove_index(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.model_name)
|
||||
|
@ -850,9 +820,7 @@ class AddConstraint(IndexOperation):
|
|||
self.constraint = constraint
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.constraint]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
state.add_constraint(app_label, self.model_name_lower, self.constraint)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
|
@ -886,10 +854,7 @@ class RemoveConstraint(IndexOperation):
|
|||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
model_state = state.models[app_label, self.model_name_lower]
|
||||
constraints = model_state.options[self.option_name]
|
||||
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
|
||||
state.reload_model(app_label, self.model_name_lower, delay=True)
|
||||
state.remove_constraint(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
|
|
|
@ -6,7 +6,10 @@ from functools import partial
|
|||
from django.apps import AppConfig
|
||||
from django.apps.registry import Apps, apps as global_apps
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db import models
|
||||
from django.db.migrations.utils import field_is_referenced, get_references
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
from django.db.models.options import DEFAULT_NAMES, normalize_together
|
||||
from django.db.models.utils import make_model_tuple
|
||||
|
@ -107,6 +110,152 @@ class ProjectState:
|
|||
# the cache automatically (#24513)
|
||||
self.apps.clear_cache()
|
||||
|
||||
def rename_model(self, app_label, old_name, new_name):
|
||||
# Add a new model.
|
||||
old_name_lower = old_name.lower()
|
||||
new_name_lower = new_name.lower()
|
||||
renamed_model = self.models[app_label, old_name_lower].clone()
|
||||
renamed_model.name = new_name
|
||||
self.models[app_label, new_name_lower] = renamed_model
|
||||
# Repoint all fields pointing to the old model to the new one.
|
||||
old_model_tuple = (app_label, old_name_lower)
|
||||
new_remote_model = f'{app_label}.{new_name}'
|
||||
to_reload = set()
|
||||
for model_state, name, field, reference in get_references(self, old_model_tuple):
|
||||
changed_field = None
|
||||
if reference.to:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.model = new_remote_model
|
||||
if reference.through:
|
||||
if changed_field is None:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.through = new_remote_model
|
||||
if changed_field:
|
||||
model_state.fields[name] = changed_field
|
||||
to_reload.add((model_state.app_label, model_state.name_lower))
|
||||
# Reload models related to old model before removing the old model.
|
||||
self.reload_models(to_reload, delay=True)
|
||||
# Remove the old model.
|
||||
self.remove_model(app_label, old_name_lower)
|
||||
self.reload_model(app_label, new_name_lower, delay=True)
|
||||
|
||||
def alter_model_options(self, app_label, model_name, options, option_keys=None):
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.options = {**model_state.options, **options}
|
||||
if option_keys:
|
||||
for key in option_keys:
|
||||
if key not in options:
|
||||
model_state.options.pop(key, False)
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def alter_model_managers(self, app_label, model_name, managers):
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.managers = list(managers)
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def _append_option(self, app_label, model_name, option_name, obj):
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.options[option_name] = [*model_state.options[option_name], obj]
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def _remove_option(self, app_label, model_name, option_name, obj_name):
|
||||
model_state = self.models[app_label, model_name]
|
||||
objs = model_state.options[option_name]
|
||||
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def add_index(self, app_label, model_name, index):
|
||||
self._append_option(app_label, model_name, 'indexes', index)
|
||||
|
||||
def remove_index(self, app_label, model_name, index_name):
|
||||
self._remove_option(app_label, model_name, 'indexes', index_name)
|
||||
|
||||
def add_constraint(self, app_label, model_name, constraint):
|
||||
self._append_option(app_label, model_name, 'constraints', constraint)
|
||||
|
||||
def remove_constraint(self, app_label, model_name, constraint_name):
|
||||
self._remove_option(app_label, model_name, 'constraints', constraint_name)
|
||||
|
||||
def add_field(self, app_label, model_name, name, field, preserve_default):
|
||||
# If preserve default is off, don't use the default for future state.
|
||||
if not preserve_default:
|
||||
field = field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = field
|
||||
self.models[app_label, model_name].fields[name] = field
|
||||
# Delay rendering of relationships if it's not a relational field.
|
||||
delay = not field.is_relation
|
||||
self.reload_model(app_label, model_name, delay=delay)
|
||||
|
||||
def remove_field(self, app_label, model_name, name):
|
||||
model_state = self.models[app_label, model_name]
|
||||
old_field = model_state.fields.pop(name)
|
||||
# Delay rendering of relationships if it's not a relational field.
|
||||
delay = not old_field.is_relation
|
||||
self.reload_model(app_label, model_name, delay=delay)
|
||||
|
||||
def alter_field(self, app_label, model_name, name, field, preserve_default):
|
||||
if not preserve_default:
|
||||
field = field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = field
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.fields[name] = field
|
||||
# TODO: investigate if old relational fields must be reloaded or if
|
||||
# it's sufficient if the new field is (#27737).
|
||||
# Delay rendering of relationships if it's not a relational field and
|
||||
# not referenced by a foreign key.
|
||||
delay = (
|
||||
not field.is_relation and
|
||||
not field_is_referenced(self, (app_label, model_name), (name, field))
|
||||
)
|
||||
self.reload_model(app_label, model_name, delay=delay)
|
||||
|
||||
def rename_field(self, app_label, model_name, old_name, new_name):
|
||||
model_state = self.models[app_label, model_name]
|
||||
# Rename the field.
|
||||
fields = model_state.fields
|
||||
try:
|
||||
found = fields.pop(old_name)
|
||||
except KeyError:
|
||||
raise FieldDoesNotExist(
|
||||
f"{app_label}.{model_name} has no field named '{old_name}'"
|
||||
)
|
||||
fields[new_name] = found
|
||||
for field in fields.values():
|
||||
# Fix from_fields to refer to the new field.
|
||||
from_fields = getattr(field, 'from_fields', None)
|
||||
if from_fields:
|
||||
field.from_fields = tuple([
|
||||
new_name if from_field_name == old_name else from_field_name
|
||||
for from_field_name in from_fields
|
||||
])
|
||||
# Fix index/unique_together to refer to the new field.
|
||||
options = model_state.options
|
||||
for option in ('index_together', 'unique_together'):
|
||||
if option in options:
|
||||
options[option] = [
|
||||
[new_name if n == old_name else n for n in together]
|
||||
for together in options[option]
|
||||
]
|
||||
# Fix to_fields to refer to the new field.
|
||||
delay = True
|
||||
references = get_references(self, (app_label, model_name), (old_name, found))
|
||||
for *_, field, reference in references:
|
||||
delay = False
|
||||
if reference.to:
|
||||
remote_field, to_fields = reference.to
|
||||
if getattr(remote_field, 'field_name', None) == old_name:
|
||||
remote_field.field_name = new_name
|
||||
if to_fields:
|
||||
field.to_fields = tuple([
|
||||
new_name if to_field_name == old_name else to_field_name
|
||||
for to_field_name in to_fields
|
||||
])
|
||||
self.reload_model(app_label, model_name, delay=delay)
|
||||
|
||||
def _find_reload_model(self, app_label, model_name, delay=False):
|
||||
if delay:
|
||||
self.is_delayed = True
|
||||
|
|
Loading…
Reference in New Issue