Add check constraint support - needed a few Field changes
This commit is contained in:
parent
375178fc19
commit
ca9c3cd39f
|
@ -435,6 +435,9 @@ class BaseDatabaseFeatures(object):
|
||||||
# Does it support foreign keys?
|
# Does it support foreign keys?
|
||||||
supports_foreign_keys = True
|
supports_foreign_keys = True
|
||||||
|
|
||||||
|
# Does it support CHECK constraints?
|
||||||
|
supports_check_constraints = True
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ class BaseDatabaseCreation(object):
|
||||||
destruction of test databases.
|
destruction of test databases.
|
||||||
"""
|
"""
|
||||||
data_types = {}
|
data_types = {}
|
||||||
|
data_type_check_constraints = {}
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
|
@ -170,6 +170,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
requires_explicit_null_ordering_when_grouping = True
|
requires_explicit_null_ordering_when_grouping = True
|
||||||
allows_primary_key_0 = False
|
allows_primary_key_0 = False
|
||||||
uses_savepoints = True
|
uses_savepoints = True
|
||||||
|
supports_check_constraints = False
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
super(DatabaseFeatures, self).__init__(connection)
|
super(DatabaseFeatures, self).__init__(connection)
|
||||||
|
|
|
@ -26,14 +26,19 @@ class DatabaseCreation(BaseDatabaseCreation):
|
||||||
'GenericIPAddressField': 'inet',
|
'GenericIPAddressField': 'inet',
|
||||||
'NullBooleanField': 'boolean',
|
'NullBooleanField': 'boolean',
|
||||||
'OneToOneField': 'integer',
|
'OneToOneField': 'integer',
|
||||||
'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)',
|
'PositiveIntegerField': 'integer',
|
||||||
'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)',
|
'PositiveSmallIntegerField': 'smallint',
|
||||||
'SlugField': 'varchar(%(max_length)s)',
|
'SlugField': 'varchar(%(max_length)s)',
|
||||||
'SmallIntegerField': 'smallint',
|
'SmallIntegerField': 'smallint',
|
||||||
'TextField': 'text',
|
'TextField': 'text',
|
||||||
'TimeField': 'time',
|
'TimeField': 'time',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data_type_check_constraints = {
|
||||||
|
'PositiveIntegerField': '"%(column)s" >= 0',
|
||||||
|
'PositiveSmallIntegerField': '"%(column)s" >= 0',
|
||||||
|
}
|
||||||
|
|
||||||
def sql_table_creation_suffix(self):
|
def sql_table_creation_suffix(self):
|
||||||
assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time."
|
assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time."
|
||||||
if self.connection.settings_dict['TEST_CHARSET']:
|
if self.connection.settings_dict['TEST_CHARSET']:
|
||||||
|
|
|
@ -137,7 +137,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
kc.table_schema = %s AND
|
kc.table_schema = %s AND
|
||||||
kc.table_name = %s
|
kc.table_name = %s
|
||||||
""", ["public", table_name])
|
""", ["public", table_name])
|
||||||
for constraint, column, kind in cursor.fetchall():
|
for constraint, column in cursor.fetchall():
|
||||||
# If we're the first column, make the record
|
# If we're the first column, make the record
|
||||||
if constraint not in constraints:
|
if constraint not in constraints:
|
||||||
constraints[constraint] = {
|
constraints[constraint] = {
|
||||||
|
|
|
@ -19,9 +19,6 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
then the relevant actions, and then commit(). This is necessary to allow
|
then the relevant actions, and then commit(). This is necessary to allow
|
||||||
things like circular foreign key references - FKs will only be created once
|
things like circular foreign key references - FKs will only be created once
|
||||||
commit() is called.
|
commit() is called.
|
||||||
|
|
||||||
TODO:
|
|
||||||
- Check constraints (PosIntField)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Overrideable SQL templates
|
# Overrideable SQL templates
|
||||||
|
@ -41,7 +38,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
|
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
|
||||||
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
|
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
|
||||||
|
|
||||||
sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||||
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||||
|
|
||||||
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
|
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
|
||||||
|
@ -105,7 +102,8 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
The field must already have had set_attributes_from_name called.
|
The field must already have had set_attributes_from_name called.
|
||||||
"""
|
"""
|
||||||
# Get the column's type and use that as the basis of the SQL
|
# Get the column's type and use that as the basis of the SQL
|
||||||
sql = field.db_type(connection=self.connection)
|
db_params = field.db_parameters(connection=self.connection)
|
||||||
|
sql = db_params['type']
|
||||||
params = []
|
params = []
|
||||||
# Check for fields that aren't actually columns (e.g. M2M)
|
# Check for fields that aren't actually columns (e.g. M2M)
|
||||||
if sql is None:
|
if sql is None:
|
||||||
|
@ -169,6 +167,11 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
definition, extra_params = self.column_sql(model, field)
|
definition, extra_params = self.column_sql(model, field)
|
||||||
if definition is None:
|
if definition is None:
|
||||||
continue
|
continue
|
||||||
|
# Check constraints can go on the column SQL here
|
||||||
|
db_params = field.db_parameters(connection=self.connection)
|
||||||
|
if db_params['check']:
|
||||||
|
definition += " CHECK (%s)" % db_params['check']
|
||||||
|
# Add the SQL to our big list
|
||||||
column_sqls.append("%s %s" % (
|
column_sqls.append("%s %s" % (
|
||||||
self.quote_name(field.column),
|
self.quote_name(field.column),
|
||||||
definition,
|
definition,
|
||||||
|
@ -295,6 +298,10 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
# It might not actually have a column behind it
|
# It might not actually have a column behind it
|
||||||
if definition is None:
|
if definition is None:
|
||||||
return
|
return
|
||||||
|
# Check constraints can go on the column SQL here
|
||||||
|
db_params = field.db_parameters(connection=self.connection)
|
||||||
|
if db_params['check']:
|
||||||
|
definition += " CHECK (%s)" % db_params['check']
|
||||||
# Build the SQL and run it
|
# Build the SQL and run it
|
||||||
sql = self.sql_create_column % {
|
sql = self.sql_create_column % {
|
||||||
"table": self.quote_name(model._meta.db_table),
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
@ -358,8 +365,10 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
If strict is true, raises errors if the old column does not match old_field precisely.
|
If strict is true, raises errors if the old column does not match old_field precisely.
|
||||||
"""
|
"""
|
||||||
# Ensure this field is even column-based
|
# Ensure this field is even column-based
|
||||||
old_type = old_field.db_type(connection=self.connection)
|
old_db_params = old_field.db_parameters(connection=self.connection)
|
||||||
new_type = self._type_for_alter(new_field)
|
old_type = old_db_params['type']
|
||||||
|
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||||
|
new_type = new_db_params['type']
|
||||||
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
|
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
|
||||||
return self._alter_many_to_many(model, old_field, new_field, strict)
|
return self._alter_many_to_many(model, old_field, new_field, strict)
|
||||||
elif old_type is None or new_type is None:
|
elif old_type is None or new_type is None:
|
||||||
|
@ -417,6 +426,22 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
"name": fk_name,
|
"name": fk_name,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# Change check constraints?
|
||||||
|
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
|
||||||
|
constraint_names = self._constraint_names(model, [old_field.column], check=True)
|
||||||
|
if strict and len(constraint_names) != 1:
|
||||||
|
raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % (
|
||||||
|
len(constraint_names),
|
||||||
|
model._meta.db_table,
|
||||||
|
old_field.column,
|
||||||
|
))
|
||||||
|
for constraint_name in constraint_names:
|
||||||
|
self.execute(
|
||||||
|
self.sql_delete_check % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"name": constraint_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
# Have they renamed the column?
|
# Have they renamed the column?
|
||||||
if old_field.column != new_field.column:
|
if old_field.column != new_field.column:
|
||||||
self.execute(self.sql_rename_column % {
|
self.execute(self.sql_rename_column % {
|
||||||
|
@ -543,6 +568,16 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
"to_column": self.quote_name(new_field.rel.get_related_field().column),
|
"to_column": self.quote_name(new_field.rel.get_related_field().column),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# Does it have check constraints we need to add?
|
||||||
|
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
|
||||||
|
self.execute(
|
||||||
|
self.sql_create_check % {
|
||||||
|
"table": self.quote_name(model._meta.db_table),
|
||||||
|
"name": self._create_index_name(model, [new_field.column], suffix="_check"),
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
"check": new_db_params['check'],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||||
"Alters M2Ms to repoint their to= endpoints."
|
"Alters M2Ms to repoint their to= endpoints."
|
||||||
|
@ -555,14 +590,6 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
|
new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _type_for_alter(self, field):
|
|
||||||
"""
|
|
||||||
Returns a field's type suitable for ALTER COLUMN.
|
|
||||||
By default it just returns field.db_type().
|
|
||||||
To be overriden by backend specific subclasses
|
|
||||||
"""
|
|
||||||
return field.db_type(connection=self.connection)
|
|
||||||
|
|
||||||
def _create_index_name(self, model, column_names, suffix=""):
|
def _create_index_name(self, model, column_names, suffix=""):
|
||||||
"Generates a unique name for an index/unique constraint."
|
"Generates a unique name for an index/unique constraint."
|
||||||
# If there is just one column in the index, use a default algorithm from Django
|
# If there is just one column in the index, use a default algorithm from Django
|
||||||
|
@ -581,7 +608,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
|
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
|
||||||
return index_name
|
return index_name
|
||||||
|
|
||||||
def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None):
|
def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None, check=None):
|
||||||
"Returns all constraint names matching the columns and conditions"
|
"Returns all constraint names matching the columns and conditions"
|
||||||
column_names = set(column_names) if column_names else None
|
column_names = set(column_names) if column_names else None
|
||||||
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
|
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
|
||||||
|
@ -594,6 +621,8 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
continue
|
continue
|
||||||
if index is not None and infodict['index'] != index:
|
if index is not None and infodict['index'] != index:
|
||||||
continue
|
continue
|
||||||
|
if check is not None and infodict['check'] != check:
|
||||||
|
continue
|
||||||
if foreign_key is not None and not infodict['foreign_key']:
|
if foreign_key is not None and not infodict['foreign_key']:
|
||||||
continue
|
continue
|
||||||
result.append(name)
|
result.append(name)
|
||||||
|
|
|
@ -97,6 +97,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
has_bulk_insert = True
|
has_bulk_insert = True
|
||||||
can_combine_inserts_with_and_without_auto_increment_pk = False
|
can_combine_inserts_with_and_without_auto_increment_pk = False
|
||||||
supports_foreign_keys = False
|
supports_foreign_keys = False
|
||||||
|
supports_check_constraints = False
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def supports_stddev(self):
|
def supports_stddev(self):
|
||||||
|
|
|
@ -99,8 +99,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
|
|
||||||
def alter_field(self, model, old_field, new_field, strict=False):
|
def alter_field(self, model, old_field, new_field, strict=False):
|
||||||
# Ensure this field is even column-based
|
# Ensure this field is even column-based
|
||||||
old_type = old_field.db_type(connection=self.connection)
|
old_db_params = old_field.db_parameters(connection=self.connection)
|
||||||
new_type = self._type_for_alter(new_field)
|
old_type = old_db_params['type']
|
||||||
|
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||||
|
new_type = new_db_params['type']
|
||||||
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
|
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
|
||||||
return self._alter_many_to_many(model, old_field, new_field, strict)
|
return self._alter_many_to_many(model, old_field, new_field, strict)
|
||||||
elif old_type is None or new_type is None:
|
elif old_type is None or new_type is None:
|
||||||
|
|
|
@ -232,12 +232,32 @@ class Field(object):
|
||||||
# mapped to one of the built-in Django field types. In this case, you
|
# mapped to one of the built-in Django field types. In this case, you
|
||||||
# can implement db_type() instead of get_internal_type() to specify
|
# can implement db_type() instead of get_internal_type() to specify
|
||||||
# exactly which wacky database column type you want to use.
|
# exactly which wacky database column type you want to use.
|
||||||
|
params = self.db_parameters(connection)
|
||||||
|
if params['type']:
|
||||||
|
if params['check']:
|
||||||
|
return "%s CHECK (%s)" % (params['type'], params['check'])
|
||||||
|
else:
|
||||||
|
return params['type']
|
||||||
|
return None
|
||||||
|
|
||||||
|
def db_parameters(self, connection):
|
||||||
|
"""
|
||||||
|
Replacement for db_type, providing a range of different return
|
||||||
|
values (type, checks)
|
||||||
|
"""
|
||||||
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
|
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
|
||||||
try:
|
try:
|
||||||
return (connection.creation.data_types[self.get_internal_type()]
|
type_string = connection.creation.data_types[self.get_internal_type()] % data
|
||||||
% data)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
type_string = None
|
||||||
|
try:
|
||||||
|
check_string = connection.creation.data_type_check_constraints[self.get_internal_type()] % data
|
||||||
|
except KeyError:
|
||||||
|
check_string = None
|
||||||
|
return {
|
||||||
|
"type": type_string,
|
||||||
|
"check": check_string,
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unique(self):
|
def unique(self):
|
||||||
|
|
|
@ -1050,6 +1050,9 @@ class ForeignKey(RelatedField, Field):
|
||||||
return IntegerField().db_type(connection=connection)
|
return IntegerField().db_type(connection=connection)
|
||||||
return rel_field.db_type(connection=connection)
|
return rel_field.db_type(connection=connection)
|
||||||
|
|
||||||
|
def db_parameters(self, connection):
|
||||||
|
return {"type": self.db_type(connection), "check": []}
|
||||||
|
|
||||||
class OneToOneField(ForeignKey):
|
class OneToOneField(ForeignKey):
|
||||||
"""
|
"""
|
||||||
A OneToOneField is essentially the same as a ForeignKey, with the exception
|
A OneToOneField is essentially the same as a ForeignKey, with the exception
|
||||||
|
@ -1292,3 +1295,6 @@ class ManyToManyField(RelatedField, Field):
|
||||||
# A ManyToManyField is not represented by a single column,
|
# A ManyToManyField is not represented by a single column,
|
||||||
# so return None.
|
# so return None.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def db_parameters(self, connection):
|
||||||
|
return {"type": None, "check": None}
|
||||||
|
|
|
@ -7,6 +7,7 @@ from django.db import models
|
||||||
|
|
||||||
class Author(models.Model):
|
class Author(models.Model):
|
||||||
name = models.CharField(max_length=255)
|
name = models.CharField(max_length=255)
|
||||||
|
height = models.PositiveIntegerField(null=True, blank=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
managed = False
|
managed = False
|
||||||
|
|
|
@ -347,6 +347,56 @@ class SchemaTests(TestCase):
|
||||||
else:
|
else:
|
||||||
self.fail("No FK constraint for tag_id found")
|
self.fail("No FK constraint for tag_id found")
|
||||||
|
|
||||||
|
@skipUnless(connection.features.supports_check_constraints, "No check constraints")
|
||||||
|
def test_check_constraints(self):
|
||||||
|
"""
|
||||||
|
Tests creating/deleting CHECK constraints
|
||||||
|
"""
|
||||||
|
# Create the tables
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.create_model(Author)
|
||||||
|
editor.commit()
|
||||||
|
# Ensure the constraint exists
|
||||||
|
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
|
||||||
|
for name, details in constraints.items():
|
||||||
|
if details['columns'] == set(["height"]) and details['check']:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.fail("No check constraint for height found")
|
||||||
|
# Alter the column to remove it
|
||||||
|
new_field = IntegerField(null=True, blank=True)
|
||||||
|
new_field.set_attributes_from_name("height")
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.alter_field(
|
||||||
|
Author,
|
||||||
|
Author._meta.get_field_by_name("height")[0],
|
||||||
|
new_field,
|
||||||
|
strict = True,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
|
||||||
|
for name, details in constraints.items():
|
||||||
|
if details['columns'] == set(["height"]) and details['check']:
|
||||||
|
self.fail("Check constraint for height found")
|
||||||
|
# Alter the column to re-add it
|
||||||
|
editor = connection.schema_editor()
|
||||||
|
editor.start()
|
||||||
|
editor.alter_field(
|
||||||
|
Author,
|
||||||
|
new_field,
|
||||||
|
Author._meta.get_field_by_name("height")[0],
|
||||||
|
strict = True,
|
||||||
|
)
|
||||||
|
editor.commit()
|
||||||
|
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
|
||||||
|
for name, details in constraints.items():
|
||||||
|
if details['columns'] == set(["height"]) and details['check']:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.fail("No check constraint for height found")
|
||||||
|
|
||||||
def test_unique(self):
|
def test_unique(self):
|
||||||
"""
|
"""
|
||||||
Tests removing and adding unique constraints to a single column.
|
Tests removing and adding unique constraints to a single column.
|
||||||
|
|
Loading…
Reference in New Issue