Add some field schema alteration methods and tests.
This commit is contained in:
parent
8ba5bf3198
commit
959a3f9791
|
@ -419,6 +419,9 @@ class BaseDatabaseFeatures(object):
|
|||
# Can we roll back DDL in a transaction?
|
||||
can_rollback_ddl = False
|
||||
|
||||
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||
supports_combined_alters = False
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
|
|
|
@ -85,6 +85,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
supports_tablespaces = True
|
||||
can_distinct_on_fields = True
|
||||
can_rollback_ddl = True
|
||||
supports_combined_alters = True
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = 'postgresql'
|
||||
|
|
|
@ -5,6 +5,7 @@ from django.conf import settings
|
|||
from django.db import transaction
|
||||
from django.db.utils import load_backend
|
||||
from django.utils.log import getLogger
|
||||
from django.db.models.fields.related import ManyToManyField
|
||||
|
||||
logger = getLogger('django.db.backends.schema')
|
||||
|
||||
|
@ -29,11 +30,15 @@ class BaseDatabaseSchemaEditor(object):
|
|||
sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
|
||||
sql_delete_table = "DROP TABLE %(table)s CASCADE"
|
||||
|
||||
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(definition)s"
|
||||
sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
|
||||
sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
|
||||
sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
|
||||
sql_alter_column_null = "ALTER COLUMN %(column)s DROP NOT NULL"
|
||||
sql_alter_column_not_null = "ALTER COLUMN %(column)s SET NOT NULL"
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE;"
|
||||
sql_alter_column_default = "ALTER COLUMN %(column)s SET DEFAULT %(default)s"
|
||||
sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
|
||||
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
|
||||
|
||||
sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
|
@ -91,50 +96,7 @@ class BaseDatabaseSchemaEditor(object):
|
|||
def quote_name(self, name):
|
||||
return self.connection.ops.quote_name(name)
|
||||
|
||||
# Actions
|
||||
|
||||
def create_model(self, model):
|
||||
"""
|
||||
Takes a model and creates a table for it in the database.
|
||||
Will also create any accompanying indexes or unique constraints.
|
||||
"""
|
||||
# Do nothing if this is an unmanaged or proxy model
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
return [], {}
|
||||
# Create column SQL, add FK deferreds if needed
|
||||
column_sqls = []
|
||||
for field in model._meta.local_fields:
|
||||
# SQL
|
||||
definition = self.column_sql(model, field)
|
||||
if definition is None:
|
||||
continue
|
||||
column_sqls.append("%s %s" % (
|
||||
self.quote_name(field.column),
|
||||
definition,
|
||||
))
|
||||
# FK
|
||||
if field.rel:
|
||||
to_table = field.rel.to._meta.db_table
|
||||
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
|
||||
self.deferred_sql.append(
|
||||
self.sql_create_fk % {
|
||||
"name": '%s_refs_%s_%x' % (
|
||||
field.column,
|
||||
to_column,
|
||||
abs(hash((model._meta.db_table, to_table)))
|
||||
),
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
"to_table": self.quote_name(to_table),
|
||||
"to_column": self.quote_name(to_column),
|
||||
}
|
||||
)
|
||||
# Make the table
|
||||
sql = self.sql_create_table % {
|
||||
"table": model._meta.db_table,
|
||||
"definition": ", ".join(column_sqls)
|
||||
}
|
||||
self.execute(sql)
|
||||
# Field <-> database mapping functions
|
||||
|
||||
def column_sql(self, model, field, include_default=False):
|
||||
"""
|
||||
|
@ -143,6 +105,7 @@ class BaseDatabaseSchemaEditor(object):
|
|||
"""
|
||||
# Get the column's type and use that as the basis of the SQL
|
||||
sql = field.db_type(connection=self.connection)
|
||||
params = []
|
||||
# Check for fields that aren't actually columns (e.g. M2M)
|
||||
if sql is None:
|
||||
return None
|
||||
|
@ -168,11 +131,232 @@ class BaseDatabaseSchemaEditor(object):
|
|||
sql += " UNIQUE"
|
||||
# If we were told to include a default value, do so
|
||||
if include_default:
|
||||
raise NotImplementedError()
|
||||
sql += " DEFAULT %s"
|
||||
params += [self.effective_default(field)]
|
||||
# Return the sql
|
||||
return sql
|
||||
return sql, params
|
||||
|
||||
def effective_default(self, field):
|
||||
"Returns a field's effective database default value"
|
||||
if field.has_default():
|
||||
default = field.get_default()
|
||||
elif not field.null and field.blank and field.empty_strings_allowed:
|
||||
default = ""
|
||||
else:
|
||||
default = None
|
||||
# If it's a callable, call it
|
||||
if callable(default):
|
||||
default = default()
|
||||
return default
|
||||
|
||||
# Actions
|
||||
|
||||
def create_model(self, model):
|
||||
"""
|
||||
Takes a model and creates a table for it in the database.
|
||||
Will also create any accompanying indexes or unique constraints.
|
||||
"""
|
||||
# Do nothing if this is an unmanaged or proxy model
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
return
|
||||
# Create column SQL, add FK deferreds if needed
|
||||
column_sqls = []
|
||||
params = []
|
||||
for field in model._meta.local_fields:
|
||||
# SQL
|
||||
definition, extra_params = self.column_sql(model, field)
|
||||
if definition is None:
|
||||
continue
|
||||
column_sqls.append("%s %s" % (
|
||||
self.quote_name(field.column),
|
||||
definition,
|
||||
))
|
||||
params.extend(extra_params)
|
||||
# FK
|
||||
if field.rel:
|
||||
to_table = field.rel.to._meta.db_table
|
||||
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
|
||||
self.deferred_sql.append(
|
||||
self.sql_create_fk % {
|
||||
"name": '%s_refs_%s_%x' % (
|
||||
field.column,
|
||||
to_column,
|
||||
abs(hash((model._meta.db_table, to_table)))
|
||||
),
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
"to_table": self.quote_name(to_table),
|
||||
"to_column": self.quote_name(to_column),
|
||||
}
|
||||
)
|
||||
# Make the table
|
||||
sql = self.sql_create_table % {
|
||||
"table": model._meta.db_table,
|
||||
"definition": ", ".join(column_sqls)
|
||||
}
|
||||
self.execute(sql, params)
|
||||
|
||||
def delete_model(self, model):
|
||||
"""
|
||||
Deletes a model from the database.
|
||||
"""
|
||||
# Do nothing if this is an unmanaged or proxy model
|
||||
if not model._meta.managed or model._meta.proxy:
|
||||
return
|
||||
# Delete the table
|
||||
self.execute(self.sql_delete_table % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
})
|
||||
|
||||
def create_field(self, model, field, keep_default=False):
|
||||
"""
|
||||
Creates a field on a model.
|
||||
Usually involves adding a column, but may involve adding a
|
||||
table instead (for M2M fields)
|
||||
"""
|
||||
# Special-case implicit M2M tables
|
||||
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
|
||||
return self.create_model(field.rel.through)
|
||||
# Get the column's definition
|
||||
definition, params = self.column_sql(model, field, include_default=True)
|
||||
# It might not actually have a column behind it
|
||||
if definition is None:
|
||||
return
|
||||
# Build the SQL and run it
|
||||
sql = self.sql_create_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
"definition": definition,
|
||||
}
|
||||
self.execute(sql, params)
|
||||
# Drop the default if we need to
|
||||
# (Django usually does not use in-database defaults)
|
||||
if not keep_default and field.default is not None:
|
||||
sql = self.sql_alter_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"changes": self.sql_alter_column_no_default % {
|
||||
"column": self.quote_name(field.column),
|
||||
}
|
||||
}
|
||||
# Add any FK constraints later
|
||||
if field.rel:
|
||||
to_table = field.rel.to._meta.db_table
|
||||
to_column = field.rel.to._meta.get_field(field.rel.field_name).column
|
||||
self.deferred_sql.append(
|
||||
self.sql_create_fk % {
|
||||
"name": '%s_refs_%s_%x' % (
|
||||
field.column,
|
||||
to_column,
|
||||
abs(hash((model._meta.db_table, to_table)))
|
||||
),
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
"to_table": self.quote_name(to_table),
|
||||
"to_column": self.quote_name(to_column),
|
||||
}
|
||||
)
|
||||
|
||||
def delete_field(self, model, field):
|
||||
"""
|
||||
Removes a field from a model. Usually involves deleting a column,
|
||||
but for M2Ms may involve deleting a table.
|
||||
"""
|
||||
# Special-case implicit M2M tables
|
||||
if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created:
|
||||
return self.delete_model(field.rel.through)
|
||||
# Get the column's definition
|
||||
definition, params = self.column_sql(model, field)
|
||||
# It might not actually have a column behind it
|
||||
if definition is None:
|
||||
return
|
||||
# Delete the column
|
||||
sql = self.sql_delete_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
}
|
||||
self.execute(sql)
|
||||
|
||||
def alter_field(self, model, old_field, new_field):
|
||||
"""
|
||||
Allows a field's type, uniqueness, nullability, default, column,
|
||||
constraints etc. to be modified.
|
||||
Requires a copy of the old field as well so we can only perform
|
||||
changes that are required.
|
||||
"""
|
||||
# Ensure this field is even column-based
|
||||
old_type = old_field.db_type(connection=self.connection)
|
||||
new_type = new_field.db_type(connection=self.connection)
|
||||
if old_type is None and new_type is None:
|
||||
# TODO: Handle M2M fields being repointed
|
||||
return
|
||||
elif old_type is None or new_type is None:
|
||||
raise ValueError("Cannot alter field %s into %s - they are not compatible types" % (
|
||||
old_field,
|
||||
new_field,
|
||||
))
|
||||
# First, have they renamed the column?
|
||||
if old_field.column != new_field.column:
|
||||
self.execute(self.sql_rename_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"old_column": self.quote_name(old_field.column),
|
||||
"new_column": self.quote_name(new_field.column),
|
||||
})
|
||||
# Next, start accumulating actions to do
|
||||
actions = []
|
||||
# Type change?
|
||||
if old_type != new_type:
|
||||
actions.append((
|
||||
self.sql_alter_column_type % {
|
||||
"column": self.quote_name(new_field.column),
|
||||
"type": new_type,
|
||||
},
|
||||
[],
|
||||
))
|
||||
# Default change?
|
||||
old_default = self.effective_default(old_field)
|
||||
new_default = self.effective_default(new_field)
|
||||
if old_default != new_default:
|
||||
if new_default is None:
|
||||
actions.append((
|
||||
self.sql_alter_column_no_default % {
|
||||
"column": self.quote_name(new_field.column),
|
||||
},
|
||||
[],
|
||||
))
|
||||
else:
|
||||
actions.append((
|
||||
self.sql_alter_column_default % {
|
||||
"column": self.quote_name(new_field.column),
|
||||
"default": "%s",
|
||||
},
|
||||
[new_default],
|
||||
))
|
||||
# Nullability change?
|
||||
if old_field.null != new_field.null:
|
||||
if new_field.null:
|
||||
actions.append((
|
||||
self.sql_alter_column_null % {
|
||||
"column": self.quote_name(new_field.column),
|
||||
},
|
||||
[],
|
||||
))
|
||||
else:
|
||||
actions.append((
|
||||
self.sql_alter_column_null % {
|
||||
"column": self.quote_name(new_field.column),
|
||||
},
|
||||
[],
|
||||
))
|
||||
# Combine actions together if we can (e.g. postgres)
|
||||
if self.connection.features.supports_combined_alters:
|
||||
sql, params = tuple(zip(*actions))
|
||||
actions = [(", ".join(sql), params)]
|
||||
# Apply those actions
|
||||
for sql, params in actions:
|
||||
self.execute(
|
||||
self.sql_alter_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"changes": sql,
|
||||
},
|
||||
params,
|
||||
)
|
||||
|
|
|
@ -2,8 +2,9 @@ from __future__ import absolute_import
|
|||
import copy
|
||||
import datetime
|
||||
from django.test import TestCase
|
||||
from django.db.models.loading import cache
|
||||
from django.db import connection, DatabaseError, IntegrityError
|
||||
from django.db.models.fields import IntegerField, TextField
|
||||
from django.db.models.loading import cache
|
||||
from .models import Author, Book
|
||||
|
||||
|
||||
|
@ -18,6 +19,8 @@ class SchemaTests(TestCase):
|
|||
|
||||
models = [Author, Book]
|
||||
|
||||
# Utility functions
|
||||
|
||||
def setUp(self):
|
||||
# Make sure we're in manual transaction mode
|
||||
connection.commit_unless_managed()
|
||||
|
@ -51,6 +54,18 @@ class SchemaTests(TestCase):
|
|||
cache.app_store = self.old_app_store
|
||||
cache._get_models_cache = {}
|
||||
|
||||
def column_classes(self, model):
|
||||
cursor = connection.cursor()
|
||||
return dict(
|
||||
(d[0], (connection.introspection.get_field_type(d[1], d), d))
|
||||
for d in connection.introspection.get_table_description(
|
||||
cursor,
|
||||
model._meta.db_table,
|
||||
)
|
||||
)
|
||||
|
||||
# Tests
|
||||
|
||||
def test_creation_deletion(self):
|
||||
"""
|
||||
Tries creating a model's table, and then deleting it.
|
||||
|
@ -100,3 +115,60 @@ class SchemaTests(TestCase):
|
|||
pub_date = datetime.datetime.now(),
|
||||
)
|
||||
connection.commit()
|
||||
|
||||
def test_create_field(self):
|
||||
"""
|
||||
Tests adding fields to models
|
||||
"""
|
||||
# Create the table
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_model(Author)
|
||||
editor.commit()
|
||||
# Ensure there's no age field
|
||||
columns = self.column_classes(Author)
|
||||
self.assertNotIn("age", columns)
|
||||
# Alter the name field to a TextField
|
||||
new_field = IntegerField(null=True)
|
||||
new_field.set_attributes_from_name("age")
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_field(
|
||||
Author,
|
||||
new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure the field is right afterwards
|
||||
columns = self.column_classes(Author)
|
||||
self.assertEqual(columns['age'][0], "IntegerField")
|
||||
self.assertEqual(columns['age'][1][6], True)
|
||||
|
||||
def test_alter(self):
|
||||
"""
|
||||
Tests simple altering of fields
|
||||
"""
|
||||
# Create the table
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_model(Author)
|
||||
editor.commit()
|
||||
# Ensure the field is right to begin with
|
||||
columns = self.column_classes(Author)
|
||||
self.assertEqual(columns['name'][0], "CharField")
|
||||
self.assertEqual(columns['name'][1][3], 255)
|
||||
self.assertEqual(columns['name'][1][6], False)
|
||||
# Alter the name field to a TextField
|
||||
new_field = TextField(null=True)
|
||||
new_field.set_attributes_from_name("name")
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.alter_field(
|
||||
Author,
|
||||
Author._meta.get_field_by_name("name")[0],
|
||||
new_field,
|
||||
)
|
||||
editor.commit()
|
||||
# Ensure the field is right afterwards
|
||||
columns = self.column_classes(Author)
|
||||
self.assertEqual(columns['name'][0], "TextField")
|
||||
self.assertEqual(columns['name'][1][6], True)
|
||||
|
|
Loading…
Reference in New Issue