Add M2M repointing

This commit is contained in:
Andrew Godwin 2012-09-07 14:39:22 -04:00
parent a92bae0f06
commit 375178fc19
4 changed files with 119 additions and 16 deletions

View File

@ -21,7 +21,6 @@ class BaseDatabaseSchemaEditor(object):
commit() is called. commit() is called.
TODO: TODO:
- Repointing of M2Ms
- Check constraints (PosIntField) - Check constraints (PosIntField)
""" """
@ -154,13 +153,13 @@ class BaseDatabaseSchemaEditor(object):
# Actions # Actions
def create_model(self, model): def create_model(self, model, force=False):
""" """
Takes a model and creates a table for it in the database. Takes a model and creates a table for it in the database.
Will also create any accompanying indexes or unique constraints. Will also create any accompanying indexes or unique constraints.
""" """
# Do nothing if this is an unmanaged or proxy model # Do nothing if this is an unmanaged or proxy model
if not model._meta.managed or model._meta.proxy: if not force and (not model._meta.managed or model._meta.proxy):
return return
# Create column SQL, add FK deferreds if needed # Create column SQL, add FK deferreds if needed
column_sqls = [] column_sqls = []
@ -214,13 +213,16 @@ class BaseDatabaseSchemaEditor(object):
"definition": ", ".join(column_sqls) "definition": ", ".join(column_sqls)
} }
self.execute(sql, params) self.execute(sql, params)
# Make M2M tables
for field in model._meta.local_many_to_many:
self.create_model(field.rel.through, force=True)
def delete_model(self, model): def delete_model(self, model, force=False):
""" """
Deletes a model from the database. Deletes a model from the database.
""" """
# Do nothing if this is an unmanaged or proxy model # Do nothing if this is an unmanaged or proxy model
if not model._meta.managed or model._meta.proxy: if not force and (not model._meta.managed or model._meta.proxy):
return return
# Delete the table # Delete the table
self.execute(self.sql_delete_table % { self.execute(self.sql_delete_table % {
@ -287,7 +289,7 @@ class BaseDatabaseSchemaEditor(object):
""" """
# Special-case implicit M2M tables # Special-case implicit M2M tables
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
return self.create_model(field.rel.through) return self.create_model(field.rel.through, force=True)
# Get the column's definition # Get the column's definition
definition, params = self.column_sql(model, field, include_default=True) definition, params = self.column_sql(model, field, include_default=True)
# It might not actually have a column behind it # It might not actually have a column behind it
@ -358,11 +360,10 @@ class BaseDatabaseSchemaEditor(object):
# Ensure this field is even column-based # Ensure this field is even column-based
old_type = old_field.db_type(connection=self.connection) old_type = old_field.db_type(connection=self.connection)
new_type = self._type_for_alter(new_field) new_type = self._type_for_alter(new_field)
if old_type is None and new_type is None: if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
# TODO: Handle M2M fields being repointed return self._alter_many_to_many(model, old_field, new_field, strict)
return
elif old_type is None or new_type is None: elif old_type is None or new_type is None:
raise ValueError("Cannot alter field %s into %s - they are not compatible types" % ( raise ValueError("Cannot alter field %s into %s - they are not compatible types (probably means only one is an M2M with implicit through model)" % (
old_field, old_field,
new_field, new_field,
)) ))
@ -543,6 +544,17 @@ class BaseDatabaseSchemaEditor(object):
} }
) )
def _alter_many_to_many(self, model, old_field, new_field, strict):
"Alters M2Ms to repoint their to= endpoints."
# Rename the through table
self.alter_db_table(old_field.rel.through, old_field.rel.through._meta.db_table, new_field.rel.through._meta.db_table)
# Repoint the FK to the other side
self.alter_field(
new_field.rel.through,
old_field.rel.through._meta.get_field_by_name(old_field.m2m_reverse_field_name())[0],
new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
)
def _type_for_alter(self, field): def _type_for_alter(self, field):
""" """
Returns a field's type suitable for ALTER COLUMN. Returns a field's type suitable for ALTER COLUMN.

View File

@ -101,11 +101,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Ensure this field is even column-based # Ensure this field is even column-based
old_type = old_field.db_type(connection=self.connection) old_type = old_field.db_type(connection=self.connection)
new_type = self._type_for_alter(new_field) new_type = self._type_for_alter(new_field)
if old_type is None and new_type is None: if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
# TODO: Handle M2M fields being repointed return self._alter_many_to_many(model, old_field, new_field, strict)
return
elif old_type is None or new_type is None: elif old_type is None or new_type is None:
raise ValueError("Cannot alter field %s into %s - they are not compatible types" % ( raise ValueError("Cannot alter field %s into %s - they are not compatible types (probably means only one is an M2M with implicit through model)" % (
old_field, old_field,
new_field, new_field,
)) ))
@ -114,3 +113,25 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def alter_unique_together(self, model, old_unique_together, new_unique_together): def alter_unique_together(self, model, old_unique_together, new_unique_together):
self._remake_table(model, override_uniques=new_unique_together) self._remake_table(model, override_uniques=new_unique_together)
def _alter_many_to_many(self, model, old_field, new_field, strict):
"Alters M2Ms to repoint their to= endpoints."
# Make a new through table
self.create_model(new_field.rel.through)
# Copy the data across
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % (
self.quote_name(new_field.rel.through._meta.db_table),
', '.join([
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]),
', '.join([
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]),
self.quote_name(old_field.rel.through._meta.db_table),
))
# Delete the old through table
self.delete_model(old_field.rel.through, force=True)

View File

@ -29,6 +29,16 @@ class Book(models.Model):
managed = False managed = False
class BookWithM2M(models.Model):
author = models.ForeignKey(Author)
title = models.CharField(max_length=100, db_index=True)
pub_date = models.DateTimeField()
tags = models.ManyToManyField("Tag", related_name="books")
class Meta:
managed = False
class BookWithSlug(models.Model): class BookWithSlug(models.Model):
author = models.ForeignKey(Author) author = models.ForeignKey(Author)
title = models.CharField(max_length=100, db_index=True) title = models.CharField(max_length=100, db_index=True)

View File

@ -7,7 +7,7 @@ from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField, ForeignKey from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.models.loading import cache from django.db.models.loading import cache
from .models import Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest from .models import Author, Book, BookWithSlug, BookWithM2M, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest
class SchemaTests(TestCase): class SchemaTests(TestCase):
@ -19,7 +19,7 @@ class SchemaTests(TestCase):
as the code it is testing. as the code it is testing.
""" """
models = [Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest] models = [Author, Book, BookWithSlug, BookWithM2M, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest]
# Utility functions # Utility functions
@ -248,6 +248,21 @@ class SchemaTests(TestCase):
self.assertEqual(columns['display_name'][0], "CharField") self.assertEqual(columns['display_name'][0], "CharField")
self.assertNotIn("name", columns) self.assertNotIn("name", columns)
def test_m2m_create(self):
"""
Tests M2M fields on models during creation
"""
# Create the tables
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.create_model(Tag)
editor.create_model(BookWithM2M)
editor.commit()
# Ensure there is now an m2m table there
columns = self.column_classes(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
self.assertEqual(columns['tag_id'][0], "IntegerField")
def test_m2m(self): def test_m2m(self):
""" """
Tests adding/removing M2M fields on models Tests adding/removing M2M fields on models
@ -287,6 +302,51 @@ class SchemaTests(TestCase):
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through) self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
connection.rollback() connection.rollback()
def test_m2m_repoint(self):
"""
Tests repointing M2M fields
"""
# Create the tables
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.create_model(BookWithM2M)
editor.create_model(Tag)
editor.create_model(UniqueTest)
editor.commit()
# Ensure the M2M exists and points to Tag
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == set(["tag_id"]) and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
break
else:
self.fail("No FK constraint for tag_id found")
# Repoint the M2M
new_field = ManyToManyField(UniqueTest)
new_field.contribute_to_class(BookWithM2M, "uniques")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Author,
BookWithM2M._meta.get_field_by_name("tags")[0],
new_field,
)
editor.commit()
# Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
connection.rollback()
# Ensure the new M2M exists and points to UniqueTest
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == set(["uniquetest_id"]) and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_uniquetest', 'id'))
break
else:
self.fail("No FK constraint for tag_id found")
def test_unique(self): def test_unique(self):
""" """
Tests removing and adding unique constraints to a single column. Tests removing and adding unique constraints to a single column.