diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 816cbcda63..b9c642d093 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -427,6 +427,9 @@ class BaseDatabaseFeatures(object): # Can we issue more than one ALTER COLUMN clause in an ALTER TABLE? supports_combined_alters = False + # What's the maximum length for index names? + max_index_name_length = 63 + def __init__(self, connection): self.connection = connection @@ -1056,6 +1059,15 @@ class BaseDatabaseIntrospection(object): """ raise NotImplementedError + def get_constraints(self, cursor, table_name): + """ + Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}} + + Both single- and multi-column constraints are introspected. + """ + raise NotImplementedError + + class BaseDatabaseClient(object): """ This class encapsulates all backend-specific methods for opening a diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index fcc6ab7584..4dffd78f44 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -21,7 +21,8 @@ class BaseDatabaseCreation(object): def __init__(self, connection): self.connection = connection - def _digest(self, *args): + @classmethod + def _digest(cls, *args): """ Generates a 32-bit digest of a set of arguments that can be used to shorten identifying names. diff --git a/django/db/backends/postgresql_psycopg2/introspection.py b/django/db/backends/postgresql_psycopg2/introspection.py index 99573b9019..c8b8ec833b 100644 --- a/django/db/backends/postgresql_psycopg2/introspection.py +++ b/django/db/backends/postgresql_psycopg2/introspection.py @@ -88,3 +88,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): continue indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} return indexes + + def get_constraints(self, cursor, table_name): + """ + Retrieves any constraints (unique, pk, check) across one or more columns. + Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}} + """ + constraints = {} + # Loop over the constraint tables, collecting things as constraints + ifsc_tables = ["constraint_column_usage", "key_column_usage"] + for ifsc_table in ifsc_tables: + cursor.execute(""" + SELECT kc.constraint_name, kc.column_name, c.constraint_type + FROM information_schema.%s AS kc + JOIN information_schema.table_constraints AS c ON + kc.table_schema = c.table_schema AND + kc.table_name = c.table_name AND + kc.constraint_name = c.constraint_name + WHERE + kc.table_schema = %%s AND + kc.table_name = %%s + """ % ifsc_table, ["public", table_name]) + for constraint, column, kind in cursor.fetchall(): + # If we're the first column, make the record + if constraint not in constraints: + constraints[constraint] = { + "columns": set(), + "primary_key": kind.lower() == "primary key", + "unique": kind.lower() in ["primary key", "unique"], + } + # Record the details + constraints[constraint]['columns'].add(column) + return constraints diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index bf838d2094..5f4e0146b4 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -4,6 +4,8 @@ import time from django.conf import settings from django.db import transaction from django.db.utils import load_backend +from django.db.backends.creation import BaseDatabaseCreation +from django.db.backends.util import truncate_name from django.utils.log import getLogger from django.db.models.fields.related import ManyToManyField @@ -294,7 +296,23 @@ class BaseDatabaseSchemaEditor(object): old_field, new_field, )) - # First, have they renamed the column? + # Has unique been removed? + if old_field.unique and not new_field.unique: + # Find the unique constraint for this field + constraint_names = self._constraint_names(model, [old_field.column], unique=True) + if len(constraint_names) != 1: + raise ValueError("Found wrong number (%s) of constraints for %s.%s" % ( + len(constraint_names), + model._meta.db_table, + old_field.column, + )) + self.execute( + self.sql_delete_unique % { + "table": self.quote_name(model._meta.db_table), + "name": constraint_names[0], + }, + ) + # Have they renamed the column? if old_field.column != new_field.column: self.execute(self.sql_rename_column % { "table": self.quote_name(model._meta.db_table), @@ -347,16 +365,58 @@ class BaseDatabaseSchemaEditor(object): }, [], )) - # Combine actions together if we can (e.g. postgres) - if self.connection.features.supports_combined_alters: - sql, params = tuple(zip(*actions)) - actions = [(", ".join(sql), params)] - # Apply those actions - for sql, params in actions: + if actions: + # Combine actions together if we can (e.g. postgres) + if self.connection.features.supports_combined_alters: + sql, params = tuple(zip(*actions)) + actions = [(", ".join(sql), params)] + # Apply those actions + for sql, params in actions: + self.execute( + self.sql_alter_column % { + "table": self.quote_name(model._meta.db_table), + "changes": sql, + }, + params, + ) + # Added a unique? + if not old_field.unique and new_field.unique: self.execute( - self.sql_alter_column % { + self.sql_create_unique % { "table": self.quote_name(model._meta.db_table), - "changes": sql, - }, - params, + "name": self._create_index_name(model, [new_field.column], suffix="_uniq"), + "columns": self.quote_name(new_field.column), + } ) + + def _create_index_name(self, model, column_names, suffix=""): + "Generates a unique name for an index/unique constraint." + # If there is just one column in the index, use a default algorithm from Django + if len(column_names) == 1 and not suffix: + return truncate_name( + '%s_%s' % (model._meta.db_table, BaseDatabaseCreation._digest(column_names[0])), + self.connection.ops.max_name_length() + ) + # Else generate the name for the index by South + table_name = model._meta.db_table.replace('"', '').replace('.', '_') + index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names)))) + # If the index name is too long, truncate it + index_name = ('%s_%s%s%s' % (table_name, column_names[0], index_unique_name, suffix)).replace('"', '').replace('.', '_') + if len(index_name) > self.connection.features.max_index_name_length: + part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix)) + index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part) + return index_name + + def _constraint_names(self, model, column_names, unique=None, primary_key=None): + "Returns all constraint names matching the columns and conditions" + column_names = set(column_names) + constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) + result = [] + for name, infodict in constraints.items(): + if column_names == infodict['columns']: + if unique is not None and infodict['unique'] != unique: + continue + if primary_key is not None and infodict['primary_key'] != unique: + continue + result.append(name) + return result diff --git a/tests/modeltests/schema/models.py b/tests/modeltests/schema/models.py index 2c5dc829c6..2362718bf3 100644 --- a/tests/modeltests/schema/models.py +++ b/tests/modeltests/schema/models.py @@ -12,10 +12,23 @@ class Author(models.Model): managed = False +class AuthorWithM2M(models.Model): + name = models.CharField(max_length=255) + + class Meta: + managed = False + + class Book(models.Model): author = models.ForeignKey(Author) title = models.CharField(max_length=100) pub_date = models.DateTimeField() + #tags = models.ManyToManyField("Tag", related_name="books") class Meta: managed = False + + +class Tag(models.Model): + title = models.CharField(max_length=255) + slug = models.SlugField(unique=True) diff --git a/tests/modeltests/schema/tests.py b/tests/modeltests/schema/tests.py index 83b2dabd45..8708fd7c8d 100644 --- a/tests/modeltests/schema/tests.py +++ b/tests/modeltests/schema/tests.py @@ -3,9 +3,10 @@ import copy import datetime from django.test import TestCase from django.db import connection, DatabaseError, IntegrityError -from django.db.models.fields import IntegerField, TextField +from django.db.models.fields import IntegerField, TextField, CharField, SlugField +from django.db.models.fields.related import ManyToManyField from django.db.models.loading import cache -from .models import Author, Book +from .models import Author, Book, AuthorWithM2M, Tag class SchemaTests(TestCase): @@ -17,7 +18,7 @@ class SchemaTests(TestCase): as the code it is testing. """ - models = [Author, Book] + models = [Author, Book, AuthorWithM2M, Tag] # Utility functions @@ -39,6 +40,17 @@ class SchemaTests(TestCase): # Delete any tables made for our models cursor = connection.cursor() for model in self.models: + # Remove any M2M tables first + for field in model._meta.local_many_to_many: + try: + cursor.execute("DROP TABLE %s CASCADE" % ( + connection.ops.quote_name(field.rel.through._meta.db_table), + )) + except DatabaseError: + connection.rollback() + else: + connection.commit() + # Then remove the main tables try: cursor.execute("DROP TABLE %s CASCADE" % ( connection.ops.quote_name(model._meta.db_table), @@ -172,3 +184,117 @@ class SchemaTests(TestCase): columns = self.column_classes(Author) self.assertEqual(columns['name'][0], "TextField") self.assertEqual(columns['name'][1][6], True) + + def test_rename(self): + """ + Tests simple altering of fields + """ + # Create the table + editor = connection.schema_editor() + editor.start() + editor.create_model(Author) + editor.commit() + # Ensure the field is right to begin with + columns = self.column_classes(Author) + self.assertEqual(columns['name'][0], "CharField") + self.assertEqual(columns['name'][1][3], 255) + self.assertNotIn("display_name", columns) + # Alter the name field's name + new_field = CharField(max_length=254) + new_field.set_attributes_from_name("display_name") + editor = connection.schema_editor() + editor.start() + editor.alter_field( + Author, + Author._meta.get_field_by_name("name")[0], + new_field, + ) + editor.commit() + # Ensure the field is right afterwards + columns = self.column_classes(Author) + self.assertEqual(columns['display_name'][0], "CharField") + self.assertEqual(columns['display_name'][1][3], 254) + self.assertNotIn("name", columns) + + def test_m2m(self): + """ + Tests adding/removing M2M fields on models + """ + # Create the tables + editor = connection.schema_editor() + editor.start() + editor.create_model(AuthorWithM2M) + editor.create_model(Tag) + editor.commit() + # Create an M2M field + new_field = ManyToManyField("schema.Tag", related_name="authors") + new_field.contribute_to_class(AuthorWithM2M, "tags") + # Ensure there's no m2m table there + self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through) + connection.rollback() + # Add the field + editor = connection.schema_editor() + editor.start() + editor.create_field( + Author, + new_field, + ) + editor.commit() + # Ensure there is now an m2m table there + columns = self.column_classes(new_field.rel.through) + self.assertEqual(columns['tag_id'][0], "IntegerField") + # Remove the M2M table again + editor = connection.schema_editor() + editor.start() + editor.delete_field( + Author, + new_field, + ) + editor.commit() + # Ensure there's no m2m table there + self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through) + connection.rollback() + + def test_unique(self): + """ + Tests removing and adding unique constraints to a single column. + """ + # Create the table + editor = connection.schema_editor() + editor.start() + editor.create_model(Tag) + editor.commit() + # Ensure the field is unique to begin with + Tag.objects.create(title="foo", slug="foo") + self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo") + connection.rollback() + # Alter the slug field to be non-unique + new_field = SlugField(unique=False) + new_field.set_attributes_from_name("slug") + editor = connection.schema_editor() + editor.start() + editor.alter_field( + Tag, + Tag._meta.get_field_by_name("slug")[0], + new_field, + ) + editor.commit() + # Ensure the field is no longer unique + Tag.objects.create(title="foo", slug="foo") + Tag.objects.create(title="bar", slug="foo") + connection.rollback() + # Alter the slug field to be non-unique + new_new_field = SlugField(unique=True) + new_new_field.set_attributes_from_name("slug") + editor = connection.schema_editor() + editor.start() + editor.alter_field( + Tag, + new_field, + new_new_field, + ) + editor.commit() + # Ensure the field is unique again + Tag.objects.create(title="foo", slug="foo") + self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo") + connection.rollback()