Add M2M tests and some unique support

This commit is contained in:
Andrew Godwin 2012-08-02 15:08:39 +01:00
parent 4a2e80fff4
commit b139315f1c
6 changed files with 259 additions and 15 deletions

View File

@ -427,6 +427,9 @@ class BaseDatabaseFeatures(object):
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE? # Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
supports_combined_alters = False supports_combined_alters = False
# What's the maximum length for index names?
max_index_name_length = 63
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
@ -1056,6 +1059,15 @@ class BaseDatabaseIntrospection(object):
""" """
raise NotImplementedError raise NotImplementedError
def get_constraints(self, cursor, table_name):
"""
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
Both single- and multi-column constraints are introspected.
"""
raise NotImplementedError
class BaseDatabaseClient(object): class BaseDatabaseClient(object):
""" """
This class encapsulates all backend-specific methods for opening a This class encapsulates all backend-specific methods for opening a

View File

@ -21,7 +21,8 @@ class BaseDatabaseCreation(object):
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
def _digest(self, *args): @classmethod
def _digest(cls, *args):
""" """
Generates a 32-bit digest of a set of arguments that can be used to Generates a 32-bit digest of a set of arguments that can be used to
shorten identifying names. shorten identifying names.

View File

@ -88,3 +88,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
continue continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes return indexes
def get_constraints(self, cursor, table_name):
"""
Retrieves any constraints (unique, pk, check) across one or more columns.
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
"""
constraints = {}
# Loop over the constraint tables, collecting things as constraints
ifsc_tables = ["constraint_column_usage", "key_column_usage"]
for ifsc_table in ifsc_tables:
cursor.execute("""
SELECT kc.constraint_name, kc.column_name, c.constraint_type
FROM information_schema.%s AS kc
JOIN information_schema.table_constraints AS c ON
kc.table_schema = c.table_schema AND
kc.table_name = c.table_name AND
kc.constraint_name = c.constraint_name
WHERE
kc.table_schema = %%s AND
kc.table_name = %%s
""" % ifsc_table, ["public", table_name])
for constraint, column, kind in cursor.fetchall():
# If we're the first column, make the record
if constraint not in constraints:
constraints[constraint] = {
"columns": set(),
"primary_key": kind.lower() == "primary key",
"unique": kind.lower() in ["primary key", "unique"],
}
# Record the details
constraints[constraint]['columns'].add(column)
return constraints

View File

@ -4,6 +4,8 @@ import time
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from django.db.utils import load_backend from django.db.utils import load_backend
from django.db.backends.creation import BaseDatabaseCreation
from django.db.backends.util import truncate_name
from django.utils.log import getLogger from django.utils.log import getLogger
from django.db.models.fields.related import ManyToManyField from django.db.models.fields.related import ManyToManyField
@ -294,7 +296,23 @@ class BaseDatabaseSchemaEditor(object):
old_field, old_field,
new_field, new_field,
)) ))
# First, have they renamed the column? # Has unique been removed?
if old_field.unique and not new_field.unique:
# Find the unique constraint for this field
constraint_names = self._constraint_names(model, [old_field.column], unique=True)
if len(constraint_names) != 1:
raise ValueError("Found wrong number (%s) of constraints for %s.%s" % (
len(constraint_names),
model._meta.db_table,
old_field.column,
))
self.execute(
self.sql_delete_unique % {
"table": self.quote_name(model._meta.db_table),
"name": constraint_names[0],
},
)
# Have they renamed the column?
if old_field.column != new_field.column: if old_field.column != new_field.column:
self.execute(self.sql_rename_column % { self.execute(self.sql_rename_column % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
@ -347,16 +365,58 @@ class BaseDatabaseSchemaEditor(object):
}, },
[], [],
)) ))
# Combine actions together if we can (e.g. postgres) if actions:
if self.connection.features.supports_combined_alters: # Combine actions together if we can (e.g. postgres)
sql, params = tuple(zip(*actions)) if self.connection.features.supports_combined_alters:
actions = [(", ".join(sql), params)] sql, params = tuple(zip(*actions))
# Apply those actions actions = [(", ".join(sql), params)]
for sql, params in actions: # Apply those actions
for sql, params in actions:
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
params,
)
# Added a unique?
if not old_field.unique and new_field.unique:
self.execute( self.execute(
self.sql_alter_column % { self.sql_create_unique % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
"changes": sql, "name": self._create_index_name(model, [new_field.column], suffix="_uniq"),
}, "columns": self.quote_name(new_field.column),
params, }
) )
def _create_index_name(self, model, column_names, suffix=""):
"Generates a unique name for an index/unique constraint."
# If there is just one column in the index, use a default algorithm from Django
if len(column_names) == 1 and not suffix:
return truncate_name(
'%s_%s' % (model._meta.db_table, BaseDatabaseCreation._digest(column_names[0])),
self.connection.ops.max_name_length()
)
# Else generate the name for the index by South
table_name = model._meta.db_table.replace('"', '').replace('.', '_')
index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names))))
# If the index name is too long, truncate it
index_name = ('%s_%s%s%s' % (table_name, column_names[0], index_unique_name, suffix)).replace('"', '').replace('.', '_')
if len(index_name) > self.connection.features.max_index_name_length:
part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix))
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
return index_name
def _constraint_names(self, model, column_names, unique=None, primary_key=None):
"Returns all constraint names matching the columns and conditions"
column_names = set(column_names)
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names == infodict['columns']:
if unique is not None and infodict['unique'] != unique:
continue
if primary_key is not None and infodict['primary_key'] != unique:
continue
result.append(name)
return result

View File

@ -12,10 +12,23 @@ class Author(models.Model):
managed = False managed = False
class AuthorWithM2M(models.Model):
name = models.CharField(max_length=255)
class Meta:
managed = False
class Book(models.Model): class Book(models.Model):
author = models.ForeignKey(Author) author = models.ForeignKey(Author)
title = models.CharField(max_length=100) title = models.CharField(max_length=100)
pub_date = models.DateTimeField() pub_date = models.DateTimeField()
#tags = models.ManyToManyField("Tag", related_name="books")
class Meta: class Meta:
managed = False managed = False
class Tag(models.Model):
title = models.CharField(max_length=255)
slug = models.SlugField(unique=True)

View File

@ -3,9 +3,10 @@ import copy
import datetime import datetime
from django.test import TestCase from django.test import TestCase
from django.db import connection, DatabaseError, IntegrityError from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField from django.db.models.fields import IntegerField, TextField, CharField, SlugField
from django.db.models.fields.related import ManyToManyField
from django.db.models.loading import cache from django.db.models.loading import cache
from .models import Author, Book from .models import Author, Book, AuthorWithM2M, Tag
class SchemaTests(TestCase): class SchemaTests(TestCase):
@ -17,7 +18,7 @@ class SchemaTests(TestCase):
as the code it is testing. as the code it is testing.
""" """
models = [Author, Book] models = [Author, Book, AuthorWithM2M, Tag]
# Utility functions # Utility functions
@ -39,6 +40,17 @@ class SchemaTests(TestCase):
# Delete any tables made for our models # Delete any tables made for our models
cursor = connection.cursor() cursor = connection.cursor()
for model in self.models: for model in self.models:
# Remove any M2M tables first
for field in model._meta.local_many_to_many:
try:
cursor.execute("DROP TABLE %s CASCADE" % (
connection.ops.quote_name(field.rel.through._meta.db_table),
))
except DatabaseError:
connection.rollback()
else:
connection.commit()
# Then remove the main tables
try: try:
cursor.execute("DROP TABLE %s CASCADE" % ( cursor.execute("DROP TABLE %s CASCADE" % (
connection.ops.quote_name(model._meta.db_table), connection.ops.quote_name(model._meta.db_table),
@ -172,3 +184,117 @@ class SchemaTests(TestCase):
columns = self.column_classes(Author) columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "TextField") self.assertEqual(columns['name'][0], "TextField")
self.assertEqual(columns['name'][1][6], True) self.assertEqual(columns['name'][1][6], True)
def test_rename(self):
"""
Tests simple altering of fields
"""
# Create the table
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.commit()
# Ensure the field is right to begin with
columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "CharField")
self.assertEqual(columns['name'][1][3], 255)
self.assertNotIn("display_name", columns)
# Alter the name field's name
new_field = CharField(max_length=254)
new_field.set_attributes_from_name("display_name")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Author,
Author._meta.get_field_by_name("name")[0],
new_field,
)
editor.commit()
# Ensure the field is right afterwards
columns = self.column_classes(Author)
self.assertEqual(columns['display_name'][0], "CharField")
self.assertEqual(columns['display_name'][1][3], 254)
self.assertNotIn("name", columns)
def test_m2m(self):
"""
Tests adding/removing M2M fields on models
"""
# Create the tables
editor = connection.schema_editor()
editor.start()
editor.create_model(AuthorWithM2M)
editor.create_model(Tag)
editor.commit()
# Create an M2M field
new_field = ManyToManyField("schema.Tag", related_name="authors")
new_field.contribute_to_class(AuthorWithM2M, "tags")
# Ensure there's no m2m table there
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
connection.rollback()
# Add the field
editor = connection.schema_editor()
editor.start()
editor.create_field(
Author,
new_field,
)
editor.commit()
# Ensure there is now an m2m table there
columns = self.column_classes(new_field.rel.through)
self.assertEqual(columns['tag_id'][0], "IntegerField")
# Remove the M2M table again
editor = connection.schema_editor()
editor.start()
editor.delete_field(
Author,
new_field,
)
editor.commit()
# Ensure there's no m2m table there
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
connection.rollback()
def test_unique(self):
"""
Tests removing and adding unique constraints to a single column.
"""
# Create the table
editor = connection.schema_editor()
editor.start()
editor.create_model(Tag)
editor.commit()
# Ensure the field is unique to begin with
Tag.objects.create(title="foo", slug="foo")
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
connection.rollback()
# Alter the slug field to be non-unique
new_field = SlugField(unique=False)
new_field.set_attributes_from_name("slug")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Tag,
Tag._meta.get_field_by_name("slug")[0],
new_field,
)
editor.commit()
# Ensure the field is no longer unique
Tag.objects.create(title="foo", slug="foo")
Tag.objects.create(title="bar", slug="foo")
connection.rollback()
# Alter the slug field to be non-unique
new_new_field = SlugField(unique=True)
new_new_field.set_attributes_from_name("slug")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Tag,
new_field,
new_new_field,
)
editor.commit()
# Ensure the field is unique again
Tag.objects.create(title="foo", slug="foo")
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
connection.rollback()