Add some field schema alteration methods and tests.

This commit is contained in:
Andrew Godwin 2012-06-19 13:25:22 +01:00
parent 8ba5bf3198
commit 959a3f9791
4 changed files with 309 additions and 49 deletions

View File

@ -419,6 +419,9 @@ class BaseDatabaseFeatures(object):
# Can we roll back DDL in a transaction?
can_rollback_ddl = False
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
supports_combined_alters = False
def __init__(self, connection):
self.connection = connection

View File

@ -85,6 +85,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_tablespaces = True
can_distinct_on_fields = True
can_rollback_ddl = True
supports_combined_alters = True
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql'

View File

@ -5,6 +5,7 @@ from django.conf import settings
from django.db import transaction
from django.db.utils import load_backend
from django.utils.log import getLogger
from django.db.models.fields.related import ManyToManyField
logger = getLogger('django.db.backends.schema')
@ -29,11 +30,15 @@ class BaseDatabaseSchemaEditor(object):
sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
sql_delete_table = "DROP TABLE %(table)s CASCADE"
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(definition)s"
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL"
sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE;"
sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s"
sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
@ -91,50 +96,7 @@ class BaseDatabaseSchemaEditor(object):
def quote_name(self, name):
return self.connection.ops.quote_name(name)
# Actions
def create_model(self, model):
"""
Takes a model and creates a table for it in the database.
Will also create any accompanying indexes or unique constraints.
"""
# Do nothing if this is an unmanaged or proxy model
if not model._meta.managed or model._meta.proxy:
return [], {}
# Create column SQL, add FK deferreds if needed
column_sqls = []
for field in model._meta.local_fields:
# SQL
definition = self.column_sql(model, field)
if definition is None:
continue
column_sqls.append("%s %s" % (
self.quote_name(field.column),
definition,
))
# FK
if field.rel:
to_table = field.rel.to._meta.db_table
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
self.deferred_sql.append(
self.sql_create_fk % {
"name": '%s_refs_%s_%x' % (
field.column,
to_column,
abs(hash((model._meta.db_table, to_table)))
),
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
"to_table": self.quote_name(to_table),
"to_column": self.quote_name(to_column),
}
)
# Make the table
sql = self.sql_create_table % {
"table": model._meta.db_table,
"definition": ", ".join(column_sqls)
}
self.execute(sql)
# Field <-> database mapping functions
def column_sql(self, model, field, include_default=False):
"""
@ -143,6 +105,7 @@ class BaseDatabaseSchemaEditor(object):
"""
# Get the column's type and use that as the basis of the SQL
sql = field.db_type(connection=self.connection)
params = []
# Check for fields that aren't actually columns (e.g. M2M)
if sql is None:
return None
@ -168,11 +131,232 @@ class BaseDatabaseSchemaEditor(object):
sql += " UNIQUE"
# If we were told to include a default value, do so
if include_default:
raise NotImplementedError()
sql += " DEFAULT %s"
params += [self.effective_default(field)]
# Return the sql
return sql
return sql, params
def effective_default(self, field):
"Returns a field's effective database default value"
if field.has_default():
default = field.get_default()
elif not field.null and field.blank and field.empty_strings_allowed:
default = ""
else:
default = None
# If it's a callable, call it
if callable(default):
default = default()
return default
# Actions
def create_model(self, model):
"""
Takes a model and creates a table for it in the database.
Will also create any accompanying indexes or unique constraints.
"""
# Do nothing if this is an unmanaged or proxy model
if not model._meta.managed or model._meta.proxy:
return
# Create column SQL, add FK deferreds if needed
column_sqls = []
params = []
for field in model._meta.local_fields:
# SQL
definition, extra_params = self.column_sql(model, field)
if definition is None:
continue
column_sqls.append("%s %s" % (
self.quote_name(field.column),
definition,
))
params.extend(extra_params)
# FK
if field.rel:
to_table = field.rel.to._meta.db_table
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
self.deferred_sql.append(
self.sql_create_fk % {
"name": '%s_refs_%s_%x' % (
field.column,
to_column,
abs(hash((model._meta.db_table, to_table)))
),
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
"to_table": self.quote_name(to_table),
"to_column": self.quote_name(to_column),
}
)
# Make the table
sql = self.sql_create_table % {
"table": model._meta.db_table,
"definition": ", ".join(column_sqls)
}
self.execute(sql, params)
def delete_model(self, model):
"""
Deletes a model from the database.
"""
# Do nothing if this is an unmanaged or proxy model
if not model._meta.managed or model._meta.proxy:
return
# Delete the table
self.execute(self.sql_delete_table % {
"table": self.quote_name(model._meta.db_table),
})
def create_field(self, model, field, keep_default=False):
"""
Creates a field on a model.
Usually involves adding a column, but may involve adding a
table instead (for M2M fields)
"""
# Special-case implicit M2M tables
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
return self.create_model(field.rel.through)
# Get the column's definition
definition, params = self.column_sql(model, field, include_default=True)
# It might not actually have a column behind it
if definition is None:
return
# Build the SQL and run it
sql = self.sql_create_column % {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
"definition": definition,
}
self.execute(sql, params)
# Drop the default if we need to
# (Django usually does not use in-database defaults)
if not keep_default and field.default is not None:
sql = self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": self.sql_alter_column_no_default % {
"column": self.quote_name(field.column),
}
}
# Add any FK constraints later
if field.rel:
to_table = field.rel.to._meta.db_table
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
self.deferred_sql.append(
self.sql_create_fk % {
"name": '%s_refs_%s_%x' % (
field.column,
to_column,
abs(hash((model._meta.db_table, to_table)))
),
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
"to_table": self.quote_name(to_table),
"to_column": self.quote_name(to_column),
}
)
def delete_field(self, model, field):
"""
Removes a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table.
"""
# Special-case implicit M2M tables
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
return self.delete_model(field.rel.through)
# Get the column's definition
definition, params = self.column_sql(model, field)
# It might not actually have a column behind it
if definition is None:
return
# Delete the column
sql = self.sql_delete_column % {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
}
self.execute(sql)
def alter_field(self, model, old_field, new_field):
"""
Allows a field's type, uniqueness, nullability, default, column,
constraints etc. to be modified.
Requires a copy of the old field as well so we can only perform
changes that are required.
"""
# Ensure this field is even column-based
old_type = old_field.db_type(connection=self.connection)
new_type = new_field.db_type(connection=self.connection)
if old_type is None and new_type is None:
# TODO: Handle M2M fields being repointed
return
elif old_type is None or new_type is None:
raise ValueError("Cannot alter field %s into %s - they are not compatible types" % (
old_field,
new_field,
))
# First, have they renamed the column?
if old_field.column != new_field.column:
self.execute(self.sql_rename_column % {
"table": self.quote_name(model._meta.db_table),
"old_column": self.quote_name(old_field.column),
"new_column": self.quote_name(new_field.column),
})
# Next, start accumulating actions to do
actions = []
# Type change?
if old_type != new_type:
actions.append((
self.sql_alter_column_type % {
"column": self.quote_name(new_field.column),
"type": new_type,
},
[],
))
# Default change?
old_default = self.effective_default(old_field)
new_default = self.effective_default(new_field)
if old_default != new_default:
if new_default is None:
actions.append((
self.sql_alter_column_no_default % {
"column": self.quote_name(new_field.column),
},
[],
))
else:
actions.append((
self.sql_alter_column_default % {
"column": self.quote_name(new_field.column),
"default": "%s",
},
[new_default],
))
# Nullability change?
if old_field.null != new_field.null:
if new_field.null:
actions.append((
self.sql_alter_column_null % {
"column": self.quote_name(new_field.column),
},
[],
))
else:
actions.append((
self.sql_alter_column_null % {
"column": self.quote_name(new_field.column),
},
[],
))
# Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters:
sql, params = tuple(zip(*actions))
actions = [(", ".join(sql), params)]
# Apply those actions
for sql, params in actions:
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
params,
)

View File

@ -2,8 +2,9 @@ from __future__ import absolute_import
import copy
import datetime
from django.test import TestCase
from django.db.models.loading import cache
from django.db import connection, DatabaseError, IntegrityError
from django.db.models.fields import IntegerField, TextField
from django.db.models.loading import cache
from .models import Author, Book
@ -18,6 +19,8 @@ class SchemaTests(TestCase):
models = [Author, Book]
# Utility functions
def setUp(self):
# Make sure we're in manual transaction mode
connection.commit_unless_managed()
@ -51,6 +54,18 @@ class SchemaTests(TestCase):
cache.app_store = self.old_app_store
cache._get_models_cache = {}
def column_classes(self, model):
cursor = connection.cursor()
return dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description(
cursor,
model._meta.db_table,
)
)
# Tests
def test_creation_deletion(self):
"""
Tries creating a model's table, and then deleting it.
@ -100,3 +115,60 @@ class SchemaTests(TestCase):
pub_date = datetime.datetime.now(),
)
connection.commit()
def test_create_field(self):
"""
Tests adding fields to models
"""
# Create the table
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.commit()
# Ensure there's no age field
columns = self.column_classes(Author)
self.assertNotIn("age", columns)
# Alter the name field to a TextField
new_field = IntegerField(null=True)
new_field.set_attributes_from_name("age")
editor = connection.schema_editor()
editor.start()
editor.create_field(
Author,
new_field,
)
editor.commit()
# Ensure the field is right afterwards
columns = self.column_classes(Author)
self.assertEqual(columns['age'][0], "IntegerField")
self.assertEqual(columns['age'][1][6], True)
def test_alter(self):
"""
Tests simple altering of fields
"""
# Create the table
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.commit()
# Ensure the field is right to begin with
columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "CharField")
self.assertEqual(columns['name'][1][3], 255)
self.assertEqual(columns['name'][1][6], False)
# Alter the name field to a TextField
new_field = TextField(null=True)
new_field.set_attributes_from_name("name")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Author,
Author._meta.get_field_by_name("name")[0],
new_field,
)
editor.commit()
# Ensure the field is right afterwards
columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "TextField")
self.assertEqual(columns['name'][1][6], True)