Repoint ForeignKeys when their to= changes.

This commit is contained in:
Andrew Godwin 2012-09-07 13:31:05 -04:00
parent d683263f97
commit a92bae0f06
3 changed files with 56 additions and 7 deletions

View File

@ -118,7 +118,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"columns": set(), "columns": set(),
"primary_key": kind.lower() == "primary key", "primary_key": kind.lower() == "primary key",
"unique": kind.lower() in ["primary key", "unique"], "unique": kind.lower() in ["primary key", "unique"],
"foreign_key": set([tuple(x.split(".", 1)) for x in used_cols]) if kind.lower() == "foreign key" else None, "foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
"check": False, "check": False,
"index": False, "index": False,
} }

View File

@ -21,7 +21,6 @@ class BaseDatabaseSchemaEditor(object):
commit() is called. commit() is called.
TODO: TODO:
- Repointing of FKs
- Repointing of M2Ms - Repointing of M2Ms
- Check constraints (PosIntField) - Check constraints (PosIntField)
""" """
@ -401,6 +400,22 @@ class BaseDatabaseSchemaEditor(object):
"name": index_name, "name": index_name,
} }
) )
# Drop any FK constraints, we'll remake them later
if getattr(old_field, "rel"):
fk_names = self._constraint_names(model, [old_field.column], foreign_key=True)
if strict and len(fk_names) != 1:
raise ValueError("Found wrong number (%s) of foreign key constraints for %s.%s" % (
len(fk_names),
model._meta.db_table,
old_field.column,
))
for fk_name in fk_names:
self.execute(
self.sql_delete_fk % {
"table": self.quote_name(model._meta.db_table),
"name": fk_name,
}
)
# Have they renamed the column? # Have they renamed the column?
if old_field.column != new_field.column: if old_field.column != new_field.column:
self.execute(self.sql_rename_column % { self.execute(self.sql_rename_column % {
@ -516,6 +531,17 @@ class BaseDatabaseSchemaEditor(object):
"columns": self.quote_name(new_field.column), "columns": self.quote_name(new_field.column),
} }
) )
# Does it have a foreign key?
if getattr(new_field, "rel"):
self.execute(
self.sql_create_fk % {
"table": self.quote_name(model._meta.db_table),
"name": self._create_index_name(model, [new_field.column], suffix="_fk"),
"column": self.quote_name(new_field.column),
"to_table": self.quote_name(new_field.rel.to._meta.db_table),
"to_column": self.quote_name(new_field.rel.get_related_field().column),
}
)
def _type_for_alter(self, field): def _type_for_alter(self, field):
""" """
@ -543,7 +569,7 @@ class BaseDatabaseSchemaEditor(object):
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part) index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
return index_name return index_name
def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None): def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None):
"Returns all constraint names matching the columns and conditions" "Returns all constraint names matching the columns and conditions"
column_names = set(column_names) if column_names else None column_names = set(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
@ -556,5 +582,7 @@ class BaseDatabaseSchemaEditor(object):
continue continue
if index is not None and infodict['index'] != index: if index is not None and infodict['index'] != index:
continue continue
if foreign_key is not None and not infodict['foreign_key']:
continue
result.append(name) result.append(name)
return result return result

View File

@ -5,7 +5,7 @@ from django.test import TestCase
from django.utils.unittest import skipUnless from django.utils.unittest import skipUnless
from django.db import connection, DatabaseError, IntegrityError from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.models.loading import cache from django.db.models.loading import cache
from .models import Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest from .models import Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest
@ -114,15 +114,16 @@ class SchemaTests(TestCase):
) )
@skipUnless(connection.features.supports_foreign_keys, "No FK support") @skipUnless(connection.features.supports_foreign_keys, "No FK support")
def test_creation_fk(self): def test_fk(self):
"Tests that creating tables out of FK order works" "Tests that creating tables out of FK order, then repointing, works"
# Create the table # Create the table
editor = connection.schema_editor() editor = connection.schema_editor()
editor.start() editor.start()
editor.create_model(Book) editor.create_model(Book)
editor.create_model(Author) editor.create_model(Author)
editor.create_model(Tag)
editor.commit() editor.commit()
# Check that both tables are there # Check that initial tables are there
try: try:
list(Author.objects.all()) list(Author.objects.all())
except DatabaseError, e: except DatabaseError, e:
@ -139,6 +140,26 @@ class SchemaTests(TestCase):
pub_date = datetime.datetime.now(), pub_date = datetime.datetime.now(),
) )
connection.commit() connection.commit()
# Repoint the FK constraint
new_field = ForeignKey(Tag)
new_field.set_attributes_from_name("author")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Book,
Book._meta.get_field_by_name("author")[0],
new_field,
strict=True,
)
editor.commit()
# Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
for name, details in constraints.items():
if details['columns'] == set(["author_id"]) and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
break
else:
self.fail("No FK constraint for author_id found")
def test_create_field(self): def test_create_field(self):
""" """