Add M2M tests and some unique support
This commit is contained in:
parent
4a2e80fff4
commit
b139315f1c
|
@ -427,6 +427,9 @@ class BaseDatabaseFeatures(object):
|
||||||
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||||
supports_combined_alters = False
|
supports_combined_alters = False
|
||||||
|
|
||||||
|
# What's the maximum length for index names?
|
||||||
|
max_index_name_length = 63
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
|
@ -1056,6 +1059,15 @@ class BaseDatabaseIntrospection(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_constraints(self, cursor, table_name):
|
||||||
|
"""
|
||||||
|
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
|
||||||
|
|
||||||
|
Both single- and multi-column constraints are introspected.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class BaseDatabaseClient(object):
|
class BaseDatabaseClient(object):
|
||||||
"""
|
"""
|
||||||
This class encapsulates all backend-specific methods for opening a
|
This class encapsulates all backend-specific methods for opening a
|
||||||
|
|
|
@ -21,7 +21,8 @@ class BaseDatabaseCreation(object):
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
def _digest(self, *args):
|
@classmethod
|
||||||
|
def _digest(cls, *args):
|
||||||
"""
|
"""
|
||||||
Generates a 32-bit digest of a set of arguments that can be used to
|
Generates a 32-bit digest of a set of arguments that can be used to
|
||||||
shorten identifying names.
|
shorten identifying names.
|
||||||
|
|
|
@ -88,3 +88,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
continue
|
continue
|
||||||
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
|
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
|
||||||
return indexes
|
return indexes
|
||||||
|
|
||||||
|
def get_constraints(self, cursor, table_name):
|
||||||
|
"""
|
||||||
|
Retrieves any constraints (unique, pk, check) across one or more columns.
|
||||||
|
Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
|
||||||
|
"""
|
||||||
|
constraints = {}
|
||||||
|
# Loop over the constraint tables, collecting things as constraints
|
||||||
|
ifsc_tables = ["constraint_column_usage", "key_column_usage"]
|
||||||
|
for ifsc_table in ifsc_tables:
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT kc.constraint_name, kc.column_name, c.constraint_type
|
||||||
|
FROM information_schema.%s AS kc
|
||||||
|
JOIN information_schema.table_constraints AS c ON
|
||||||
|
kc.table_schema = c.table_schema AND
|
||||||
|
kc.table_name = c.table_name AND
|
||||||
|
kc.constraint_name = c.constraint_name
|
||||||
|
WHERE
|
||||||
|
kc.table_schema = %%s AND
|
||||||
|
kc.table_name = %%s
|
||||||
|
""" % ifsc_table, ["public", table_name])
|
||||||
|
for constraint, column, kind in cursor.fetchall():
|
||||||
|
# If we're the first column, make the record
|
||||||
|
if constraint not in constraints:
|
||||||
|
constraints[constraint] = {
|
||||||
|
"columns": set(),
|
||||||
|
"primary_key": kind.lower() == "primary key",
|
||||||
|
"unique": kind.lower() in ["primary key", "unique"],
|
||||||
|
}
|
||||||
|
# Record the details
|
||||||
|
constraints[constraint]['columns'].add(column)
|
||||||
|
return constraints
|
||||||
|
|
|
@ -4,6 +4,8 @@ import time
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.utils import load_backend
|
from django.db.utils import load_backend
|
||||||
|
from django.db.backends.creation import BaseDatabaseCreation
|
||||||
|
from django.db.backends.util import truncate_name
|
||||||
from django.utils.log import getLogger
|
from django.utils.log import getLogger
|
||||||
from django.db.models.fields.related import ManyToManyField
|
from django.db.models.fields.related import ManyToManyField
|
||||||
|
|
||||||
|
@ -294,7 +296,23 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
old_field,
|
old_field,
|
||||||
new_field,
|
new_field,
|
||||||
))
|
))
|
||||||
# First, have they renamed the column?
|
# Has unique been removed?
|
||||||
|
if old_field.unique and not new_field.unique:
|
||||||
|
# Find the unique constraint for this field
|
||||||
|
constraint_names = self._constraint_names(model, [old_field.column], unique=True)
|
||||||
|
if len(constraint_names) != 1:
|
||||||
|
raise ValueError("Found wrong number (%s) of constraints for %s.%s" % (
|
||||||
|
len(constraint_names),
|
||||||
|
model._meta.db_table,
|
||||||
|
old_field.column,
|
||||||
|
))
|
||||||
|
self.execute(
|
||||||
|
self.sql_delete_unique % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"name": constraint_names[0],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Have they renamed the column?
|
||||||
if old_field.column != new_field.column:
|
if old_field.column != new_field.column:
|
||||||
self.execute(self.sql_rename_column % {
|
self.execute(self.sql_rename_column % {
|
||||||
"table": self.quote_name(model._meta.db_table),
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
@ -347,6 +365,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
},
|
},
|
||||||
[],
|
[],
|
||||||
))
|
))
|
||||||
|
if actions:
|
||||||
# Combine actions together if we can (e.g. postgres)
|
# Combine actions together if we can (e.g. postgres)
|
||||||
if self.connection.features.supports_combined_alters:
|
if self.connection.features.supports_combined_alters:
|
||||||
sql, params = tuple(zip(*actions))
|
sql, params = tuple(zip(*actions))
|
||||||
|
@ -360,3 +379,44 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
},
|
},
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
# Added a unique?
|
||||||
|
if not old_field.unique and new_field.unique:
|
||||||
|
self.execute(
|
||||||
|
self.sql_create_unique % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"name": self._create_index_name(model, [new_field.column], suffix="_uniq"),
|
||||||
|
"columns": self.quote_name(new_field.column),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_index_name(self, model, column_names, suffix=""):
|
||||||
|
"Generates a unique name for an index/unique constraint."
|
||||||
|
# If there is just one column in the index, use a default algorithm from Django
|
||||||
|
if len(column_names) == 1 and not suffix:
|
||||||
|
return truncate_name(
|
||||||
|
'%s_%s' % (model._meta.db_table, BaseDatabaseCreation._digest(column_names[0])),
|
||||||
|
self.connection.ops.max_name_length()
|
||||||
|
)
|
||||||
|
# Else generate the name for the index by South
|
||||||
|
table_name = model._meta.db_table.replace('"', '').replace('.', '_')
|
||||||
|
index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names))))
|
||||||
|
# If the index name is too long, truncate it
|
||||||
|
index_name = ('%s_%s%s%s' % (table_name, column_names[0], index_unique_name, suffix)).replace('"', '').replace('.', '_')
|
||||||
|
if len(index_name) > self.connection.features.max_index_name_length:
|
||||||
|
part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix))
|
||||||
|
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
|
||||||
|
return index_name
|
||||||
|
|
||||||
|
def _constraint_names(self, model, column_names, unique=None, primary_key=None):
|
||||||
|
"Returns all constraint names matching the columns and conditions"
|
||||||
|
column_names = set(column_names)
|
||||||
|
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
|
||||||
|
result = []
|
||||||
|
for name, infodict in constraints.items():
|
||||||
|
if column_names == infodict['columns']:
|
||||||
|
if unique is not None and infodict['unique'] != unique:
|
||||||
|
continue
|
||||||
|
if primary_key is not None and infodict['primary_key'] != unique:
|
||||||
|
continue
|
||||||
|
result.append(name)
|
||||||
|
return result
|
||||||
|
|
|
@ -12,10 +12,23 @@ class Author(models.Model):
|
||||||
managed = False
|
managed = False
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorWithM2M(models.Model):
|
||||||
|
name = models.CharField(max_length=255)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
managed = False
|
||||||
|
|
||||||
|
|
||||||
class Book(models.Model):
|
class Book(models.Model):
|
||||||
author = models.ForeignKey(Author)
|
author = models.ForeignKey(Author)
|
||||||
title = models.CharField(max_length=100)
|
title = models.CharField(max_length=100)
|
||||||
pub_date = models.DateTimeField()
|
pub_date = models.DateTimeField()
|
||||||
|
#tags = models.ManyToManyField("Tag", related_name="books")
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
managed = False
|
managed = False
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(models.Model):
|
||||||
|
title = models.CharField(max_length=255)
|
||||||
|
slug = models.SlugField(unique=True)
|
||||||
|
|
|
@ -3,9 +3,10 @@ import copy
|
||||||
import datetime
|
import datetime
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.db import connection, DatabaseError, IntegrityError
|
from django.db import connection, DatabaseError, IntegrityError
|
||||||
from django.db.models.fields import IntegerField, TextField
|
from django.db.models.fields import IntegerField, TextField, CharField, SlugField
|
||||||
|
from django.db.models.fields.related import ManyToManyField
|
||||||
from django.db.models.loading import cache
|
from django.db.models.loading import cache
|
||||||
from .models import Author, Book
|
from .models import Author, Book, AuthorWithM2M, Tag
|
||||||
|
|
||||||
|
|
||||||
class SchemaTests(TestCase):
|
class SchemaTests(TestCase):
|
||||||
|
@ -17,7 +18,7 @@ class SchemaTests(TestCase):
|
||||||
as the code it is testing.
|
as the code it is testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = [Author, Book]
|
models = [Author, Book, AuthorWithM2M, Tag]
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
|
|
||||||
|
@ -39,6 +40,17 @@ class SchemaTests(TestCase):
|
||||||
# Delete any tables made for our models
|
# Delete any tables made for our models
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
|
# Remove any M2M tables first
|
||||||
|
for field in model._meta.local_many_to_many:
|
||||||
|
try:
|
||||||
|
cursor.execute("DROP TABLE %s CASCADE" % (
|
||||||
|
connection.ops.quote_name(field.rel.through._meta.db_table),
|
||||||
|
))
|
||||||
|
except DatabaseError:
|
||||||
|
connection.rollback()
|
||||||
|
else:
|
||||||
|
connection.commit()
|
||||||
|
# Then remove the main tables
|
||||||
try:
|
try:
|
||||||
cursor.execute("DROP TABLE %s CASCADE" % (
|
cursor.execute("DROP TABLE %s CASCADE" % (
|
||||||
connection.ops.quote_name(model._meta.db_table),
|
connection.ops.quote_name(model._meta.db_table),
|
||||||
|
@ -172,3 +184,117 @@ class SchemaTests(TestCase):
|
||||||
columns = self.column_classes(Author)
|
columns = self.column_classes(Author)
|
||||||
self.assertEqual(columns['name'][0], "TextField")
|
self.assertEqual(columns['name'][0], "TextField")
|
||||||
self.assertEqual(columns['name'][1][6], True)
|
self.assertEqual(columns['name'][1][6], True)
|
||||||
|
|
||||||
|
def test_rename(self):
|
||||||
|
"""
|
||||||
|
Tests simple altering of fields
|
||||||
|
"""
|
||||||
|
# Create the table
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_model(Author)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is right to begin with
|
||||||
|
columns = self.column_classes(Author)
|
||||||
|
self.assertEqual(columns['name'][0], "CharField")
|
||||||
|
self.assertEqual(columns['name'][1][3], 255)
|
||||||
|
self.assertNotIn("display_name", columns)
|
||||||
|
# Alter the name field's name
|
||||||
|
new_field = CharField(max_length=254)
|
||||||
|
new_field.set_attributes_from_name("display_name")
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.alter_field(
|
||||||
|
Author,
|
||||||
|
Author._meta.get_field_by_name("name")[0],
|
||||||
|
new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is right afterwards
|
||||||
|
columns = self.column_classes(Author)
|
||||||
|
self.assertEqual(columns['display_name'][0], "CharField")
|
||||||
|
self.assertEqual(columns['display_name'][1][3], 254)
|
||||||
|
self.assertNotIn("name", columns)
|
||||||
|
|
||||||
|
def test_m2m(self):
|
||||||
|
"""
|
||||||
|
Tests adding/removing M2M fields on models
|
||||||
|
"""
|
||||||
|
# Create the tables
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_model(AuthorWithM2M)
|
||||||
|
editor.create_model(Tag)
|
||||||
|
editor.commit()
|
||||||
|
# Create an M2M field
|
||||||
|
new_field = ManyToManyField("schema.Tag", related_name="authors")
|
||||||
|
new_field.contribute_to_class(AuthorWithM2M, "tags")
|
||||||
|
# Ensure there's no m2m table there
|
||||||
|
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
|
||||||
|
connection.rollback()
|
||||||
|
# Add the field
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_field(
|
||||||
|
Author,
|
||||||
|
new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure there is now an m2m table there
|
||||||
|
columns = self.column_classes(new_field.rel.through)
|
||||||
|
self.assertEqual(columns['tag_id'][0], "IntegerField")
|
||||||
|
# Remove the M2M table again
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.delete_field(
|
||||||
|
Author,
|
||||||
|
new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure there's no m2m table there
|
||||||
|
self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
|
||||||
|
connection.rollback()
|
||||||
|
|
||||||
|
def test_unique(self):
|
||||||
|
"""
|
||||||
|
Tests removing and adding unique constraints to a single column.
|
||||||
|
"""
|
||||||
|
# Create the table
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_model(Tag)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is unique to begin with
|
||||||
|
Tag.objects.create(title="foo", slug="foo")
|
||||||
|
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
|
||||||
|
connection.rollback()
|
||||||
|
# Alter the slug field to be non-unique
|
||||||
|
new_field = SlugField(unique=False)
|
||||||
|
new_field.set_attributes_from_name("slug")
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.alter_field(
|
||||||
|
Tag,
|
||||||
|
Tag._meta.get_field_by_name("slug")[0],
|
||||||
|
new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is no longer unique
|
||||||
|
Tag.objects.create(title="foo", slug="foo")
|
||||||
|
Tag.objects.create(title="bar", slug="foo")
|
||||||
|
connection.rollback()
|
||||||
|
# Alter the slug field to be non-unique
|
||||||
|
new_new_field = SlugField(unique=True)
|
||||||
|
new_new_field.set_attributes_from_name("slug")
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.alter_field(
|
||||||
|
Tag,
|
||||||
|
new_field,
|
||||||
|
new_new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is unique again
|
||||||
|
Tag.objects.create(title="foo", slug="foo")
|
||||||
|
self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
|
||||||
|
connection.rollback()
|
||||||
|
|
Loading…
Reference in New Issue