diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index 392708134d1..e04213b6566 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -98,18 +98,34 @@ class ProjectState: self.real_apps = real_apps self.is_delayed = False # {remote_model_key: {model_key: [(field_name, field)]}} - self.relations = None + self._relations = None + + @property + def relations(self): + if self._relations is None: + self.resolve_fields_and_relations() + return self._relations def add_model(self, model_state): - app_label, model_name = model_state.app_label, model_state.name_lower - self.models[(app_label, model_name)] = model_state + model_key = model_state.app_label, model_state.name_lower + self.models[model_key] = model_state + if self._relations is not None: + self.resolve_model_relations(model_key) if 'apps' in self.__dict__: # hasattr would cache the property - self.reload_model(app_label, model_name) + self.reload_model(*model_key) def remove_model(self, app_label, model_name): - del self.models[app_label, model_name] + model_key = app_label, model_name + del self.models[model_key] + if self._relations is not None: + self._relations.pop(model_key, None) + # Call list() since _relations can change size during iteration. + for related_model_key, model_relations in list(self._relations.items()): + model_relations.pop(model_key, None) + if not model_relations: + del self._relations[related_model_key] if 'apps' in self.__dict__: # hasattr would cache the property - self.apps.unregister_model(app_label, model_name) + self.apps.unregister_model(*model_key) # Need to do this explicitly since unregister_model() doesn't clear # the cache automatically (#24513) self.apps.clear_cache() @@ -137,6 +153,14 @@ class ProjectState: if changed_field: model_state.fields[name] = changed_field to_reload.add((model_state.app_label, model_state.name_lower)) + if self._relations is not None: + old_name_key = app_label, old_name_lower + new_name_key = app_label, new_name_lower + if old_name_key in self._relations: + self._relations[new_name_key] = self._relations.pop(old_name_key) + for model_relations in self._relations.values(): + if old_name_key in model_relations: + model_relations[new_name_key] = model_relations.pop(old_name_key) # Reload models related to old model before removing the old model. self.reload_models(to_reload, delay=True) # Remove the old model. @@ -187,17 +211,23 @@ class ProjectState: field.default = NOT_PROVIDED else: field = field - self.models[app_label, model_name].fields[name] = field + model_key = app_label, model_name + self.models[model_key].fields[name] = field + if self._relations is not None: + self.resolve_model_field_relations(model_key, name, field) # Delay rendering of relationships if it's not a relational field. delay = not field.is_relation - self.reload_model(app_label, model_name, delay=delay) + self.reload_model(*model_key, delay=delay) def remove_field(self, app_label, model_name, name): - model_state = self.models[app_label, model_name] + model_key = app_label, model_name + model_state = self.models[model_key] old_field = model_state.fields.pop(name) + if self._relations is not None: + self.resolve_model_field_relations(model_key, name, old_field) # Delay rendering of relationships if it's not a relational field. delay = not old_field.is_relation - self.reload_model(app_label, model_name, delay=delay) + self.reload_model(*model_key, delay=delay) def alter_field(self, app_label, model_name, name, field, preserve_default): if not preserve_default: @@ -205,20 +235,30 @@ class ProjectState: field.default = NOT_PROVIDED else: field = field - model_state = self.models[app_label, model_name] - model_state.fields[name] = field + model_key = app_label, model_name + fields = self.models[model_key].fields + if self._relations is not None: + old_field = fields.pop(name) + if old_field.is_relation: + self.resolve_model_field_relations(model_key, name, old_field) + fields[name] = field + if field.is_relation: + self.resolve_model_field_relations(model_key, name, field) + else: + fields[name] = field # TODO: investigate if old relational fields must be reloaded or if # it's sufficient if the new field is (#27737). # Delay rendering of relationships if it's not a relational field and # not referenced by a foreign key. delay = ( not field.is_relation and - not field_is_referenced(self, (app_label, model_name), (name, field)) + not field_is_referenced(self, model_key, (name, field)) ) - self.reload_model(app_label, model_name, delay=delay) + self.reload_model(*model_key, delay=delay) def rename_field(self, app_label, model_name, old_name, new_name): - model_state = self.models[app_label, model_name] + model_key = app_label, model_name + model_state = self.models[model_key] # Rename the field. fields = model_state.fields try: @@ -246,7 +286,7 @@ class ProjectState: ] # Fix to_fields to refer to the new field. delay = True - references = get_references(self, (app_label, model_name), (old_name, found)) + references = get_references(self, model_key, (old_name, found)) for *_, field, reference in references: delay = False if reference.to: @@ -258,7 +298,19 @@ class ProjectState: new_name if to_field_name == old_name else to_field_name for to_field_name in to_fields ]) - self.reload_model(app_label, model_name, delay=delay) + if self._relations is not None: + old_name_lower = old_name.lower() + new_name_lower = new_name.lower() + for to_model in self._relations.values(): + # It's safe to modify the same collection that is iterated + # because `break` is called right after. + for field_name, field in to_model[model_key]: + if field_name == old_name_lower: + field.name = new_name_lower + to_model[model_key].remove((old_name_lower, field)) + to_model[model_key].append((new_name_lower, field)) + break + self.reload_model(*model_key, delay=delay) def _find_reload_model(self, app_label, model_name, delay=False): if delay: @@ -352,7 +404,13 @@ class ProjectState: remote_model_key = resolve_relation(model, *model_key) if remote_model_key[0] not in self.real_apps and remote_model_key in concretes: remote_model_key = concretes[remote_model_key] - self.relations[remote_model_key][model_key].append((field_name, field)) + relations_to_remote_model = self._relations[remote_model_key] + if field_name in self.models[model_key].fields: + relations_to_remote_model[model_key].append((field_name, field)) + else: + relations_to_remote_model[model_key].remove((field_name, field)) + if not relations_to_remote_model[model_key]: + del relations_to_remote_model[model_key] def resolve_model_field_relations( self, model_key, field_name, field, concretes=None, @@ -387,14 +445,14 @@ class ProjectState: field.name = field_name # Resolve relations. # {remote_model_key: {model_key: [(field_name, field)]}} - self.relations = defaultdict(partial(defaultdict, list)) + self._relations = defaultdict(partial(defaultdict, list)) concretes, proxies = self._get_concrete_models_mapping_and_proxy_models() for model_key in concretes: self.resolve_model_relations(model_key, concretes) for model_key in proxies: - self.relations[model_key] = self.relations[concretes[model_key]] + self._relations[model_key] = self._relations[concretes[model_key]] def get_concrete_model_key(self, model): concrete_models_mapping, _ = self._get_concrete_models_mapping_and_proxy_models() diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index 5be29ca2f14..550245fb8ea 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -1017,6 +1017,402 @@ class StateTests(SimpleTestCase): self.assertEqual(list(choices_field.choices), choices) +class StateRelationsTests(SimpleTestCase): + def get_base_project_state(self): + new_apps = Apps() + + class User(models.Model): + class Meta: + app_label = 'tests' + apps = new_apps + + class Comment(models.Model): + text = models.TextField() + user = models.ForeignKey(User, models.CASCADE) + comments = models.ManyToManyField('self') + + class Meta: + app_label = 'tests' + apps = new_apps + + class Post(models.Model): + text = models.TextField() + authors = models.ManyToManyField(User) + + class Meta: + app_label = 'tests' + apps = new_apps + + project_state = ProjectState() + project_state.add_model(ModelState.from_model(User)) + project_state.add_model(ModelState.from_model(Comment)) + project_state.add_model(ModelState.from_model(Post)) + return project_state + + def test_relations_population(self): + tests = [ + ('add_model', [ + ModelState( + app_label='migrations', + name='Tag', + fields=[('id', models.AutoField(primary_key=True))], + ), + ]), + ('remove_model', ['tests', 'comment']), + ('rename_model', ['tests', 'comment', 'opinion']), + ('add_field', [ + 'tests', + 'post', + 'next_post', + models.ForeignKey('self', models.CASCADE), + True, + ]), + ('remove_field', ['tests', 'post', 'text']), + ('rename_field', ['tests', 'comment', 'user', 'author']), + ('alter_field', [ + 'tests', + 'comment', + 'user', + models.IntegerField(), + True, + ]), + ] + for method, args in tests: + with self.subTest(method=method): + project_state = self.get_base_project_state() + getattr(project_state, method)(*args) + # ProjectState's `_relations` are populated on `relations` access. + self.assertIsNone(project_state._relations) + self.assertEqual(project_state.relations, project_state._relations) + self.assertIsNotNone(project_state._relations) + + def test_add_model(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + self.assertEqual( + list(project_state.relations['tests', 'comment']), + [('tests', 'comment')], + ) + self.assertNotIn(('tests', 'post'), project_state.relations) + + def test_add_model_no_relations(self): + project_state = ProjectState() + project_state.add_model(ModelState( + app_label='migrations', + name='Tag', + fields=[('id', models.AutoField(primary_key=True))], + )) + self.assertEqual(project_state.relations, {}) + + def test_add_model_other_app(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + project_state.add_model(ModelState( + app_label='tests_other', + name='comment', + fields=[ + ('id', models.AutoField(primary_key=True)), + ('user', models.ForeignKey('tests.user', models.CASCADE)), + ], + )) + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post'), ('tests_other', 'comment')], + ) + + def test_remove_model(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + self.assertEqual( + list(project_state.relations['tests', 'comment']), + [('tests', 'comment')], + ) + + project_state.remove_model('tests', 'comment') + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'post')], + ) + self.assertNotIn(('tests', 'comment'), project_state.relations) + project_state.remove_model('tests', 'post') + self.assertEqual(project_state.relations, {}) + project_state.remove_model('tests', 'user') + self.assertEqual(project_state.relations, {}) + + def test_rename_model(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + self.assertEqual( + list(project_state.relations['tests', 'comment']), + [('tests', 'comment')], + ) + + related_field = project_state.relations['tests', 'user']['tests', 'comment'] + project_state.rename_model('tests', 'comment', 'opinion') + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'post'), ('tests', 'opinion')], + ) + self.assertEqual( + list(project_state.relations['tests', 'opinion']), + [('tests', 'opinion')], + ) + self.assertNotIn(('tests', 'comment'), project_state.relations) + self.assertEqual( + project_state.relations['tests', 'user']['tests', 'opinion'], + related_field, + ) + + project_state.rename_model('tests', 'user', 'author') + self.assertEqual( + list(project_state.relations['tests', 'author']), + [('tests', 'post'), ('tests', 'opinion')], + ) + self.assertNotIn(('tests', 'user'), project_state.relations) + + def test_rename_model_no_relations(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + related_field = project_state.relations['tests', 'user']['tests', 'post'] + self.assertNotIn(('tests', 'post'), project_state.relations) + # Rename a model without relations. + project_state.rename_model('tests', 'post', 'blog') + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'blog')], + ) + self.assertNotIn(('tests', 'blog'), project_state.relations) + self.assertEqual( + related_field, + project_state.relations['tests', 'user']['tests', 'blog'], + ) + + def test_add_field(self): + project_state = self.get_base_project_state() + self.assertNotIn(('tests', 'post'), project_state.relations) + # Add a self-referential foreign key. + new_field = models.ForeignKey('self', models.CASCADE) + project_state.add_field( + 'tests', 'post', 'next_post', new_field, preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'post']), + [('tests', 'post')], + ) + self.assertEqual( + project_state.relations['tests', 'post']['tests', 'post'], + [('next_post', new_field)], + ) + # Add a foreign key. + new_field = models.ForeignKey('tests.post', models.CASCADE) + project_state.add_field( + 'tests', 'comment', 'post', new_field, preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'post']), + [('tests', 'post'), ('tests', 'comment')], + ) + self.assertEqual( + project_state.relations['tests', 'post']['tests', 'comment'], + [('post', new_field)], + ) + + def test_add_field_m2m_with_through(self): + project_state = self.get_base_project_state() + project_state.add_model(ModelState( + app_label='tests', + name='Tag', + fields=[('id', models.AutoField(primary_key=True))], + )) + project_state.add_model(ModelState( + app_label='tests', + name='PostTag', + fields=[ + ('id', models.AutoField(primary_key=True)), + ('post', models.ForeignKey('tests.post', models.CASCADE)), + ('tag', models.ForeignKey('tests.tag', models.CASCADE)), + ], + )) + self.assertEqual( + list(project_state.relations['tests', 'post']), + [('tests', 'posttag')], + ) + self.assertEqual( + list(project_state.relations['tests', 'tag']), + [('tests', 'posttag')], + ) + # Add a many-to-many field with the through model. + new_field = models.ManyToManyField('tests.tag', through='tests.posttag') + project_state.add_field( + 'tests', 'post', 'tags', new_field, preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'post']), + [('tests', 'posttag')], + ) + self.assertEqual( + list(project_state.relations['tests', 'tag']), + [('tests', 'posttag'), ('tests', 'post')], + ) + self.assertEqual( + project_state.relations['tests', 'tag']['tests', 'post'], + [('tags', new_field)], + ) + + def test_remove_field(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + # Remove a many-to-many field. + project_state.remove_field('tests', 'post', 'authors') + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment')], + ) + # Remove a foreign key. + project_state.remove_field('tests', 'comment', 'user') + self.assertEqual(project_state.relations['tests', 'user'], {}) + + def test_remove_field_no_relations(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + # Remove a non-relation field. + project_state.remove_field('tests', 'post', 'text') + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + + def test_rename_field(self): + project_state = self.get_base_project_state() + field = project_state.models['tests', 'comment'].fields['user'] + self.assertEqual( + project_state.relations['tests', 'user']['tests', 'comment'], + [('user', field)], + ) + + project_state.rename_field('tests', 'comment', 'user', 'author') + renamed_field = project_state.models['tests', 'comment'].fields['author'] + self.assertEqual( + project_state.relations['tests', 'user']['tests', 'comment'], + [('author', renamed_field)], + ) + self.assertEqual(field, renamed_field) + + def test_rename_field_no_relations(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + # Rename a non-relation field. + project_state.rename_field('tests', 'post', 'text', 'description') + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + + def test_alter_field(self): + project_state = self.get_base_project_state() + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + # Alter a foreign key to a non-relation field. + project_state.alter_field( + 'tests', 'comment', 'user', models.IntegerField(), preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'post')], + ) + # Alter a non-relation field to a many-to-many field. + m2m_field = models.ManyToManyField('tests.user') + project_state.alter_field( + 'tests', 'comment', 'user', m2m_field, preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'post'), ('tests', 'comment')], + ) + self.assertEqual( + project_state.relations['tests', 'user']['tests', 'comment'], + [('user', m2m_field)], + ) + + def test_alter_field_m2m_to_fk(self): + project_state = self.get_base_project_state() + project_state.add_model(ModelState( + app_label='tests_other', + name='user_other', + fields=[('id', models.AutoField(primary_key=True))], + )) + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + self.assertNotIn(('tests_other', 'user_other'), project_state.relations) + # Alter a many-to-many field to a foreign key. + foreign_key = models.ForeignKey('tests_other.user_other', models.CASCADE) + project_state.alter_field( + 'tests', 'post', 'authors', foreign_key, preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment')], + ) + self.assertEqual( + list(project_state.relations['tests_other', 'user_other']), + [('tests', 'post')], + ) + self.assertEqual( + project_state.relations['tests_other', 'user_other']['tests', 'post'], + [('authors', foreign_key)], + ) + + def test_many_relations_to_same_model(self): + project_state = self.get_base_project_state() + new_field = models.ForeignKey('tests.user', models.CASCADE) + project_state.add_field( + 'tests', 'comment', 'reviewer', new_field, preserve_default=True, + ) + self.assertEqual( + list(project_state.relations['tests', 'user']), + [('tests', 'comment'), ('tests', 'post')], + ) + comment_rels = project_state.relations['tests', 'user']['tests', 'comment'] + # Two foreign keys to the same model. + self.assertEqual(len(comment_rels), 2) + self.assertEqual(comment_rels[1], ('reviewer', new_field)) + # Rename the second foreign key. + project_state.rename_field('tests', 'comment', 'reviewer', 'supervisor') + self.assertEqual(len(comment_rels), 2) + self.assertEqual(comment_rels[1], ('supervisor', new_field)) + # Remove the first foreign key. + project_state.remove_field('tests', 'comment', 'user') + self.assertEqual(comment_rels, [('supervisor', new_field)]) + + class ModelStateTests(SimpleTestCase): def test_custom_model_base(self): state = ModelState.from_model(ModelWithCustomBase)