Refs #28305 -- Consolidated field referencing detection in migrations.
This moves all the field referencing resolution methods to shared functions instead of duplicating efforts amongst state_forwards and references methods.
This commit is contained in:
parent
734fde7714
commit
f5ede1cb6d
|
@ -3,9 +3,7 @@ from django.db.models import NOT_PROVIDED
|
|||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Operation
|
||||
from .utils import (
|
||||
field_references_model, is_referenced_by_foreign_key, resolve_relation,
|
||||
)
|
||||
from .utils import field_is_referenced, field_references, get_references
|
||||
|
||||
|
||||
class FieldOperation(Operation):
|
||||
|
@ -33,9 +31,9 @@ class FieldOperation(Operation):
|
|||
if name_lower == self.model_name_lower:
|
||||
return True
|
||||
if self.field:
|
||||
return field_references_model(
|
||||
return bool(field_references(
|
||||
(app_label, self.model_name_lower), self.field, (app_label, name_lower)
|
||||
)
|
||||
))
|
||||
return False
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
|
@ -47,20 +45,14 @@ class FieldOperation(Operation):
|
|||
elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
|
||||
return True
|
||||
# Check if this operation remotely references the field.
|
||||
if self.field:
|
||||
model_tuple = (app_label, model_name_lower)
|
||||
remote_field = self.field.remote_field
|
||||
if remote_field:
|
||||
if (resolve_relation(remote_field.model, app_label, self.model_name_lower) == model_tuple and
|
||||
(not hasattr(self.field, 'to_fields') or
|
||||
name in self.field.to_fields or None in self.field.to_fields)):
|
||||
return True
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if (through and resolve_relation(through, app_label, self.model_name_lower) == model_tuple and
|
||||
(getattr(remote_field, 'through_fields', None) is None or
|
||||
name in remote_field.through_fields)):
|
||||
return True
|
||||
return False
|
||||
if self.field is None:
|
||||
return False
|
||||
return bool(field_references(
|
||||
(app_label, self.model_name_lower),
|
||||
self.field,
|
||||
(app_label, model_name_lower),
|
||||
name,
|
||||
))
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
return (
|
||||
|
@ -236,7 +228,9 @@ class AlterField(FieldOperation):
|
|||
# not referenced by a foreign key.
|
||||
delay = (
|
||||
not field.is_relation and
|
||||
not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name)
|
||||
not field_is_referenced(
|
||||
state, (app_label, self.model_name_lower), (self.name, field),
|
||||
)
|
||||
)
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
|
@ -305,17 +299,11 @@ class RenameField(FieldOperation):
|
|||
model_state = state.models[app_label, self.model_name_lower]
|
||||
# Rename the field
|
||||
fields = model_state.fields
|
||||
found = False
|
||||
found = None
|
||||
for index, (name, field) in enumerate(fields):
|
||||
if not found and name == self.old_name:
|
||||
fields[index] = (self.new_name, field)
|
||||
found = True
|
||||
# Delay rendering of relationships if it's not a relational
|
||||
# field and not referenced by a foreign key.
|
||||
delay = (
|
||||
not field.is_relation and
|
||||
not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name)
|
||||
)
|
||||
found = field
|
||||
# Fix from_fields to refer to the new field.
|
||||
from_fields = getattr(field, 'from_fields', None)
|
||||
if from_fields:
|
||||
|
@ -323,7 +311,7 @@ class RenameField(FieldOperation):
|
|||
self.new_name if from_field_name == self.old_name else from_field_name
|
||||
for from_field_name in from_fields
|
||||
])
|
||||
if not found:
|
||||
if found is None:
|
||||
raise FieldDoesNotExist(
|
||||
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
|
||||
)
|
||||
|
@ -336,23 +324,21 @@ class RenameField(FieldOperation):
|
|||
for together in options[option]
|
||||
]
|
||||
# Fix to_fields to refer to the new field.
|
||||
model_tuple = app_label, self.model_name_lower
|
||||
for (model_app_label, model_name), model_state in state.models.items():
|
||||
for index, (name, field) in enumerate(model_state.fields):
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
remote_model_tuple = resolve_relation(
|
||||
remote_field.model, model_app_label, model_name
|
||||
)
|
||||
if remote_model_tuple == model_tuple:
|
||||
if getattr(remote_field, 'field_name', None) == self.old_name:
|
||||
remote_field.field_name = self.new_name
|
||||
to_fields = getattr(field, 'to_fields', None)
|
||||
if to_fields:
|
||||
field.to_fields = tuple([
|
||||
self.new_name if to_field_name == self.old_name else to_field_name
|
||||
for to_field_name in to_fields
|
||||
])
|
||||
delay = True
|
||||
references = get_references(
|
||||
state, (app_label, self.model_name_lower), (self.old_name, found),
|
||||
)
|
||||
for *_, field, reference in references:
|
||||
delay = False
|
||||
if reference.to:
|
||||
remote_field, to_fields = reference.to
|
||||
if getattr(remote_field, 'field_name', None) == self.old_name:
|
||||
remote_field.field_name = self.new_name
|
||||
if to_fields:
|
||||
field.to_fields = tuple([
|
||||
self.new_name if to_field_name == self.old_name else to_field_name
|
||||
for to_field_name in to_fields
|
||||
])
|
||||
state.reload_model(app_label, self.model_name_lower, delay=delay)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
|
|
|
@ -7,7 +7,7 @@ from django.utils.functional import cached_property
|
|||
from .fields import (
|
||||
AddField, AlterField, FieldOperation, RemoveField, RenameField,
|
||||
)
|
||||
from .utils import field_references_model, resolve_relation
|
||||
from .utils import field_references, get_references, resolve_relation
|
||||
|
||||
|
||||
def _check_for_duplicates(arg_name, objs):
|
||||
|
@ -113,7 +113,7 @@ class CreateModel(ModelOperation):
|
|||
|
||||
# Check we have no FKs/M2Ms with it
|
||||
for _name, field in self.fields:
|
||||
if field_references_model((app_label, self.name_lower), field, reference_model_tuple):
|
||||
if field_references((app_label, self.name_lower), field, reference_model_tuple):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -309,33 +309,19 @@ class RenameModel(ModelOperation):
|
|||
# Repoint all fields pointing to the old model to the new one.
|
||||
old_model_tuple = (app_label, self.old_name_lower)
|
||||
new_remote_model = '%s.%s' % (app_label, self.new_name)
|
||||
to_reload = []
|
||||
for (model_app_label, model_name), model_state in state.models.items():
|
||||
model_changed = False
|
||||
for index, (name, field) in enumerate(model_state.fields):
|
||||
changed_field = None
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
remote_model_tuple = resolve_relation(
|
||||
remote_field.model, model_app_label, model_name
|
||||
)
|
||||
if remote_model_tuple == old_model_tuple:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.model = new_remote_model
|
||||
through_model = getattr(remote_field, 'through', None)
|
||||
if through_model:
|
||||
through_model_tuple = resolve_relation(
|
||||
through_model, model_app_label, model_name
|
||||
)
|
||||
if through_model_tuple == old_model_tuple:
|
||||
if changed_field is None:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.through = new_remote_model
|
||||
if changed_field:
|
||||
model_state.fields[index] = name, changed_field
|
||||
model_changed = True
|
||||
if model_changed:
|
||||
to_reload.append((model_app_label, model_name))
|
||||
to_reload = set()
|
||||
for model_state, index, name, field, reference in get_references(state, old_model_tuple):
|
||||
changed_field = None
|
||||
if reference.to:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.model = new_remote_model
|
||||
if reference.through:
|
||||
if changed_field is None:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.through = new_remote_model
|
||||
if changed_field:
|
||||
model_state.fields[index] = name, changed_field
|
||||
to_reload.add((model_state.app_label, model_state.name_lower))
|
||||
# Reload models related to old model before removing the old model.
|
||||
state.reload_models(to_reload, delay=True)
|
||||
# Remove the old model.
|
||||
|
|
|
@ -1,17 +1,8 @@
|
|||
from collections import namedtuple
|
||||
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
|
||||
def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
|
||||
for state_app_label, state_model in state.models:
|
||||
for _, f in state.models[state_app_label, state_model].fields:
|
||||
if (f.related_model and
|
||||
'%s.%s' % (state_app_label, model_name_lower) == f.related_model.lower() and
|
||||
hasattr(f, 'to_fields')):
|
||||
if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def resolve_relation(model, app_label=None, model_name=None):
|
||||
"""
|
||||
Turn a model class or model reference string and return a model tuple.
|
||||
|
@ -38,13 +29,73 @@ def resolve_relation(model, app_label=None, model_name=None):
|
|||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
|
||||
def field_references_model(model_tuple, field, reference_model_tuple):
|
||||
"""Return whether or not field references reference_model_tuple."""
|
||||
FieldReference = namedtuple('FieldReference', 'to through')
|
||||
|
||||
|
||||
def field_references(
|
||||
model_tuple,
|
||||
field,
|
||||
reference_model_tuple,
|
||||
reference_field_name=None,
|
||||
reference_field=None,
|
||||
):
|
||||
"""
|
||||
Return either False or a FieldReference if `field` references provided
|
||||
context.
|
||||
|
||||
False positives can be returned if `reference_field_name` is provided
|
||||
without `reference_field` because of the introspection limitation it
|
||||
incurs. This should not be an issue when this function is used to determine
|
||||
whether or not an optimization can take place.
|
||||
"""
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
|
||||
return True
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
|
||||
return True
|
||||
return False
|
||||
if not remote_field:
|
||||
return False
|
||||
references_to = None
|
||||
references_through = None
|
||||
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
|
||||
to_fields = getattr(field, 'to_fields', None)
|
||||
if (
|
||||
reference_field_name is None or
|
||||
# Unspecified to_field(s).
|
||||
to_fields is None or
|
||||
# Reference to primary key.
|
||||
(None in to_fields and (reference_field is None or reference_field.primary_key)) or
|
||||
# Reference to field.
|
||||
reference_field_name in to_fields
|
||||
):
|
||||
references_to = (remote_field, to_fields)
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
|
||||
through_fields = remote_field.through_fields
|
||||
if (
|
||||
reference_field_name is None or
|
||||
# Unspecified through_fields.
|
||||
through_fields is None or
|
||||
# Reference to field.
|
||||
reference_field_name in through_fields
|
||||
):
|
||||
references_through = (remote_field, through_fields)
|
||||
if not (references_to or references_through):
|
||||
return False
|
||||
return FieldReference(references_to, references_through)
|
||||
|
||||
|
||||
def get_references(state, model_tuple, field_tuple=()):
|
||||
"""
|
||||
Generator of (model_state, index, name, field, reference) referencing
|
||||
provided context.
|
||||
|
||||
If field_tuple is provided only references to this particular field of
|
||||
model_tuple will be generated.
|
||||
"""
|
||||
for state_model_tuple, model_state in state.models.items():
|
||||
for index, (name, field) in enumerate(model_state.fields):
|
||||
reference = field_references(state_model_tuple, field, model_tuple, *field_tuple)
|
||||
if reference:
|
||||
yield model_state, index, name, field, reference
|
||||
|
||||
|
||||
def field_is_referenced(state, model_tuple, field_tuple):
|
||||
"""Return whether `field_tuple` is referenced by any state models."""
|
||||
return next(get_references(state, model_tuple, field_tuple), None) is not None
|
||||
|
|
Loading…
Reference in New Issue