From 959a3f9791d780062c4efe8765404a8ef95e87f0 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Tue, 19 Jun 2012 13:25:22 +0100 Subject: [PATCH] Add some field schema alteration methods and tests. --- django/db/backends/__init__.py | 3 + .../db/backends/postgresql_psycopg2/base.py | 1 + django/db/backends/schema.py | 280 +++++++++++++++--- tests/modeltests/schema/tests.py | 74 ++++- 4 files changed, 309 insertions(+), 49 deletions(-) diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index ed2a54277f..0c1905c6b8 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -419,6 +419,9 @@ class BaseDatabaseFeatures(object): # Can we roll back DDL in a transaction? can_rollback_ddl = False + # Can we issue more than one ALTER COLUMN clause in an ALTER TABLE? + supports_combined_alters = False + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 6c56bb9c91..ebb4109f79 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -85,6 +85,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_tablespaces = True can_distinct_on_fields = True can_rollback_ddl = True + supports_combined_alters = True class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'postgresql' diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index 73a9b99b50..bf838d2094 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -5,6 +5,7 @@ from django.conf import settings from django.db import transaction from django.db.utils import load_backend from django.utils.log import getLogger +from django.db.models.fields.related import ManyToManyField logger = getLogger('django.db.backends.schema') @@ -29,11 +30,15 @@ class BaseDatabaseSchemaEditor(object): sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s" sql_delete_table = "DROP TABLE %(table)s CASCADE" - sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(definition)s" + sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s" + sql_alter_column = "ALTER TABLE %(table)s %(changes)s" sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s" sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL" sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL" - sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE;" + sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s" + sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT" + sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" + sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" @@ -91,50 +96,7 @@ class BaseDatabaseSchemaEditor(object): def quote_name(self, name): return self.connection.ops.quote_name(name) - # Actions - - def create_model(self, model): - """ - Takes a model and creates a table for it in the database. - Will also create any accompanying indexes or unique constraints. - """ - # Do nothing if this is an unmanaged or proxy model - if not model._meta.managed or model._meta.proxy: - return [], {} - # Create column SQL, add FK deferreds if needed - column_sqls = [] - for field in model._meta.local_fields: - # SQL - definition = self.column_sql(model, field) - if definition is None: - continue - column_sqls.append("%s %s" % ( - self.quote_name(field.column), - definition, - )) - # FK - if field.rel: - to_table = field.rel.to._meta.db_table - to_column = field.rel.to._meta.get_field(field.rel.field_name).column - self.deferred_sql.append( - self.sql_create_fk % { - "name": '%s_refs_%s_%x' % ( - field.column, - to_column, - abs(hash((model._meta.db_table, to_table))) - ), - "table": self.quote_name(model._meta.db_table), - "column": self.quote_name(field.column), - "to_table": self.quote_name(to_table), - "to_column": self.quote_name(to_column), - } - ) - # Make the table - sql = self.sql_create_table % { - "table": model._meta.db_table, - "definition": ", ".join(column_sqls) - } - self.execute(sql) + # Field <-> database mapping functions def column_sql(self, model, field, include_default=False): """ @@ -143,6 +105,7 @@ class BaseDatabaseSchemaEditor(object): """ # Get the column's type and use that as the basis of the SQL sql = field.db_type(connection=self.connection) + params = [] # Check for fields that aren't actually columns (e.g. M2M) if sql is None: return None @@ -168,11 +131,232 @@ class BaseDatabaseSchemaEditor(object): sql += " UNIQUE" # If we were told to include a default value, do so if include_default: - raise NotImplementedError() + sql += " DEFAULT %s" + params += [self.effective_default(field)] # Return the sql - return sql + return sql, params + + def effective_default(self, field): + "Returns a field's effective database default value" + if field.has_default(): + default = field.get_default() + elif not field.null and field.blank and field.empty_strings_allowed: + default = "" + else: + default = None + # If it's a callable, call it + if callable(default): + default = default() + return default + + # Actions + + def create_model(self, model): + """ + Takes a model and creates a table for it in the database. + Will also create any accompanying indexes or unique constraints. + """ + # Do nothing if this is an unmanaged or proxy model + if not model._meta.managed or model._meta.proxy: + return + # Create column SQL, add FK deferreds if needed + column_sqls = [] + params = [] + for field in model._meta.local_fields: + # SQL + definition, extra_params = self.column_sql(model, field) + if definition is None: + continue + column_sqls.append("%s %s" % ( + self.quote_name(field.column), + definition, + )) + params.extend(extra_params) + # FK + if field.rel: + to_table = field.rel.to._meta.db_table + to_column = field.rel.to._meta.get_field(field.rel.field_name).column + self.deferred_sql.append( + self.sql_create_fk % { + "name": '%s_refs_%s_%x' % ( + field.column, + to_column, + abs(hash((model._meta.db_table, to_table))) + ), + "table": self.quote_name(model._meta.db_table), + "column": self.quote_name(field.column), + "to_table": self.quote_name(to_table), + "to_column": self.quote_name(to_column), + } + ) + # Make the table + sql = self.sql_create_table % { + "table": model._meta.db_table, + "definition": ", ".join(column_sqls) + } + self.execute(sql, params) def delete_model(self, model): + """ + Deletes a model from the database. + """ + # Do nothing if this is an unmanaged or proxy model + if not model._meta.managed or model._meta.proxy: + return + # Delete the table self.execute(self.sql_delete_table % { "table": self.quote_name(model._meta.db_table), }) + + def create_field(self, model, field, keep_default=False): + """ + Creates a field on a model. + Usually involves adding a column, but may involve adding a + table instead (for M2M fields) + """ + # Special-case implicit M2M tables + if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: + return self.create_model(field.rel.through) + # Get the column's definition + definition, params = self.column_sql(model, field, include_default=True) + # It might not actually have a column behind it + if definition is None: + return + # Build the SQL and run it + sql = self.sql_create_column % { + "table": self.quote_name(model._meta.db_table), + "column": self.quote_name(field.column), + "definition": definition, + } + self.execute(sql, params) + # Drop the default if we need to + # (Django usually does not use in-database defaults) + if not keep_default and field.default is not None: + sql = self.sql_alter_column % { + "table": self.quote_name(model._meta.db_table), + "changes": self.sql_alter_column_no_default % { + "column": self.quote_name(field.column), + } + } + # Add any FK constraints later + if field.rel: + to_table = field.rel.to._meta.db_table + to_column = field.rel.to._meta.get_field(field.rel.field_name).column + self.deferred_sql.append( + self.sql_create_fk % { + "name": '%s_refs_%s_%x' % ( + field.column, + to_column, + abs(hash((model._meta.db_table, to_table))) + ), + "table": self.quote_name(model._meta.db_table), + "column": self.quote_name(field.column), + "to_table": self.quote_name(to_table), + "to_column": self.quote_name(to_column), + } + ) + + def delete_field(self, model, field): + """ + Removes a field from a model. Usually involves deleting a column, + but for M2Ms may involve deleting a table. + """ + # Special-case implicit M2M tables + if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: + return self.delete_model(field.rel.through) + # Get the column's definition + definition, params = self.column_sql(model, field) + # It might not actually have a column behind it + if definition is None: + return + # Delete the column + sql = self.sql_delete_column % { + "table": self.quote_name(model._meta.db_table), + "column": self.quote_name(field.column), + } + self.execute(sql) + + def alter_field(self, model, old_field, new_field): + """ + Allows a field's type, uniqueness, nullability, default, column, + constraints etc. to be modified. + Requires a copy of the old field as well so we can only perform + changes that are required. + """ + # Ensure this field is even column-based + old_type = old_field.db_type(connection=self.connection) + new_type = new_field.db_type(connection=self.connection) + if old_type is None and new_type is None: + # TODO: Handle M2M fields being repointed + return + elif old_type is None or new_type is None: + raise ValueError("Cannot alter field %s into %s - they are not compatible types" % ( + old_field, + new_field, + )) + # First, have they renamed the column? + if old_field.column != new_field.column: + self.execute(self.sql_rename_column % { + "table": self.quote_name(model._meta.db_table), + "old_column": self.quote_name(old_field.column), + "new_column": self.quote_name(new_field.column), + }) + # Next, start accumulating actions to do + actions = [] + # Type change? + if old_type != new_type: + actions.append(( + self.sql_alter_column_type % { + "column": self.quote_name(new_field.column), + "type": new_type, + }, + [], + )) + # Default change? + old_default = self.effective_default(old_field) + new_default = self.effective_default(new_field) + if old_default != new_default: + if new_default is None: + actions.append(( + self.sql_alter_column_no_default % { + "column": self.quote_name(new_field.column), + }, + [], + )) + else: + actions.append(( + self.sql_alter_column_default % { + "column": self.quote_name(new_field.column), + "default": "%s", + }, + [new_default], + )) + # Nullability change? + if old_field.null != new_field.null: + if new_field.null: + actions.append(( + self.sql_alter_column_null % { + "column": self.quote_name(new_field.column), + }, + [], + )) + else: + actions.append(( + self.sql_alter_column_null % { + "column": self.quote_name(new_field.column), + }, + [], + )) + # Combine actions together if we can (e.g. postgres) + if self.connection.features.supports_combined_alters: + sql, params = tuple(zip(*actions)) + actions = [(", ".join(sql), params)] + # Apply those actions + for sql, params in actions: + self.execute( + self.sql_alter_column % { + "table": self.quote_name(model._meta.db_table), + "changes": sql, + }, + params, + ) diff --git a/tests/modeltests/schema/tests.py b/tests/modeltests/schema/tests.py index 6d5d27cdf1..83b2dabd45 100644 --- a/tests/modeltests/schema/tests.py +++ b/tests/modeltests/schema/tests.py @@ -2,8 +2,9 @@ from __future__ import absolute_import import copy import datetime from django.test import TestCase -from django.db.models.loading import cache from django.db import connection, DatabaseError, IntegrityError +from django.db.models.fields import IntegerField, TextField +from django.db.models.loading import cache from .models import Author, Book @@ -18,6 +19,8 @@ class SchemaTests(TestCase): models = [Author, Book] + # Utility functions + def setUp(self): # Make sure we're in manual transaction mode connection.commit_unless_managed() @@ -51,6 +54,18 @@ class SchemaTests(TestCase): cache.app_store = self.old_app_store cache._get_models_cache = {} + def column_classes(self, model): + cursor = connection.cursor() + return dict( + (d[0], (connection.introspection.get_field_type(d[1], d), d)) + for d in connection.introspection.get_table_description( + cursor, + model._meta.db_table, + ) + ) + + # Tests + def test_creation_deletion(self): """ Tries creating a model's table, and then deleting it. @@ -100,3 +115,60 @@ class SchemaTests(TestCase): pub_date = datetime.datetime.now(), ) connection.commit() + + def test_create_field(self): + """ + Tests adding fields to models + """ + # Create the table + editor = connection.schema_editor() + editor.start() + editor.create_model(Author) + editor.commit() + # Ensure there's no age field + columns = self.column_classes(Author) + self.assertNotIn("age", columns) + # Alter the name field to a TextField + new_field = IntegerField(null=True) + new_field.set_attributes_from_name("age") + editor = connection.schema_editor() + editor.start() + editor.create_field( + Author, + new_field, + ) + editor.commit() + # Ensure the field is right afterwards + columns = self.column_classes(Author) + self.assertEqual(columns['age'][0], "IntegerField") + self.assertEqual(columns['age'][1][6], True) + + def test_alter(self): + """ + Tests simple altering of fields + """ + # Create the table + editor = connection.schema_editor() + editor.start() + editor.create_model(Author) + editor.commit() + # Ensure the field is right to begin with + columns = self.column_classes(Author) + self.assertEqual(columns['name'][0], "CharField") + self.assertEqual(columns['name'][1][3], 255) + self.assertEqual(columns['name'][1][6], False) + # Alter the name field to a TextField + new_field = TextField(null=True) + new_field.set_attributes_from_name("name") + editor = connection.schema_editor() + editor.start() + editor.alter_field( + Author, + Author._meta.get_field_by_name("name")[0], + new_field, + ) + editor.commit() + # Ensure the field is right afterwards + columns = self.column_classes(Author) + self.assertEqual(columns['name'][0], "TextField") + self.assertEqual(columns['name'][1][6], True)