Add M2M tests and some unique support
This commit is contained in:
parent
4a2e80fff4
commit
b139315f1c
|
@ -427,6 +427,9 @@ class BaseDatabaseFeatures(object):
|
|||
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||
supports_combined_alters = False
|
||||
|
||||
# What's the maximum length for index names?
|
||||
max_index_name_length = 63
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
|
@ -1056,6 +1059,15 @@ class BaseDatabaseIntrospection(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
|
||||
|
||||
Both single- and multi-column constraints are introspected.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseDatabaseClient(object):
|
||||
"""
|
||||
This class encapsulates all backend-specific methods for opening a
|
||||
|
|
|
@ -21,7 +21,8 @@ class BaseDatabaseCreation(object):
|
|||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def _digest(self, *args):
|
||||
@classmethod
|
||||
def _digest(cls, *args):
|
||||
"""
|
||||
Generates a 32-bit digest of a set of arguments that can be used to
|
||||
shorten identifying names.
|
||||
|
|
|
@ -88,3 +88,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
continue
|
||||
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
|
||||
return indexes
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieves any constraints (unique, pk, check) across one or more columns.
|
||||
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the constraint tables, collecting things as constraints
|
||||
ifsc_tables = ["constraint_column_usage", "key_column_usage"]
|
||||
for ifsc_table in ifsc_tables:
|
||||
cursor.execute("""
|
||||
SELECT kc.constraint_name, kc.column_name, c.constraint_type
|
||||
FROM information_schema.%s AS kc
|
||||
JOIN information_schema.table_constraints AS c ON
|
||||
kc.table_schema = c.table_schema AND
|
||||
kc.table_name = c.table_name AND
|
||||
kc.constraint_name = c.constraint_name
|
||||
WHERE
|
||||
kc.table_schema = %%s AND
|
||||
kc.table_name = %%s
|
||||
""" % ifsc_table, ["public", table_name])
|
||||
for constraint, column, kind in cursor.fetchall():
|
||||
# If we're the first column, make the record
|
||||
if constraint not in constraints:
|
||||
constraints[constraint] = {
|
||||
"columns": set(),
|
||||
"primary_key": kind.lower() == "primary key",
|
||||
"unique": kind.lower() in ["primary key", "unique"],
|
||||
}
|
||||
# Record the details
|
||||
constraints[constraint]['columns'].add(column)
|
||||
return constraints
|
||||
|
|
|
@ -4,6 +4,8 @@ import time
|
|||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.db.utils import load_backend
|
||||
from django.db.backends.creation import BaseDatabaseCreation
|
||||
from django.db.backends.util import truncate_name
|
||||
from django.utils.log import getLogger
|
||||
from django.db.models.fields.related import ManyToManyField
|
||||
|
||||
|
@ -294,7 +296,23 @@ class BaseDatabaseSchemaEditor(object):
|
|||
old_field,
|
||||
new_field,
|
||||
))
|
||||
# First, have they renamed the column?
|
||||
# Has unique been removed?
|
||||
if old_field.unique and not new_field.unique:
|
||||
# Find the unique constraint for this field
|
||||
constraint_names = self._constraint_names(model, [old_field.column], unique=True)
|
||||
if len(constraint_names) != 1:
|
||||
raise ValueError("Found wrong number (%s) of constraints for %s.%s" % (
|
||||
len(constraint_names),
|
||||
model._meta.db_table,
|
||||
old_field.column,
|
||||
))
|
||||
self.execute(
|
||||
self.sql_delete_unique % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"name": constraint_names[0],
|
||||
},
|
||||
)
|
||||
# Have they renamed the column?
|
||||
if old_field.column != new_field.column:
|
||||
self.execute(self.sql_rename_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
|
@ -347,16 +365,58 @@ class BaseDatabaseSchemaEditor(object):
|
|||
},
|
||||
[],
|
||||
))
|
||||
# Combine actions together if we can (e.g. postgres)
|
||||
if self.connection.features.supports_combined_alters:
|
||||
sql, params = tuple(zip(*actions))
|
||||
actions = [(", ".join(sql), params)]
|
||||
# Apply those actions
|
||||
for sql, params in actions:
|
||||
if actions:
|
||||
# Combine actions together if we can (e.g. postgres)
|
||||
if self.connection.features.supports_combined_alters:
|
||||
sql, params = tuple(zip(*actions))
|
||||
actions = [(", ".join(sql), params)]
|
||||
# Apply those actions
|
||||
for sql, params in actions:
|
||||
self.execute(
|
||||
self.sql_alter_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"changes": sql,
|
||||
},
|
||||
params,
|
||||
)
|
||||
# Added a unique?
|
||||
if not old_field.unique and new_field.unique:
|
||||
self.execute(
|
||||
self.sql_alter_column % {
|
||||
self.sql_create_unique % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"changes": sql,
|
||||
},
|
||||
params,
|
||||
"name": self._create_index_name(model, [new_field.column], suffix="_uniq"),
|
||||
"columns": self.quote_name(new_field.column),
|
||||
}
|
||||
)
|
||||
|
||||
def _create_index_name(self, model, column_names, suffix=""):
|
||||
"Generates a unique name for an index/unique constraint."
|
||||
# If there is just one column in the index, use a default algorithm from Django
|
||||
if len(column_names) == 1 and not suffix:
|
||||
return truncate_name(
|
||||
'%s_%s' % (model._meta.db_table, BaseDatabaseCreation._digest(column_names[0])),
|
||||
self.connection.ops.max_name_length()
|
||||
)
|
||||
# Else generate the name for the index by South
|
||||
table_name = model._meta.db_table.replace('"', '').replace('.', '_')
|
||||
index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names))))
|
||||
# If the index name is too long, truncate it
|
||||
index_name = ('%s_%s%s%s' % (table_name, column_names[0], index_unique_name, suffix)).replace('"', '').replace('.', '_')
|
||||
if len(index_name) > self.connection.features.max_index_name_length:
|
||||
part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix))
|
||||
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
|
||||
return index_name
|
||||
|
||||
def _constraint_names(self, model, column_names, unique=None, primary_key=None):
|
||||
"Returns all constraint names matching the columns and conditions"
|
||||
column_names = set(column_names)
|
||||
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
|
||||
result = []
|
||||
for name, infodict in constraints.items():
|
||||
if column_names == infodict['columns']:
|
||||
if unique is not None and infodict['unique'] != unique:
|
||||
continue
|
||||
if primary_key is not None and infodict['primary_key'] != unique:
|
||||
continue
|
||||
result.append(name)
|
||||
return result
|
||||
|
|
|
@ -12,10 +12,23 @@ class Author(models.Model):
|
|||
managed = False
|
||||
|
||||
|
||||
class AuthorWithM2M(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
|
||||
class Meta:
|
||||
managed = False
|
||||
|
||||
|
||||
class Book(models.Model):
|
||||
author = models.ForeignKey(Author)
|
||||
title = models.CharField(max_length=100)
|
||||
pub_date = models.DateTimeField()
|
||||
#tags = models.ManyToManyField("Tag", related_name="books")
|
||||
|
||||
class Meta:
|
||||
managed = False
|
||||
|
||||
|
||||
class Tag(models.Model):
|
||||
title = models.CharField(max_length=255)
|
||||
slug = models.SlugField(unique=True)
|
||||
|
|
|
@ -3,9 +3,10 @@ import copy
|
|||
import datetime
|
||||
from django.test import TestCase
|
||||
from django.db import connection, DatabaseError, IntegrityError
|
||||
from django.db.models.fields import IntegerField, TextField
|
||||
from django.db.models.fields import IntegerField, TextField, CharField, SlugField
|
||||
from django.db.models.fields.related import ManyToManyField
|
||||
from django.db.models.loading import cache
|
||||
from .models import Author, Book
|
||||
from .models import Author, Book, AuthorWithM2M, Tag
|
||||
|
||||
|
||||
class SchemaTests(TestCase):
|
||||
|
@ -17,7 +18,7 @@ class SchemaTests(TestCase):
|
|||
as the code it is testing.
|
||||
"""
|
||||
|
||||
models = [Author, Book]
|
||||
models = [Author, Book, AuthorWithM2M, Tag]
|
||||
|
||||
# Utility functions
|
||||
|
||||
|
@ -39,6 +40,17 @@ class SchemaTests(TestCase):
|
|||
# Delete any tables made for our models
|
||||
cursor = connection.cursor()
|
||||
for model in self.models:
|
||||
# Remove any M2M tables first
|
||||
for field in model._meta.local_many_to_many:
|
||||
try:
|
||||
cursor.execute("DROP TABLE %s CASCADE" % (
|
||||
connection.ops.quote_name(field.rel.through._meta.db_table),
|
||||
))
|
||||
except DatabaseError:
|
||||
connection.rollback()
|
||||
else:
|
||||
connection.commit()
|
||||
# Then remove the main tables
|
||||
try:
|
||||
cursor.execute("DROP TABLE %s CASCADE" % (
|
||||
connection.ops.quote_name(model._meta.db_table),
|
||||
|
@ -172,3 +184,117 @@ class SchemaTests(TestCase):
|
|||
columns = self.column_classes(Author)
|
||||
self.assertEqual(columns['name'][0], "TextField")
|
||||
self.assertEqual(columns['name'][1][6], True)
|
||||
|
||||
def test_rename(self):
|
||||
"""
|
||||
Tests simple altering of fields
|
||||
"""
|
||||
# Create the table
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_model(Author)
|
||||
editor.commit()
|
||||
# Ensure the field is right to begin with
|
||||
columns = self.column_classes(Author)
|
||||
self.assertEqual(columns['name'][0], "CharField")
|
||||
self.assertEqual(columns['name'][1][3], 255)
|
||||
self.assertNotIn("display_name", columns)
|
||||
# Alter the name field's name
|
||||
new_field = CharField(max_length=254)
|
||||
new_field.set_attributes_from_name("display_name")
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.alter_field(
|
||||
Author,
|
||||
Author._meta.get_field_by_name("name")[0],
|
||||
new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure the field is right afterwards
|
||||
columns = self.column_classes(Author)
|
||||
self.assertEqual(columns['display_name'][0], "CharField")
|
||||
self.assertEqual(columns['display_name'][1][3], 254)
|
||||
self.assertNotIn("name", columns)
|
||||
|
||||
def test_m2m(self):
|
||||
"""
|
||||
Tests adding/removing M2M fields on models
|
||||
"""
|
||||
# Create the tables
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_model(AuthorWithM2M)
|
||||
editor.create_model(Tag)
|
||||
editor.commit()
|
||||
# Create an M2M field
|
||||
new_field = ManyToManyField("schema.Tag", related_name="authors")
|
||||
new_field.contribute_to_class(AuthorWithM2M, "tags")
|
||||
# Ensure there's no m2m table there
|
||||
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
|
||||
connection.rollback()
|
||||
# Add the field
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_field(
|
||||
Author,
|
||||
new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure there is now an m2m table there
|
||||
columns = self.column_classes(new_field.rel.through)
|
||||
self.assertEqual(columns['tag_id'][0], "IntegerField")
|
||||
# Remove the M2M table again
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.delete_field(
|
||||
Author,
|
||||
new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure there's no m2m table there
|
||||
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
|
||||
connection.rollback()
|
||||
|
||||
def test_unique(self):
|
||||
"""
|
||||
Tests removing and adding unique constraints to a single column.
|
||||
"""
|
||||
# Create the table
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_model(Tag)
|
||||
editor.commit()
|
||||
# Ensure the field is unique to begin with
|
||||
Tag.objects.create(title="foo", slug="foo")
|
||||
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
|
||||
connection.rollback()
|
||||
# Alter the slug field to be non-unique
|
||||
new_field = SlugField(unique=False)
|
||||
new_field.set_attributes_from_name("slug")
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.alter_field(
|
||||
Tag,
|
||||
Tag._meta.get_field_by_name("slug")[0],
|
||||
new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure the field is no longer unique
|
||||
Tag.objects.create(title="foo", slug="foo")
|
||||
Tag.objects.create(title="bar", slug="foo")
|
||||
connection.rollback()
|
||||
# Alter the slug field to be non-unique
|
||||
new_new_field = SlugField(unique=True)
|
||||
new_new_field.set_attributes_from_name("slug")
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.alter_field(
|
||||
Tag,
|
||||
new_field,
|
||||
new_new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure the field is unique again
|
||||
Tag.objects.create(title="foo", slug="foo")
|
||||
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
|
||||
connection.rollback()
|
||||
|
|
Loading…
Reference in New Issue