From 013bcf57d54afea02611413c0169351a1521ee7c Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 3 Feb 2017 00:42:15 -0500 Subject: [PATCH] Introduced ModelTuple to remove migrations boilerplate. --- django/db/migrations/operations/base.py | 10 ------ django/db/migrations/operations/fields.py | 41 ++++----------------- django/db/migrations/operations/models.py | 27 +++++++------- django/db/migrations/operations/utils.py | 44 +++++++++++++++++++++++ 4 files changed, 62 insertions(+), 60 deletions(-) diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py index 2448284a2bf..3fb1002c445 100644 --- a/django/db/migrations/operations/base.py +++ b/django/db/migrations/operations/base.py @@ -80,16 +80,6 @@ class Operation: """ return "%s: %s" % (self.__class__.__name__, self._constructor_args) - def model_to_key(self, model): - """ - Take either a model class or an 'app_label.ModelName' string and return - (app_label, model_name). - """ - if isinstance(model, str): - return tuple(model.lower().split('.', 1)) - else: - return model._meta.app_label, model._meta.model_name - def references_model(self, name, app_label=None): """ Return True if there is a chance this operation references the given diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 66231292e84..34f1d2b64d5 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -3,7 +3,9 @@ from django.db.models.fields import NOT_PROVIDED from django.utils.functional import cached_property from .base import Operation -from .utils import is_referenced_by_foreign_key +from .utils import ( + ModelTuple, field_references_model, is_referenced_by_foreign_key, +) class FieldOperation(Operation): @@ -31,40 +33,9 @@ class FieldOperation(Operation): if name_lower == self.model_name_lower: return True if self.field: - if self.field.remote_field: - remote_app_label, remote_model_name = self.model_to_key(self.field.remote_field.model) - if (remote_model_name == name_lower and app_label is None or - not remote_app_label or remote_app_label == app_label): - return True - through = getattr(self.field.remote_field, 'through', None) - if through and self.model_to_key(through) == (app_label, name_lower): - through_app_label, through_model_name = self.model_to_key(through) - if (through_model_name == name_lower and app_label is None or - not through_app_label or through_app_label == app_label): - return True - return False - return True - - def references_field(self, model_name, name, app_label=None): - if self.field: - model_name_lower = model_name.lower() - remote_field = self.field.remote_field - if remote_field: - remote_app_label, remote_model_name = self.model_to_key(remote_field.model) - if (remote_model_name == model_name_lower and - (app_label is None or not remote_app_label or remote_app_label == app_label)): - # TODO: Consider to_fields/from_fields. - return True - through = getattr(remote_field, 'through', None) - if through and self.model_to_key(through) == (app_label, model_name_lower): - through_app_label, through_model_name = self.model_to_key(through) - if (through_model_name == model_name_lower and - (app_label is None or not through_app_label or through_app_label == app_label) and - (remote_field.through_fields is None or name in remote_field.through_fields)): - return True - elif model_name_lower == self.model_name_lower and name == self.name: - return True - return False + return field_references_model(self.field, ModelTuple(app_label, name_lower)) + # Refuse the temptation to guess. This operation could be performed on + # a field referencing the specified model. return True def reduce(self, operation, in_between, app_label=None): diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index 88f3507c229..b2d3f70feab 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -7,6 +7,7 @@ from django.utils.functional import cached_property from .fields import ( AddField, AlterField, FieldOperation, RemoveField, RenameField, ) +from .utils import ModelTuple, field_references_model def _check_for_duplicates(arg_name, objs): @@ -104,19 +105,15 @@ class CreateModel(ModelOperation): return True # Check we didn't inherit from the model - models_to_check = [ - base for base in self.bases - if base is not models.Model and isinstance(base, (models.base.ModelBase, str)) - ] + model_tuple = ModelTuple(app_label, name_lower) + for base in self.bases: + if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and + ModelTuple.from_model(base) == model_tuple): + return True + # Check we have no FKs/M2Ms with it - for fname, field in self.fields: - if field.remote_field: - models_to_check.append(field.remote_field.model) - # Now go over all the models and check against them - for model in models_to_check: - model_app_label, model_name = self.model_to_key(model) - if (model_name == name_lower and app_label is None or - not model_app_label or model_app_label == app_label): + for _name, field in self.fields: + if field_references_model(field, model_tuple): return True return False @@ -267,7 +264,7 @@ class RenameModel(ModelOperation): renamed_model.name = self.new_name state.models[app_label, self.new_name_lower] = renamed_model # Repoint all fields pointing to the old model to the new one. - old_model_tuple = app_label, self.old_name_lower + old_model_tuple = ModelTuple(app_label, self.old_name_lower) new_remote_model = '%s.%s' % (app_label, self.new_name) to_reload = [] for (model_app_label, model_name), model_state in state.models.items(): @@ -276,7 +273,7 @@ class RenameModel(ModelOperation): changed_field = None remote_field = field.remote_field if remote_field: - remote_model_tuple = self._get_model_tuple( + remote_model_tuple = ModelTuple.from_model( remote_field.model, model_app_label, model_name ) if remote_model_tuple == old_model_tuple: @@ -284,7 +281,7 @@ class RenameModel(ModelOperation): changed_field.remote_field.model = new_remote_model through_model = getattr(remote_field, 'through', None) if through_model: - through_model_tuple = self._get_model_tuple( + through_model_tuple = ModelTuple.from_model( through_model, model_app_label, model_name ) if through_model_tuple == old_model_tuple: diff --git a/django/db/migrations/operations/utils.py b/django/db/migrations/operations/utils.py index af23ea95634..34fdaba8217 100644 --- a/django/db/migrations/operations/utils.py +++ b/django/db/migrations/operations/utils.py @@ -1,3 +1,8 @@ +from collections import namedtuple + +from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT + + def is_referenced_by_foreign_key(state, model_name_lower, field, field_name): for state_app_label, state_model in state.models: for _, f in state.models[state_app_label, state_model].fields: @@ -7,3 +12,42 @@ def is_referenced_by_foreign_key(state, model_name_lower, field, field_name): if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields: return True return False + + +class ModelTuple(namedtuple('ModelTupleBase', ('app_label', 'model_name'))): + @classmethod + def from_model(cls, model, app_label=None, model_name=None): + """ + Take a model class or a 'app_label.ModelName' string and return a + ModelTuple('app_label', 'modelname'). The optional app_label and + model_name arguments are the defaults if "self" or "ModelName" are + passed. + """ + if isinstance(model, str): + if model == RECURSIVE_RELATIONSHIP_CONSTANT: + return cls(app_label, model_name) + if '.' in model: + return cls(*model.lower().split('.', 1)) + return cls(app_label, model.lower()) + return cls(model._meta.app_label, model._meta.model_name) + + def __eq__(self, other): + if isinstance(other, ModelTuple): + # Consider ModelTuple equal if their model_name is equal and either + # one of them is missing an app_label. + return self.model_name == other.model_name and ( + self.app_label is None or other.app_label is None or self.app_label == other.app_label + ) + return super().__eq__(other) + + +def field_references_model(field, model_tuple): + """Return whether or not field references model_tuple.""" + remote_field = field.remote_field + if remote_field: + if ModelTuple.from_model(remote_field.model) == model_tuple: + return True + through = getattr(remote_field, 'through', None) + if through and ModelTuple.from_model(through) == model_tuple: + return True + return False