Add some field schema alteration methods and tests.
This commit is contained in:
parent
8ba5bf3198
commit
959a3f9791
|
@ -419,6 +419,9 @@ class BaseDatabaseFeatures(object):
|
||||||
# Can we roll back DDL in a transaction?
|
# Can we roll back DDL in a transaction?
|
||||||
can_rollback_ddl = False
|
can_rollback_ddl = False
|
||||||
|
|
||||||
|
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||||
|
supports_combined_alters = False
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
supports_tablespaces = True
|
supports_tablespaces = True
|
||||||
can_distinct_on_fields = True
|
can_distinct_on_fields = True
|
||||||
can_rollback_ddl = True
|
can_rollback_ddl = True
|
||||||
|
supports_combined_alters = True
|
||||||
|
|
||||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
vendor = 'postgresql'
|
vendor = 'postgresql'
|
||||||
|
|
|
@ -5,6 +5,7 @@ from django.conf import settings
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.utils import load_backend
|
from django.db.utils import load_backend
|
||||||
from django.utils.log import getLogger
|
from django.utils.log import getLogger
|
||||||
|
from django.db.models.fields.related import ManyToManyField
|
||||||
|
|
||||||
logger = getLogger('django.db.backends.schema')
|
logger = getLogger('django.db.backends.schema')
|
||||||
|
|
||||||
|
@ -29,11 +30,15 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
|
sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
|
||||||
sql_delete_table = "DROP TABLE %(table)s CASCADE"
|
sql_delete_table = "DROP TABLE %(table)s CASCADE"
|
||||||
|
|
||||||
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(definition)s"
|
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
|
||||||
|
sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
|
||||||
sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
|
sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
|
||||||
sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL"
|
sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL"
|
||||||
sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL"
|
sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL"
|
||||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE;"
|
sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s"
|
||||||
|
sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
|
||||||
|
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
|
||||||
|
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
|
||||||
|
|
||||||
sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||||
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||||
|
@ -91,50 +96,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
def quote_name(self, name):
|
def quote_name(self, name):
|
||||||
return self.connection.ops.quote_name(name)
|
return self.connection.ops.quote_name(name)
|
||||||
|
|
||||||
# Actions
|
# Field <-> database mapping functions
|
||||||
|
|
||||||
def create_model(self, model):
|
|
||||||
"""
|
|
||||||
Takes a model and creates a table for it in the database.
|
|
||||||
Will also create any accompanying indexes or unique constraints.
|
|
||||||
"""
|
|
||||||
# Do nothing if this is an unmanaged or proxy model
|
|
||||||
if not model._meta.managed or model._meta.proxy:
|
|
||||||
return [], {}
|
|
||||||
# Create column SQL, add FK deferreds if needed
|
|
||||||
column_sqls = []
|
|
||||||
for field in model._meta.local_fields:
|
|
||||||
# SQL
|
|
||||||
definition = self.column_sql(model, field)
|
|
||||||
if definition is None:
|
|
||||||
continue
|
|
||||||
column_sqls.append("%s %s" % (
|
|
||||||
self.quote_name(field.column),
|
|
||||||
definition,
|
|
||||||
))
|
|
||||||
# FK
|
|
||||||
if field.rel:
|
|
||||||
to_table = field.rel.to._meta.db_table
|
|
||||||
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
|
|
||||||
self.deferred_sql.append(
|
|
||||||
self.sql_create_fk % {
|
|
||||||
"name": '%s_refs_%s_%x' % (
|
|
||||||
field.column,
|
|
||||||
to_column,
|
|
||||||
abs(hash((model._meta.db_table, to_table)))
|
|
||||||
),
|
|
||||||
"table": self.quote_name(model._meta.db_table),
|
|
||||||
"column": self.quote_name(field.column),
|
|
||||||
"to_table": self.quote_name(to_table),
|
|
||||||
"to_column": self.quote_name(to_column),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# Make the table
|
|
||||||
sql = self.sql_create_table % {
|
|
||||||
"table": model._meta.db_table,
|
|
||||||
"definition": ", ".join(column_sqls)
|
|
||||||
}
|
|
||||||
self.execute(sql)
|
|
||||||
|
|
||||||
def column_sql(self, model, field, include_default=False):
|
def column_sql(self, model, field, include_default=False):
|
||||||
"""
|
"""
|
||||||
|
@ -143,6 +105,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
"""
|
"""
|
||||||
# Get the column's type and use that as the basis of the SQL
|
# Get the column's type and use that as the basis of the SQL
|
||||||
sql = field.db_type(connection=self.connection)
|
sql = field.db_type(connection=self.connection)
|
||||||
|
params = []
|
||||||
# Check for fields that aren't actually columns (e.g. M2M)
|
# Check for fields that aren't actually columns (e.g. M2M)
|
||||||
if sql is None:
|
if sql is None:
|
||||||
return None
|
return None
|
||||||
|
@ -168,11 +131,232 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
sql += " UNIQUE"
|
sql += " UNIQUE"
|
||||||
# If we were told to include a default value, do so
|
# If we were told to include a default value, do so
|
||||||
if include_default:
|
if include_default:
|
||||||
raise NotImplementedError()
|
sql += " DEFAULT %s"
|
||||||
|
params += [self.effective_default(field)]
|
||||||
# Return the sql
|
# Return the sql
|
||||||
return sql
|
return sql, params
|
||||||
|
|
||||||
|
def effective_default(self, field):
|
||||||
|
"Returns a field's effective database default value"
|
||||||
|
if field.has_default():
|
||||||
|
default = field.get_default()
|
||||||
|
elif not field.null and field.blank and field.empty_strings_allowed:
|
||||||
|
default = ""
|
||||||
|
else:
|
||||||
|
default = None
|
||||||
|
# If it's a callable, call it
|
||||||
|
if callable(default):
|
||||||
|
default = default()
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Actions
|
||||||
|
|
||||||
|
def create_model(self, model):
|
||||||
|
"""
|
||||||
|
Takes a model and creates a table for it in the database.
|
||||||
|
Will also create any accompanying indexes or unique constraints.
|
||||||
|
"""
|
||||||
|
# Do nothing if this is an unmanaged or proxy model
|
||||||
|
if not model._meta.managed or model._meta.proxy:
|
||||||
|
return
|
||||||
|
# Create column SQL, add FK deferreds if needed
|
||||||
|
column_sqls = []
|
||||||
|
params = []
|
||||||
|
for field in model._meta.local_fields:
|
||||||
|
# SQL
|
||||||
|
definition, extra_params = self.column_sql(model, field)
|
||||||
|
if definition is None:
|
||||||
|
continue
|
||||||
|
column_sqls.append("%s %s" % (
|
||||||
|
self.quote_name(field.column),
|
||||||
|
definition,
|
||||||
|
))
|
||||||
|
params.extend(extra_params)
|
||||||
|
# FK
|
||||||
|
if field.rel:
|
||||||
|
to_table = field.rel.to._meta.db_table
|
||||||
|
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
|
||||||
|
self.deferred_sql.append(
|
||||||
|
self.sql_create_fk % {
|
||||||
|
"name": '%s_refs_%s_%x' % (
|
||||||
|
field.column,
|
||||||
|
to_column,
|
||||||
|
abs(hash((model._meta.db_table, to_table)))
|
||||||
|
),
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"column": self.quote_name(field.column),
|
||||||
|
"to_table": self.quote_name(to_table),
|
||||||
|
"to_column": self.quote_name(to_column),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Make the table
|
||||||
|
sql = self.sql_create_table % {
|
||||||
|
"table": model._meta.db_table,
|
||||||
|
"definition": ", ".join(column_sqls)
|
||||||
|
}
|
||||||
|
self.execute(sql, params)
|
||||||
|
|
||||||
def delete_model(self, model):
|
def delete_model(self, model):
|
||||||
|
"""
|
||||||
|
Deletes a model from the database.
|
||||||
|
"""
|
||||||
|
# Do nothing if this is an unmanaged or proxy model
|
||||||
|
if not model._meta.managed or model._meta.proxy:
|
||||||
|
return
|
||||||
|
# Delete the table
|
||||||
self.execute(self.sql_delete_table % {
|
self.execute(self.sql_delete_table % {
|
||||||
"table": self.quote_name(model._meta.db_table),
|
"table": self.quote_name(model._meta.db_table),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def create_field(self, model, field, keep_default=False):
|
||||||
|
"""
|
||||||
|
Creates a field on a model.
|
||||||
|
Usually involves adding a column, but may involve adding a
|
||||||
|
table instead (for M2M fields)
|
||||||
|
"""
|
||||||
|
# Special-case implicit M2M tables
|
||||||
|
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
|
||||||
|
return self.create_model(field.rel.through)
|
||||||
|
# Get the column's definition
|
||||||
|
definition, params = self.column_sql(model, field, include_default=True)
|
||||||
|
# It might not actually have a column behind it
|
||||||
|
if definition is None:
|
||||||
|
return
|
||||||
|
# Build the SQL and run it
|
||||||
|
sql = self.sql_create_column % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"column": self.quote_name(field.column),
|
||||||
|
"definition": definition,
|
||||||
|
}
|
||||||
|
self.execute(sql, params)
|
||||||
|
# Drop the default if we need to
|
||||||
|
# (Django usually does not use in-database defaults)
|
||||||
|
if not keep_default and field.default is not None:
|
||||||
|
sql = self.sql_alter_column % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"changes": self.sql_alter_column_no_default % {
|
||||||
|
"column": self.quote_name(field.column),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Add any FK constraints later
|
||||||
|
if field.rel:
|
||||||
|
to_table = field.rel.to._meta.db_table
|
||||||
|
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
|
||||||
|
self.deferred_sql.append(
|
||||||
|
self.sql_create_fk % {
|
||||||
|
"name": '%s_refs_%s_%x' % (
|
||||||
|
field.column,
|
||||||
|
to_column,
|
||||||
|
abs(hash((model._meta.db_table, to_table)))
|
||||||
|
),
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"column": self.quote_name(field.column),
|
||||||
|
"to_table": self.quote_name(to_table),
|
||||||
|
"to_column": self.quote_name(to_column),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_field(self, model, field):
|
||||||
|
"""
|
||||||
|
Removes a field from a model. Usually involves deleting a column,
|
||||||
|
but for M2Ms may involve deleting a table.
|
||||||
|
"""
|
||||||
|
# Special-case implicit M2M tables
|
||||||
|
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
|
||||||
|
return self.delete_model(field.rel.through)
|
||||||
|
# Get the column's definition
|
||||||
|
definition, params = self.column_sql(model, field)
|
||||||
|
# It might not actually have a column behind it
|
||||||
|
if definition is None:
|
||||||
|
return
|
||||||
|
# Delete the column
|
||||||
|
sql = self.sql_delete_column % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"column": self.quote_name(field.column),
|
||||||
|
}
|
||||||
|
self.execute(sql)
|
||||||
|
|
||||||
|
def alter_field(self, model, old_field, new_field):
|
||||||
|
"""
|
||||||
|
Allows a field's type, uniqueness, nullability, default, column,
|
||||||
|
constraints etc. to be modified.
|
||||||
|
Requires a copy of the old field as well so we can only perform
|
||||||
|
changes that are required.
|
||||||
|
"""
|
||||||
|
# Ensure this field is even column-based
|
||||||
|
old_type = old_field.db_type(connection=self.connection)
|
||||||
|
new_type = new_field.db_type(connection=self.connection)
|
||||||
|
if old_type is None and new_type is None:
|
||||||
|
# TODO: Handle M2M fields being repointed
|
||||||
|
return
|
||||||
|
elif old_type is None or new_type is None:
|
||||||
|
raise ValueError("Cannot alter field %s into %s - they are not compatible types" % (
|
||||||
|
old_field,
|
||||||
|
new_field,
|
||||||
|
))
|
||||||
|
# First, have they renamed the column?
|
||||||
|
if old_field.column != new_field.column:
|
||||||
|
self.execute(self.sql_rename_column % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"old_column": self.quote_name(old_field.column),
|
||||||
|
"new_column": self.quote_name(new_field.column),
|
||||||
|
})
|
||||||
|
# Next, start accumulating actions to do
|
||||||
|
actions = []
|
||||||
|
# Type change?
|
||||||
|
if old_type != new_type:
|
||||||
|
actions.append((
|
||||||
|
self.sql_alter_column_type % {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
"type": new_type,
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
))
|
||||||
|
# Default change?
|
||||||
|
old_default = self.effective_default(old_field)
|
||||||
|
new_default = self.effective_default(new_field)
|
||||||
|
if old_default != new_default:
|
||||||
|
if new_default is None:
|
||||||
|
actions.append((
|
||||||
|
self.sql_alter_column_no_default % {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
actions.append((
|
||||||
|
self.sql_alter_column_default % {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
"default": "%s",
|
||||||
|
},
|
||||||
|
[new_default],
|
||||||
|
))
|
||||||
|
# Nullability change?
|
||||||
|
if old_field.null != new_field.null:
|
||||||
|
if new_field.null:
|
||||||
|
actions.append((
|
||||||
|
self.sql_alter_column_null % {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
actions.append((
|
||||||
|
self.sql_alter_column_null % {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
))
|
||||||
|
# Combine actions together if we can (e.g. postgres)
|
||||||
|
if self.connection.features.supports_combined_alters:
|
||||||
|
sql, params = tuple(zip(*actions))
|
||||||
|
actions = [(", ".join(sql), params)]
|
||||||
|
# Apply those actions
|
||||||
|
for sql, params in actions:
|
||||||
|
self.execute(
|
||||||
|
self.sql_alter_column % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"changes": sql,
|
||||||
|
},
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
|
@ -2,8 +2,9 @@ from __future__ import absolute_import
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.db.models.loading import cache
|
|
||||||
from django.db import connection, DatabaseError, IntegrityError
|
from django.db import connection, DatabaseError, IntegrityError
|
||||||
|
from django.db.models.fields import IntegerField, TextField
|
||||||
|
from django.db.models.loading import cache
|
||||||
from .models import Author, Book
|
from .models import Author, Book
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,6 +19,8 @@ class SchemaTests(TestCase):
|
||||||
|
|
||||||
models = [Author, Book]
|
models = [Author, Book]
|
||||||
|
|
||||||
|
# Utility functions
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Make sure we're in manual transaction mode
|
# Make sure we're in manual transaction mode
|
||||||
connection.commit_unless_managed()
|
connection.commit_unless_managed()
|
||||||
|
@ -51,6 +54,18 @@ class SchemaTests(TestCase):
|
||||||
cache.app_store = self.old_app_store
|
cache.app_store = self.old_app_store
|
||||||
cache._get_models_cache = {}
|
cache._get_models_cache = {}
|
||||||
|
|
||||||
|
def column_classes(self, model):
|
||||||
|
cursor = connection.cursor()
|
||||||
|
return dict(
|
||||||
|
(d[0], (connection.introspection.get_field_type(d[1], d), d))
|
||||||
|
for d in connection.introspection.get_table_description(
|
||||||
|
cursor,
|
||||||
|
model._meta.db_table,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
|
||||||
def test_creation_deletion(self):
|
def test_creation_deletion(self):
|
||||||
"""
|
"""
|
||||||
Tries creating a model's table, and then deleting it.
|
Tries creating a model's table, and then deleting it.
|
||||||
|
@ -100,3 +115,60 @@ class SchemaTests(TestCase):
|
||||||
pub_date = datetime.datetime.now(),
|
pub_date = datetime.datetime.now(),
|
||||||
)
|
)
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
def test_create_field(self):
|
||||||
|
"""
|
||||||
|
Tests adding fields to models
|
||||||
|
"""
|
||||||
|
# Create the table
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_model(Author)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure there's no age field
|
||||||
|
columns = self.column_classes(Author)
|
||||||
|
self.assertNotIn("age", columns)
|
||||||
|
# Alter the name field to a TextField
|
||||||
|
new_field = IntegerField(null=True)
|
||||||
|
new_field.set_attributes_from_name("age")
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_field(
|
||||||
|
Author,
|
||||||
|
new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is right afterwards
|
||||||
|
columns = self.column_classes(Author)
|
||||||
|
self.assertEqual(columns['age'][0], "IntegerField")
|
||||||
|
self.assertEqual(columns['age'][1][6], True)
|
||||||
|
|
||||||
|
def test_alter(self):
|
||||||
|
"""
|
||||||
|
Tests simple altering of fields
|
||||||
|
"""
|
||||||
|
# Create the table
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_model(Author)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is right to begin with
|
||||||
|
columns = self.column_classes(Author)
|
||||||
|
self.assertEqual(columns['name'][0], "CharField")
|
||||||
|
self.assertEqual(columns['name'][1][3], 255)
|
||||||
|
self.assertEqual(columns['name'][1][6], False)
|
||||||
|
# Alter the name field to a TextField
|
||||||
|
new_field = TextField(null=True)
|
||||||
|
new_field.set_attributes_from_name("name")
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.alter_field(
|
||||||
|
Author,
|
||||||
|
Author._meta.get_field_by_name("name")[0],
|
||||||
|
new_field,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the field is right afterwards
|
||||||
|
columns = self.column_classes(Author)
|
||||||
|
self.assertEqual(columns['name'][0], "TextField")
|
||||||
|
self.assertEqual(columns['name'][1][6], True)
|
||||||
|
|
Loading…
Reference in New Issue