diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 8bde081a87..568515d8ac 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -4,7 +4,7 @@ from django.utils.functional import cached_property from .base import Operation from .utils import ( - ModelTuple, field_references_model, is_referenced_by_foreign_key, + field_references_model, is_referenced_by_foreign_key, resolve_relation, ) @@ -33,7 +33,9 @@ class FieldOperation(Operation): if name_lower == self.model_name_lower: return True if self.field: - return field_references_model(self.field, ModelTuple(app_label, name_lower)) + return field_references_model( + (app_label, self.model_name_lower), self.field, (app_label, name_lower) + ) return False def references_field(self, model_name, name, app_label): @@ -46,15 +48,15 @@ class FieldOperation(Operation): return True # Check if this operation remotely references the field. if self.field: - model_tuple = ModelTuple(app_label, model_name_lower) + model_tuple = (app_label, model_name_lower) remote_field = self.field.remote_field if remote_field: - if (ModelTuple.from_model(remote_field.model) == model_tuple and + if (resolve_relation(remote_field.model, app_label, self.model_name_lower) == model_tuple and (not hasattr(self.field, 'to_fields') or name in self.field.to_fields or None in self.field.to_fields)): return True through = getattr(remote_field, 'through', None) - if (through and ModelTuple.from_model(through) == model_tuple and + if (through and resolve_relation(through, app_label, self.model_name_lower) == model_tuple and (getattr(remote_field, 'through_fields', None) is None or name in remote_field.through_fields)): return True @@ -339,7 +341,7 @@ class RenameField(FieldOperation): for index, (name, field) in enumerate(model_state.fields): remote_field = field.remote_field if remote_field: - remote_model_tuple = ModelTuple.from_model( + remote_model_tuple = resolve_relation( remote_field.model, model_app_label, model_name ) if remote_model_tuple == model_tuple: diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index b415bd6417..fa247f56eb 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -7,7 +7,7 @@ from django.utils.functional import cached_property from .fields import ( AddField, AlterField, FieldOperation, RemoveField, RenameField, ) -from .utils import ModelTuple, field_references_model +from .utils import field_references_model, resolve_relation def _check_for_duplicates(arg_name, objs): @@ -105,15 +105,15 @@ class CreateModel(ModelOperation): return True # Check we didn't inherit from the model - model_tuple = ModelTuple(app_label, name_lower) + reference_model_tuple = (app_label, name_lower) for base in self.bases: if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and - ModelTuple.from_model(base) == model_tuple): + resolve_relation(base, app_label) == reference_model_tuple): return True # Check we have no FKs/M2Ms with it for _name, field in self.fields: - if field_references_model(field, model_tuple): + if field_references_model((app_label, self.name_lower), field, reference_model_tuple): return True return False @@ -307,7 +307,7 @@ class RenameModel(ModelOperation): renamed_model.name = self.new_name state.models[app_label, self.new_name_lower] = renamed_model # Repoint all fields pointing to the old model to the new one. - old_model_tuple = ModelTuple(app_label, self.old_name_lower) + old_model_tuple = (app_label, self.old_name_lower) new_remote_model = '%s.%s' % (app_label, self.new_name) to_reload = [] for (model_app_label, model_name), model_state in state.models.items(): @@ -316,7 +316,7 @@ class RenameModel(ModelOperation): changed_field = None remote_field = field.remote_field if remote_field: - remote_model_tuple = ModelTuple.from_model( + remote_model_tuple = resolve_relation( remote_field.model, model_app_label, model_name ) if remote_model_tuple == old_model_tuple: @@ -324,7 +324,7 @@ class RenameModel(ModelOperation): changed_field.remote_field.model = new_remote_model through_model = getattr(remote_field, 'through', None) if through_model: - through_model_tuple = ModelTuple.from_model( + through_model_tuple = resolve_relation( through_model, model_app_label, model_name ) if through_model_tuple == old_model_tuple: diff --git a/django/db/migrations/operations/utils.py b/django/db/migrations/operations/utils.py index adb2a7febb..24319fb383 100644 --- a/django/db/migrations/operations/utils.py +++ b/django/db/migrations/operations/utils.py @@ -1,5 +1,3 @@ -from collections import namedtuple - from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT @@ -14,40 +12,39 @@ def is_referenced_by_foreign_key(state, model_name_lower, field, field_name): return False -class ModelTuple(namedtuple('ModelTupleBase', ('app_label', 'model_name'))): - @classmethod - def from_model(cls, model, app_label=None, model_name=None): - """ - Take a model class or an 'app_label.ModelName' string and return a - ModelTuple('app_label', 'modelname'). The optional app_label and - model_name arguments are the defaults if "self" or "ModelName" are - passed. - """ - if isinstance(model, str): - if model == RECURSIVE_RELATIONSHIP_CONSTANT: - return cls(app_label, model_name) - if '.' in model: - return cls(*model.lower().split('.', 1)) - return cls(app_label, model.lower()) - return cls(model._meta.app_label, model._meta.model_name) +def resolve_relation(model, app_label=None, model_name=None): + """ + Turn a model class or model reference string and return a model tuple. - def __eq__(self, other): - if isinstance(other, ModelTuple): - # Consider ModelTuple equal if their model_name is equal and either - # one of them is missing an app_label. - return self.model_name == other.model_name and ( - self.app_label is None or other.app_label is None or self.app_label == other.app_label + app_label and model_name are used to resolve the scope of recursive and + unscoped model relationship. + """ + if isinstance(model, str): + if model == RECURSIVE_RELATIONSHIP_CONSTANT: + if app_label is None or model_name is None: + raise TypeError( + 'app_label and model_name must be provided to resolve ' + 'recursive relationships.' + ) + return app_label, model_name + if '.' in model: + return tuple(model.lower().split('.', 1)) + if app_label is None: + raise TypeError( + 'app_label must be provided to resolve unscoped model ' + 'relationships.' ) - return super().__eq__(other) + return app_label, model.lower() + return model._meta.app_label, model._meta.model_name -def field_references_model(field, model_tuple): - """Return whether or not field references model_tuple.""" +def field_references_model(model_tuple, field, reference_model_tuple): + """Return whether or not field references reference_model_tuple.""" remote_field = field.remote_field if remote_field: - if ModelTuple.from_model(remote_field.model) == model_tuple: + if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: return True through = getattr(remote_field, 'through', None) - if through and ModelTuple.from_model(through) == model_tuple: + if through and resolve_relation(through, *model_tuple) == reference_model_tuple: return True return False