Add M2M tests and some unique support

This commit is contained in:
Andrew Godwin 2012-08-02 15:08:39 +01:00
parent 4a2e80fff4
commit b139315f1c
6 changed files with 259 additions and 15 deletions

View File

@ -427,6 +427,9 @@ class BaseDatabaseFeatures(object):
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
supports_combined_alters = False
# What's the maximum length for index names?
max_index_name_length = 63
def __init__(self, connection):
self.connection = connection
@ -1056,6 +1059,15 @@ class BaseDatabaseIntrospection(object):
"""
raise NotImplementedError
def get_constraints(self, cursor, table_name):
"""
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
Both single- and multi-column constraints are introspected.
"""
raise NotImplementedError
class BaseDatabaseClient(object):
"""
This class encapsulates all backend-specific methods for opening a

View File

@ -21,7 +21,8 @@ class BaseDatabaseCreation(object):
def __init__(self, connection):
self.connection = connection
def _digest(self, *args):
@classmethod
def _digest(cls, *args):
"""
Generates a 32-bit digest of a set of arguments that can be used to
shorten identifying names.

View File

@ -88,3 +88,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
def get_constraints(self, cursor, table_name):
"""
Retrieves any constraints (unique, pk, check) across one or more columns.
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
"""
constraints = {}
# Loop over the constraint tables, collecting things as constraints
ifsc_tables = ["constraint_column_usage", "key_column_usage"]
for ifsc_table in ifsc_tables:
cursor.execute("""
SELECT kc.constraint_name, kc.column_name, c.constraint_type
FROM information_schema.%s AS kc
JOIN information_schema.table_constraints AS c ON
kc.table_schema = c.table_schema AND
kc.table_name = c.table_name AND
kc.constraint_name = c.constraint_name
WHERE
kc.table_schema = %%s AND
kc.table_name = %%s
""" % ifsc_table, ["public", table_name])
for constraint, column, kind in cursor.fetchall():
# If we're the first column, make the record
if constraint not in constraints:
constraints[constraint] = {
"columns": set(),
"primary_key": kind.lower() == "primary key",
"unique": kind.lower() in ["primary key", "unique"],
}
# Record the details
constraints[constraint]['columns'].add(column)
return constraints

View File

@ -4,6 +4,8 @@ import time
from django.conf import settings
from django.db import transaction
from django.db.utils import load_backend
from django.db.backends.creation import BaseDatabaseCreation
from django.db.backends.util import truncate_name
from django.utils.log import getLogger
from django.db.models.fields.related import ManyToManyField
@ -294,7 +296,23 @@ class BaseDatabaseSchemaEditor(object):
old_field,
new_field,
))
# First, have they renamed the column?
# Has unique been removed?
if old_field.unique and not new_field.unique:
# Find the unique constraint for this field
constraint_names = self._constraint_names(model, [old_field.column], unique=True)
if len(constraint_names) != 1:
raise ValueError("Found wrong number (%s) of constraints for %s.%s" % (
len(constraint_names),
model._meta.db_table,
old_field.column,
))
self.execute(
self.sql_delete_unique % {
"table": self.quote_name(model._meta.db_table),
"name": constraint_names[0],
},
)
# Have they renamed the column?
if old_field.column != new_field.column:
self.execute(self.sql_rename_column % {
"table": self.quote_name(model._meta.db_table),
@ -347,16 +365,58 @@ class BaseDatabaseSchemaEditor(object):
},
[],
))
# Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters:
sql, params = tuple(zip(*actions))
actions = [(", ".join(sql), params)]
# Apply those actions
for sql, params in actions:
if actions:
# Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters:
sql, params = tuple(zip(*actions))
actions = [(", ".join(sql), params)]
# Apply those actions
for sql, params in actions:
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
params,
)
# Added a unique?
if not old_field.unique and new_field.unique:
self.execute(
self.sql_alter_column % {
self.sql_create_unique % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
params,
"name": self._create_index_name(model, [new_field.column], suffix="_uniq"),
"columns": self.quote_name(new_field.column),
}
)
def _create_index_name(self, model, column_names, suffix=""):
"Generates a unique name for an index/unique constraint."
# If there is just one column in the index, use a default algorithm from Django
if len(column_names) == 1 and not suffix:
return truncate_name(
'%s_%s' % (model._meta.db_table, BaseDatabaseCreation._digest(column_names[0])),
self.connection.ops.max_name_length()
)
# Else generate the name for the index by South
table_name = model._meta.db_table.replace('"', '').replace('.', '_')
index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names))))
# If the index name is too long, truncate it
index_name = ('%s_%s%s%s' % (table_name, column_names[0], index_unique_name, suffix)).replace('"', '').replace('.', '_')
if len(index_name) > self.connection.features.max_index_name_length:
part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix))
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
return index_name
def _constraint_names(self, model, column_names, unique=None, primary_key=None):
"Returns all constraint names matching the columns and conditions"
column_names = set(column_names)
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names == infodict['columns']:
if unique is not None and infodict['unique'] != unique:
continue
if primary_key is not None and infodict['primary_key'] != unique:
continue
result.append(name)
return result

View File

@ -12,10 +12,23 @@ class Author(models.Model):
managed = False
class AuthorWithM2M(models.Model):
name = models.CharField(max_length=255)
class Meta:
managed = False
class Book(models.Model):
author = models.ForeignKey(Author)
title = models.CharField(max_length=100)
pub_date = models.DateTimeField()
#tags = models.ManyToManyField("Tag", related_name="books")
class Meta:
managed = False
class Tag(models.Model):
title = models.CharField(max_length=255)
slug = models.SlugField(unique=True)

View File

@ -3,9 +3,10 @@ import copy
import datetime
from django.test import TestCase
from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField
from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField
from django.db.models.loading import cache
from .models import Author, Book
from .models import Author, Book, AuthorWithM2M, Tag
class SchemaTests(TestCase):
@ -17,7 +18,7 @@ class SchemaTests(TestCase):
as the code it is testing.
"""
models = [Author, Book]
models = [Author, Book, AuthorWithM2M, Tag]
# Utility functions
@ -39,6 +40,17 @@ class SchemaTests(TestCase):
# Delete any tables made for our models
cursor = connection.cursor()
for model in self.models:
# Remove any M2M tables first
for field in model._meta.local_many_to_many:
try:
cursor.execute("DROP TABLE %s CASCADE" % (
connection.ops.quote_name(field.rel.through._meta.db_table),
))
except DatabaseError:
connection.rollback()
else:
connection.commit()
# Then remove the main tables
try:
cursor.execute("DROP TABLE %s CASCADE" % (
connection.ops.quote_name(model._meta.db_table),
@ -172,3 +184,117 @@ class SchemaTests(TestCase):
columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "TextField")
self.assertEqual(columns['name'][1][6], True)
def test_rename(self):
"""
Tests simple altering of fields
"""
# Create the table
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.commit()
# Ensure the field is right to begin with
columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "CharField")
self.assertEqual(columns['name'][1][3], 255)
self.assertNotIn("display_name", columns)
# Alter the name field's name
new_field = CharField(max_length=254)
new_field.set_attributes_from_name("display_name")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Author,
Author._meta.get_field_by_name("name")[0],
new_field,
)
editor.commit()
# Ensure the field is right afterwards
columns = self.column_classes(Author)
self.assertEqual(columns['display_name'][0], "CharField")
self.assertEqual(columns['display_name'][1][3], 254)
self.assertNotIn("name", columns)
def test_m2m(self):
"""
Tests adding/removing M2M fields on models
"""
# Create the tables
editor = connection.schema_editor()
editor.start()
editor.create_model(AuthorWithM2M)
editor.create_model(Tag)
editor.commit()
# Create an M2M field
new_field = ManyToManyField("schema.Tag", related_name="authors")
new_field.contribute_to_class(AuthorWithM2M, "tags")
# Ensure there's no m2m table there
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
connection.rollback()
# Add the field
editor = connection.schema_editor()
editor.start()
editor.create_field(
Author,
new_field,
)
editor.commit()
# Ensure there is now an m2m table there
columns = self.column_classes(new_field.rel.through)
self.assertEqual(columns['tag_id'][0], "IntegerField")
# Remove the M2M table again
editor = connection.schema_editor()
editor.start()
editor.delete_field(
Author,
new_field,
)
editor.commit()
# Ensure there's no m2m table there
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
connection.rollback()
def test_unique(self):
"""
Tests removing and adding unique constraints to a single column.
"""
# Create the table
editor = connection.schema_editor()
editor.start()
editor.create_model(Tag)
editor.commit()
# Ensure the field is unique to begin with
Tag.objects.create(title="foo", slug="foo")
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
connection.rollback()
# Alter the slug field to be non-unique
new_field = SlugField(unique=False)
new_field.set_attributes_from_name("slug")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Tag,
Tag._meta.get_field_by_name("slug")[0],
new_field,
)
editor.commit()
# Ensure the field is no longer unique
Tag.objects.create(title="foo", slug="foo")
Tag.objects.create(title="bar", slug="foo")
connection.rollback()
# Alter the slug field to be non-unique
new_new_field = SlugField(unique=True)
new_new_field.set_attributes_from_name("slug")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Tag,
new_field,
new_new_field,
)
editor.commit()
# Ensure the field is unique again
Tag.objects.create(title="foo", slug="foo")
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
connection.rollback()