Introduced ModelTuple to remove migrations boilerplate.
This commit is contained in:
parent
ad82900ad9
commit
013bcf57d5
|
@ -80,16 +80,6 @@ class Operation:
|
|||
"""
|
||||
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
|
||||
|
||||
def model_to_key(self, model):
|
||||
"""
|
||||
Take either a model class or an 'app_label.ModelName' string and return
|
||||
(app_label, model_name).
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
return tuple(model.lower().split('.', 1))
|
||||
else:
|
||||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
def references_model(self, name, app_label=None):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
|
|
|
@ -3,7 +3,9 @@ from django.db.models.fields import NOT_PROVIDED
|
|||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Operation
|
||||
from .utils import is_referenced_by_foreign_key
|
||||
from .utils import (
|
||||
ModelTuple, field_references_model, is_referenced_by_foreign_key,
|
||||
)
|
||||
|
||||
|
||||
class FieldOperation(Operation):
|
||||
|
@ -31,40 +33,9 @@ class FieldOperation(Operation):
|
|||
if name_lower == self.model_name_lower:
|
||||
return True
|
||||
if self.field:
|
||||
if self.field.remote_field:
|
||||
remote_app_label, remote_model_name = self.model_to_key(self.field.remote_field.model)
|
||||
if (remote_model_name == name_lower and app_label is None or
|
||||
not remote_app_label or remote_app_label == app_label):
|
||||
return True
|
||||
through = getattr(self.field.remote_field, 'through', None)
|
||||
if through and self.model_to_key(through) == (app_label, name_lower):
|
||||
through_app_label, through_model_name = self.model_to_key(through)
|
||||
if (through_model_name == name_lower and app_label is None or
|
||||
not through_app_label or through_app_label == app_label):
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def references_field(self, model_name, name, app_label=None):
|
||||
if self.field:
|
||||
model_name_lower = model_name.lower()
|
||||
remote_field = self.field.remote_field
|
||||
if remote_field:
|
||||
remote_app_label, remote_model_name = self.model_to_key(remote_field.model)
|
||||
if (remote_model_name == model_name_lower and
|
||||
(app_label is None or not remote_app_label or remote_app_label == app_label)):
|
||||
# TODO: Consider to_fields/from_fields.
|
||||
return True
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if through and self.model_to_key(through) == (app_label, model_name_lower):
|
||||
through_app_label, through_model_name = self.model_to_key(through)
|
||||
if (through_model_name == model_name_lower and
|
||||
(app_label is None or not through_app_label or through_app_label == app_label) and
|
||||
(remote_field.through_fields is None or name in remote_field.through_fields)):
|
||||
return True
|
||||
elif model_name_lower == self.model_name_lower and name == self.name:
|
||||
return True
|
||||
return False
|
||||
return field_references_model(self.field, ModelTuple(app_label, name_lower))
|
||||
# Refuse the temptation to guess. This operation could be performed on
|
||||
# a field referencing the specified model.
|
||||
return True
|
||||
|
||||
def reduce(self, operation, in_between, app_label=None):
|
||||
|
|
|
@ -7,6 +7,7 @@ from django.utils.functional import cached_property
|
|||
from .fields import (
|
||||
AddField, AlterField, FieldOperation, RemoveField, RenameField,
|
||||
)
|
||||
from .utils import ModelTuple, field_references_model
|
||||
|
||||
|
||||
def _check_for_duplicates(arg_name, objs):
|
||||
|
@ -104,19 +105,15 @@ class CreateModel(ModelOperation):
|
|||
return True
|
||||
|
||||
# Check we didn't inherit from the model
|
||||
models_to_check = [
|
||||
base for base in self.bases
|
||||
if base is not models.Model and isinstance(base, (models.base.ModelBase, str))
|
||||
]
|
||||
model_tuple = ModelTuple(app_label, name_lower)
|
||||
for base in self.bases:
|
||||
if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
|
||||
ModelTuple.from_model(base) == model_tuple):
|
||||
return True
|
||||
|
||||
# Check we have no FKs/M2Ms with it
|
||||
for fname, field in self.fields:
|
||||
if field.remote_field:
|
||||
models_to_check.append(field.remote_field.model)
|
||||
# Now go over all the models and check against them
|
||||
for model in models_to_check:
|
||||
model_app_label, model_name = self.model_to_key(model)
|
||||
if (model_name == name_lower and app_label is None or
|
||||
not model_app_label or model_app_label == app_label):
|
||||
for _name, field in self.fields:
|
||||
if field_references_model(field, model_tuple):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -267,7 +264,7 @@ class RenameModel(ModelOperation):
|
|||
renamed_model.name = self.new_name
|
||||
state.models[app_label, self.new_name_lower] = renamed_model
|
||||
# Repoint all fields pointing to the old model to the new one.
|
||||
old_model_tuple = app_label, self.old_name_lower
|
||||
old_model_tuple = ModelTuple(app_label, self.old_name_lower)
|
||||
new_remote_model = '%s.%s' % (app_label, self.new_name)
|
||||
to_reload = []
|
||||
for (model_app_label, model_name), model_state in state.models.items():
|
||||
|
@ -276,7 +273,7 @@ class RenameModel(ModelOperation):
|
|||
changed_field = None
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
remote_model_tuple = self._get_model_tuple(
|
||||
remote_model_tuple = ModelTuple.from_model(
|
||||
remote_field.model, model_app_label, model_name
|
||||
)
|
||||
if remote_model_tuple == old_model_tuple:
|
||||
|
@ -284,7 +281,7 @@ class RenameModel(ModelOperation):
|
|||
changed_field.remote_field.model = new_remote_model
|
||||
through_model = getattr(remote_field, 'through', None)
|
||||
if through_model:
|
||||
through_model_tuple = self._get_model_tuple(
|
||||
through_model_tuple = ModelTuple.from_model(
|
||||
through_model, model_app_label, model_name
|
||||
)
|
||||
if through_model_tuple == old_model_tuple:
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
from collections import namedtuple
|
||||
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
|
||||
def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
|
||||
for state_app_label, state_model in state.models:
|
||||
for _, f in state.models[state_app_label, state_model].fields:
|
||||
|
@ -7,3 +12,42 @@ def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
|
|||
if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ModelTuple(namedtuple('ModelTupleBase', ('app_label', 'model_name'))):
|
||||
@classmethod
|
||||
def from_model(cls, model, app_label=None, model_name=None):
|
||||
"""
|
||||
Take a model class or a 'app_label.ModelName' string and return a
|
||||
ModelTuple('app_label', 'modelname'). The optional app_label and
|
||||
model_name arguments are the defaults if "self" or "ModelName" are
|
||||
passed.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
return cls(app_label, model_name)
|
||||
if '.' in model:
|
||||
return cls(*model.lower().split('.', 1))
|
||||
return cls(app_label, model.lower())
|
||||
return cls(model._meta.app_label, model._meta.model_name)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, ModelTuple):
|
||||
# Consider ModelTuple equal if their model_name is equal and either
|
||||
# one of them is missing an app_label.
|
||||
return self.model_name == other.model_name and (
|
||||
self.app_label is None or other.app_label is None or self.app_label == other.app_label
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
def field_references_model(field, model_tuple):
|
||||
"""Return whether or not field references model_tuple."""
|
||||
remote_field = field.remote_field
|
||||
if remote_field:
|
||||
if ModelTuple.from_model(remote_field.model) == model_tuple:
|
||||
return True
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if through and ModelTuple.from_model(through) == model_tuple:
|
||||
return True
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue