From fa1d7ba5b9c7707cae4a23ce9876c0a7abd924a5 Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Thu, 26 Aug 2021 07:49:37 +0200 Subject: [PATCH] Refs #29898 -- Changed fields in ProjectState's relation registry to dict. --- django/db/migrations/autodetector.py | 11 +++++++---- django/db/migrations/state.py | 26 +++++++++++++------------- tests/migrations/test_state.py | 20 ++++++++++---------- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 2848adce7da..594658ce99c 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -483,7 +483,7 @@ class MigrationAutodetector: fields = list(model_state.fields.values()) + [ field.remote_field for relations in self.to_state.relations[app_label, model_name].values() - for _, field in relations + for field in relations.values() ] for field in fields: if field.is_relation: @@ -672,7 +672,7 @@ class MigrationAutodetector: if (app_label, model_name) in self.old_proxy_keys: for related_model_key, related_fields in relations[app_label, model_name].items(): related_model_state = self.to_state.models[related_model_key] - for related_field_name, related_field in related_fields: + for related_field_name, related_field in related_fields.items(): self.add_operation( related_model_state.app_label, operations.AlterField( @@ -777,7 +777,7 @@ class MigrationAutodetector: for (related_object_app_label, object_name), relation_related_fields in ( relations[app_label, model_name].items() ): - for field_name, field in relation_related_fields: + for field_name, field in relation_related_fields.items(): dependencies.append( (related_object_app_label, object_name, field_name, False), ) @@ -1082,7 +1082,10 @@ class MigrationAutodetector: else: relations = project_state.relations[app_label, model_name] for (remote_app_label, remote_model_name), fields in relations.items(): - if any(field == related_field.remote_field for _, related_field in fields): + if any( + field == related_field.remote_field + for related_field in fields.values() + ): remote_field_model = f'{remote_app_label}.{remote_model_name}' break # Account for FKs to swappable models diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index e04213b6566..2927e66f835 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -97,7 +97,7 @@ class ProjectState: assert isinstance(real_apps, set) self.real_apps = real_apps self.is_delayed = False - # {remote_model_key: {model_key: [(field_name, field)]}} + # {remote_model_key: {model_key: {field_name: field}}} self._relations = None @property @@ -302,14 +302,10 @@ class ProjectState: old_name_lower = old_name.lower() new_name_lower = new_name.lower() for to_model in self._relations.values(): - # It's safe to modify the same collection that is iterated - # because `break` is called right after. - for field_name, field in to_model[model_key]: - if field_name == old_name_lower: - field.name = new_name_lower - to_model[model_key].remove((old_name_lower, field)) - to_model[model_key].append((new_name_lower, field)) - break + if old_name_lower in to_model[model_key]: + field = to_model[model_key].pop(old_name_lower) + field.name = new_name_lower + to_model[model_key][new_name_lower] = field self.reload_model(*model_key, delay=delay) def _find_reload_model(self, app_label, model_name, delay=False): @@ -406,9 +402,13 @@ class ProjectState: remote_model_key = concretes[remote_model_key] relations_to_remote_model = self._relations[remote_model_key] if field_name in self.models[model_key].fields: - relations_to_remote_model[model_key].append((field_name, field)) + # The assert holds because it's a new relation, or an altered + # relation, in which case references have been removed by + # alter_field(). + assert field_name not in relations_to_remote_model[model_key] + relations_to_remote_model[model_key][field_name] = field else: - relations_to_remote_model[model_key].remove((field_name, field)) + del relations_to_remote_model[model_key][field_name] if not relations_to_remote_model[model_key]: del relations_to_remote_model[model_key] @@ -444,8 +444,8 @@ class ProjectState: for field_name, field in model_state.fields.items(): field.name = field_name # Resolve relations. - # {remote_model_key: {model_key: [(field_name, field)]}} - self._relations = defaultdict(partial(defaultdict, list)) + # {remote_model_key: {model_key: {field_name: field}}} + self._relations = defaultdict(partial(defaultdict, dict)) concretes, proxies = self._get_concrete_models_mapping_and_proxy_models() for model_key in concretes: diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index 550245fb8ea..8ad0d195008 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -1216,7 +1216,7 @@ class StateRelationsTests(SimpleTestCase): ) self.assertEqual( project_state.relations['tests', 'post']['tests', 'post'], - [('next_post', new_field)], + {'next_post': new_field}, ) # Add a foreign key. new_field = models.ForeignKey('tests.post', models.CASCADE) @@ -1229,7 +1229,7 @@ class StateRelationsTests(SimpleTestCase): ) self.assertEqual( project_state.relations['tests', 'post']['tests', 'comment'], - [('post', new_field)], + {'post': new_field}, ) def test_add_field_m2m_with_through(self): @@ -1271,7 +1271,7 @@ class StateRelationsTests(SimpleTestCase): ) self.assertEqual( project_state.relations['tests', 'tag']['tests', 'post'], - [('tags', new_field)], + {'tags': new_field}, ) def test_remove_field(self): @@ -1308,14 +1308,14 @@ class StateRelationsTests(SimpleTestCase): field = project_state.models['tests', 'comment'].fields['user'] self.assertEqual( project_state.relations['tests', 'user']['tests', 'comment'], - [('user', field)], + {'user': field}, ) project_state.rename_field('tests', 'comment', 'user', 'author') renamed_field = project_state.models['tests', 'comment'].fields['author'] self.assertEqual( project_state.relations['tests', 'user']['tests', 'comment'], - [('author', renamed_field)], + {'author': renamed_field}, ) self.assertEqual(field, renamed_field) @@ -1357,7 +1357,7 @@ class StateRelationsTests(SimpleTestCase): ) self.assertEqual( project_state.relations['tests', 'user']['tests', 'comment'], - [('user', m2m_field)], + {'user': m2m_field}, ) def test_alter_field_m2m_to_fk(self): @@ -1387,7 +1387,7 @@ class StateRelationsTests(SimpleTestCase): ) self.assertEqual( project_state.relations['tests_other', 'user_other']['tests', 'post'], - [('authors', foreign_key)], + {'authors': foreign_key}, ) def test_many_relations_to_same_model(self): @@ -1403,14 +1403,14 @@ class StateRelationsTests(SimpleTestCase): comment_rels = project_state.relations['tests', 'user']['tests', 'comment'] # Two foreign keys to the same model. self.assertEqual(len(comment_rels), 2) - self.assertEqual(comment_rels[1], ('reviewer', new_field)) + self.assertEqual(comment_rels['reviewer'], new_field) # Rename the second foreign key. project_state.rename_field('tests', 'comment', 'reviewer', 'supervisor') self.assertEqual(len(comment_rels), 2) - self.assertEqual(comment_rels[1], ('supervisor', new_field)) + self.assertEqual(comment_rels['supervisor'], new_field) # Remove the first foreign key. project_state.remove_field('tests', 'comment', 'user') - self.assertEqual(comment_rels, [('supervisor', new_field)]) + self.assertEqual(comment_rels, {'supervisor': new_field}) class ModelStateTests(SimpleTestCase):