From 248fdb1110abefbdc14e3464c556b8a22abe4edc Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 11 Dec 2013 13:16:29 +0000 Subject: [PATCH] Change FKs when what they point to changes --- django/db/backends/schema.py | 20 +++ django/db/migrations/autodetector.py | 153 ++++++++++++---------- django/db/migrations/operations/fields.py | 12 +- 3 files changed, 107 insertions(+), 78 deletions(-) diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index df17ab5a7f..5e3ad4dd80 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -615,6 +615,11 @@ class BaseDatabaseSchemaEditor(object): "extra": "", } ) + # Type alteration on primary key? Then we need to alter the column + # referring to us. + rels_to_update = [] + if old_field.primary_key and new_field.primary_key and old_type != new_type: + rels_to_update.extend(model._meta.get_all_related_objects()) # Changed to become primary key? # Note that we don't detect unsetting of a PK, as we assume another field # will always come along and replace it. @@ -641,6 +646,21 @@ class BaseDatabaseSchemaEditor(object): "columns": self.quote_name(new_field.column), } ) + # Update all referencing columns + rels_to_update.extend(model._meta.get_all_related_objects()) + # Handle out type alters on the other end of rels from the PK stuff above + for rel in rels_to_update: + rel_db_params = rel.field.db_parameters(connection=self.connection) + rel_type = rel_db_params['type'] + self.execute( + self.sql_alter_column % { + "table": self.quote_name(rel.model._meta.db_table), + "changes": self.sql_alter_column_type % { + "column": self.quote_name(rel.field.column), + "type": rel_type, + } + } + ) # Does it have a foreign key? if new_field.rel: self.execute( diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 78c9770366..6bec6d56c4 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -153,80 +153,16 @@ class MigrationAutodetector(object): ) # Changes within models kept_models = set(old_model_keys).intersection(new_model_keys) + old_fields = set() + new_fields = set() for app_label, model_name in kept_models: old_model_state = self.from_state.models[app_label, model_name] new_model_state = self.to_state.models[app_label, model_name] - # New fields - old_field_names = set(x for x, y in old_model_state.fields) - new_field_names = set(x for x, y in new_model_state.fields) - for field_name in new_field_names - old_field_names: - field = new_model_state.get_field_by_name(field_name) - # Scan to see if this is actually a rename! - field_dec = field.deconstruct()[1:] - found_rename = False - for removed_field_name in (old_field_names - new_field_names): - if old_model_state.get_field_by_name(removed_field_name).deconstruct()[1:] == field_dec: - if self.questioner.ask_rename(model_name, removed_field_name, field_name, field): - self.add_to_migration( - app_label, - operations.RenameField( - model_name=model_name, - old_name=removed_field_name, - new_name=field_name, - ) - ) - old_field_names.remove(removed_field_name) - new_field_names.remove(field_name) - found_rename = True - break - if found_rename: - continue - # You can't just add NOT NULL fields with no default - if not field.null and not field.has_default(): - field = field.clone() - field.default = self.questioner.ask_not_null_addition(field_name, model_name) - self.add_to_migration( - app_label, - operations.AddField( - model_name=model_name, - name=field_name, - field=field, - preserve_default=False, - ) - ) - else: - self.add_to_migration( - app_label, - operations.AddField( - model_name=model_name, - name=field_name, - field=field, - ) - ) - # Old fields - for field_name in old_field_names - new_field_names: - self.add_to_migration( - app_label, - operations.RemoveField( - model_name=model_name, - name=field_name, - ) - ) - # The same fields - for field_name in old_field_names.intersection(new_field_names): - # Did the field change? - old_field_dec = old_model_state.get_field_by_name(field_name).deconstruct() - new_field_dec = new_model_state.get_field_by_name(field_name).deconstruct() - if old_field_dec != new_field_dec: - self.add_to_migration( - app_label, - operations.AlterField( - model_name=model_name, - name=field_name, - field=new_model_state.get_field_by_name(field_name), - ) - ) - # unique_together changes + # Collect field changes for later global dealing with (so AddFields + # always come before AlterFields even on separate models) + old_fields.update((app_label, model_name, x) for x, y in old_model_state.fields) + new_fields.update((app_label, model_name, x) for x, y in new_model_state.fields) + # Unique_together changes if old_model_state.options.get("unique_together", set()) != new_model_state.options.get("unique_together", set()): self.add_to_migration( app_label, @@ -235,6 +171,81 @@ class MigrationAutodetector(object): unique_together=new_model_state.options.get("unique_together", set()), ) ) + # New fields + for app_label, model_name, field_name in new_fields - old_fields: + old_model_state = self.from_state.models[app_label, model_name] + new_model_state = self.to_state.models[app_label, model_name] + field = new_model_state.get_field_by_name(field_name) + # Scan to see if this is actually a rename! + field_dec = field.deconstruct()[1:] + found_rename = False + for rem_app_label, rem_model_name, rem_field_name in (old_fields - new_fields): + if rem_app_label == app_label and rem_model_name == model_name: + if old_model_state.get_field_by_name(rem_field_name).deconstruct()[1:] == field_dec: + if self.questioner.ask_rename(model_name, rem_field_name, field_name, field): + self.add_to_migration( + app_label, + operations.RenameField( + model_name=model_name, + old_name=rem_field_name, + new_name=field_name, + ) + ) + old_fields.remove((rem_app_label, rem_model_name, rem_field_name)) + new_fields.remove((app_label, model_name, field_name)) + found_rename = True + break + if found_rename: + continue + # You can't just add NOT NULL fields with no default + if not field.null and not field.has_default(): + field = field.clone() + field.default = self.questioner.ask_not_null_addition(field_name, model_name) + self.add_to_migration( + app_label, + operations.AddField( + model_name=model_name, + name=field_name, + field=field, + preserve_default=False, + ) + ) + else: + self.add_to_migration( + app_label, + operations.AddField( + model_name=model_name, + name=field_name, + field=field, + ) + ) + # Old fields + for app_label, model_name, field_name in old_fields - new_fields: + old_model_state = self.from_state.models[app_label, model_name] + new_model_state = self.to_state.models[app_label, model_name] + self.add_to_migration( + app_label, + operations.RemoveField( + model_name=model_name, + name=field_name, + ) + ) + # The same fields + for app_label, model_name, field_name in old_fields.intersection(new_fields): + # Did the field change? + old_model_state = self.from_state.models[app_label, model_name] + new_model_state = self.to_state.models[app_label, model_name] + old_field_dec = old_model_state.get_field_by_name(field_name).deconstruct() + new_field_dec = new_model_state.get_field_by_name(field_name).deconstruct() + if old_field_dec != new_field_dec: + self.add_to_migration( + app_label, + operations.AlterField( + model_name=model_name, + name=field_name, + field=new_model_state.get_field_by_name(field_name), + ) + ) # Alright, now add internal dependencies for app_label, migrations in self.migrations.items(): for m1, m2 in zip(migrations, migrations[1:]): diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index c5f0bd1e2b..73efef4691 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -24,10 +24,9 @@ class AddField(Operation): state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) def database_forwards(self, app_label, schema_editor, from_state, to_state): - from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): - schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) + schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) @@ -74,10 +73,9 @@ class RemoveField(Operation): schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): - from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): - schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) + schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) def describe(self): return "Remove field %s from %s" % (self.name, self.model_name) @@ -109,7 +107,7 @@ class AlterField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( - from_model, + to_model, from_model._meta.get_field_by_name(self.name)[0], to_model._meta.get_field_by_name(self.name)[0], ) @@ -155,7 +153,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( - from_model, + to_model, from_model._meta.get_field_by_name(self.old_name)[0], to_model._meta.get_field_by_name(self.new_name)[0], ) @@ -165,7 +163,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( - from_model, + to_model, from_model._meta.get_field_by_name(self.new_name)[0], to_model._meta.get_field_by_name(self.old_name)[0], )