mirror of https://github.com/django/django.git
Add check constraint support - needed a few Field changes
This commit is contained in:
parent
375178fc19
commit
ca9c3cd39f
|
@ -435,6 +435,9 @@ class BaseDatabaseFeatures(object):
|
|||
# Does it support foreign keys?
|
||||
supports_foreign_keys = True
|
||||
|
||||
# Does it support CHECK constraints?
|
||||
supports_check_constraints = True
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ class BaseDatabaseCreation(object):
|
|||
destruction of test databases.
|
||||
"""
|
||||
data_types = {}
|
||||
data_type_check_constraints = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
|
|
@ -170,6 +170,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
requires_explicit_null_ordering_when_grouping = True
|
||||
allows_primary_key_0 = False
|
||||
uses_savepoints = True
|
||||
supports_check_constraints = False
|
||||
|
||||
def __init__(self, connection):
|
||||
super(DatabaseFeatures, self).__init__(connection)
|
||||
|
|
|
@ -26,14 +26,19 @@ class DatabaseCreation(BaseDatabaseCreation):
|
|||
'GenericIPAddressField': 'inet',
|
||||
'NullBooleanField': 'boolean',
|
||||
'OneToOneField': 'integer',
|
||||
'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)',
|
||||
'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)',
|
||||
'PositiveIntegerField': 'integer',
|
||||
'PositiveSmallIntegerField': 'smallint',
|
||||
'SlugField': 'varchar(%(max_length)s)',
|
||||
'SmallIntegerField': 'smallint',
|
||||
'TextField': 'text',
|
||||
'TimeField': 'time',
|
||||
}
|
||||
|
||||
data_type_check_constraints = {
|
||||
'PositiveIntegerField': '"%(column)s" >= 0',
|
||||
'PositiveSmallIntegerField': '"%(column)s" >= 0',
|
||||
}
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time."
|
||||
if self.connection.settings_dict['TEST_CHARSET']:
|
||||
|
|
|
@ -137,7 +137,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
kc.table_schema = %s AND
|
||||
kc.table_name = %s
|
||||
""", ["public", table_name])
|
||||
for constraint, column, kind in cursor.fetchall():
|
||||
for constraint, column in cursor.fetchall():
|
||||
# If we're the first column, make the record
|
||||
if constraint not in constraints:
|
||||
constraints[constraint] = {
|
||||
|
|
|
@ -19,9 +19,6 @@ class BaseDatabaseSchemaEditor(object):
|
|||
then the relevant actions, and then commit(). This is necessary to allow
|
||||
things like circular foreign key references - FKs will only be created once
|
||||
commit() is called.
|
||||
|
||||
TODO:
|
||||
- Check constraints (PosIntField)
|
||||
"""
|
||||
|
||||
# Overrideable SQL templates
|
||||
|
@ -41,7 +38,7 @@ class BaseDatabaseSchemaEditor(object):
|
|||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
|
||||
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
|
||||
|
||||
sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
|
||||
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
|
||||
|
@ -105,7 +102,8 @@ class BaseDatabaseSchemaEditor(object):
|
|||
The field must already have had set_attributes_from_name called.
|
||||
"""
|
||||
# Get the column's type and use that as the basis of the SQL
|
||||
sql = field.db_type(connection=self.connection)
|
||||
db_params = field.db_parameters(connection=self.connection)
|
||||
sql = db_params['type']
|
||||
params = []
|
||||
# Check for fields that aren't actually columns (e.g. M2M)
|
||||
if sql is None:
|
||||
|
@ -169,6 +167,11 @@ class BaseDatabaseSchemaEditor(object):
|
|||
definition, extra_params = self.column_sql(model, field)
|
||||
if definition is None:
|
||||
continue
|
||||
# Check constraints can go on the column SQL here
|
||||
db_params = field.db_parameters(connection=self.connection)
|
||||
if db_params['check']:
|
||||
definition += " CHECK (%s)" % db_params['check']
|
||||
# Add the SQL to our big list
|
||||
column_sqls.append("%s %s" % (
|
||||
self.quote_name(field.column),
|
||||
definition,
|
||||
|
@ -295,6 +298,10 @@ class BaseDatabaseSchemaEditor(object):
|
|||
# It might not actually have a column behind it
|
||||
if definition is None:
|
||||
return
|
||||
# Check constraints can go on the column SQL here
|
||||
db_params = field.db_parameters(connection=self.connection)
|
||||
if db_params['check']:
|
||||
definition += " CHECK (%s)" % db_params['check']
|
||||
# Build the SQL and run it
|
||||
sql = self.sql_create_column % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
|
@ -358,8 +365,10 @@ class BaseDatabaseSchemaEditor(object):
|
|||
If strict is true, raises errors if the old column does not match old_field precisely.
|
||||
"""
|
||||
# Ensure this field is even column-based
|
||||
old_type = old_field.db_type(connection=self.connection)
|
||||
new_type = self._type_for_alter(new_field)
|
||||
old_db_params = old_field.db_parameters(connection=self.connection)
|
||||
old_type = old_db_params['type']
|
||||
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||
new_type = new_db_params['type']
|
||||
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
|
||||
return self._alter_many_to_many(model, old_field, new_field, strict)
|
||||
elif old_type is None or new_type is None:
|
||||
|
@ -417,6 +426,22 @@ class BaseDatabaseSchemaEditor(object):
|
|||
"name": fk_name,
|
||||
}
|
||||
)
|
||||
# Change check constraints?
|
||||
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
|
||||
constraint_names = self._constraint_names(model, [old_field.column], check=True)
|
||||
if strict and len(constraint_names) != 1:
|
||||
raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % (
|
||||
len(constraint_names),
|
||||
model._meta.db_table,
|
||||
old_field.column,
|
||||
))
|
||||
for constraint_name in constraint_names:
|
||||
self.execute(
|
||||
self.sql_delete_check % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"name": constraint_name,
|
||||
}
|
||||
)
|
||||
# Have they renamed the column?
|
||||
if old_field.column != new_field.column:
|
||||
self.execute(self.sql_rename_column % {
|
||||
|
@ -543,6 +568,16 @@ class BaseDatabaseSchemaEditor(object):
|
|||
"to_column": self.quote_name(new_field.rel.get_related_field().column),
|
||||
}
|
||||
)
|
||||
# Does it have check constraints we need to add?
|
||||
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
|
||||
self.execute(
|
||||
self.sql_create_check % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"name": self._create_index_name(model, [new_field.column], suffix="_check"),
|
||||
"column": self.quote_name(new_field.column),
|
||||
"check": new_db_params['check'],
|
||||
}
|
||||
)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"Alters M2Ms to repoint their to= endpoints."
|
||||
|
@ -555,14 +590,6 @@ class BaseDatabaseSchemaEditor(object):
|
|||
new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
|
||||
)
|
||||
|
||||
def _type_for_alter(self, field):
|
||||
"""
|
||||
Returns a field's type suitable for ALTER COLUMN.
|
||||
By default it just returns field.db_type().
|
||||
To be overriden by backend specific subclasses
|
||||
"""
|
||||
return field.db_type(connection=self.connection)
|
||||
|
||||
def _create_index_name(self, model, column_names, suffix=""):
|
||||
"Generates a unique name for an index/unique constraint."
|
||||
# If there is just one column in the index, use a default algorithm from Django
|
||||
|
@ -581,7 +608,7 @@ class BaseDatabaseSchemaEditor(object):
|
|||
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
|
||||
return index_name
|
||||
|
||||
def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None):
|
||||
def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None, check=None):
|
||||
"Returns all constraint names matching the columns and conditions"
|
||||
column_names = set(column_names) if column_names else None
|
||||
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
|
||||
|
@ -594,6 +621,8 @@ class BaseDatabaseSchemaEditor(object):
|
|||
continue
|
||||
if index is not None and infodict['index'] != index:
|
||||
continue
|
||||
if check is not None and infodict['check'] != check:
|
||||
continue
|
||||
if foreign_key is not None and not infodict['foreign_key']:
|
||||
continue
|
||||
result.append(name)
|
||||
|
|
|
@ -97,6 +97,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
has_bulk_insert = True
|
||||
can_combine_inserts_with_and_without_auto_increment_pk = False
|
||||
supports_foreign_keys = False
|
||||
supports_check_constraints = False
|
||||
|
||||
@cached_property
|
||||
def supports_stddev(self):
|
||||
|
|
|
@ -99,8 +99,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
|||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
# Ensure this field is even column-based
|
||||
old_type = old_field.db_type(connection=self.connection)
|
||||
new_type = self._type_for_alter(new_field)
|
||||
old_db_params = old_field.db_parameters(connection=self.connection)
|
||||
old_type = old_db_params['type']
|
||||
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||
new_type = new_db_params['type']
|
||||
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
|
||||
return self._alter_many_to_many(model, old_field, new_field, strict)
|
||||
elif old_type is None or new_type is None:
|
||||
|
|
|
@ -232,12 +232,32 @@ class Field(object):
|
|||
# mapped to one of the built-in Django field types. In this case, you
|
||||
# can implement db_type() instead of get_internal_type() to specify
|
||||
# exactly which wacky database column type you want to use.
|
||||
params = self.db_parameters(connection)
|
||||
if params['type']:
|
||||
if params['check']:
|
||||
return "%s CHECK (%s)" % (params['type'], params['check'])
|
||||
else:
|
||||
return params['type']
|
||||
return None
|
||||
|
||||
def db_parameters(self, connection):
|
||||
"""
|
||||
Replacement for db_type, providing a range of different return
|
||||
values (type, checks)
|
||||
"""
|
||||
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
|
||||
try:
|
||||
return (connection.creation.data_types[self.get_internal_type()]
|
||||
% data)
|
||||
type_string = connection.creation.data_types[self.get_internal_type()] % data
|
||||
except KeyError:
|
||||
return None
|
||||
type_string = None
|
||||
try:
|
||||
check_string = connection.creation.data_type_check_constraints[self.get_internal_type()] % data
|
||||
except KeyError:
|
||||
check_string = None
|
||||
return {
|
||||
"type": type_string,
|
||||
"check": check_string,
|
||||
}
|
||||
|
||||
@property
|
||||
def unique(self):
|
||||
|
|
|
@ -1050,6 +1050,9 @@ class ForeignKey(RelatedField, Field):
|
|||
return IntegerField().db_type(connection=connection)
|
||||
return rel_field.db_type(connection=connection)
|
||||
|
||||
def db_parameters(self, connection):
|
||||
return {"type": self.db_type(connection), "check": []}
|
||||
|
||||
class OneToOneField(ForeignKey):
|
||||
"""
|
||||
A OneToOneField is essentially the same as a ForeignKey, with the exception
|
||||
|
@ -1292,3 +1295,6 @@ class ManyToManyField(RelatedField, Field):
|
|||
# A ManyToManyField is not represented by a single column,
|
||||
# so return None.
|
||||
return None
|
||||
|
||||
def db_parameters(self, connection):
|
||||
return {"type": None, "check": None}
|
||||
|
|
|
@ -7,6 +7,7 @@ from django.db import models
|
|||
|
||||
class Author(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
height = models.PositiveIntegerField(null=True, blank=True)
|
||||
|
||||
class Meta:
|
||||
managed = False
|
||||
|
|
|
@ -347,6 +347,56 @@ class SchemaTests(TestCase):
|
|||
else:
|
||||
self.fail("No FK constraint for tag_id found")
|
||||
|
||||
@skipUnless(connection.features.supports_check_constraints, "No check constraints")
|
||||
def test_check_constraints(self):
|
||||
"""
|
||||
Tests creating/deleting CHECK constraints
|
||||
"""
|
||||
# Create the tables
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.create_model(Author)
|
||||
editor.commit()
|
||||
# Ensure the constraint exists
|
||||
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
|
||||
for name, details in constraints.items():
|
||||
if details['columns'] == set(["height"]) and details['check']:
|
||||
break
|
||||
else:
|
||||
self.fail("No check constraint for height found")
|
||||
# Alter the column to remove it
|
||||
new_field = IntegerField(null=True, blank=True)
|
||||
new_field.set_attributes_from_name("height")
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.alter_field(
|
||||
Author,
|
||||
Author._meta.get_field_by_name("height")[0],
|
||||
new_field,
|
||||
strict = True,
|
||||
)
|
||||
editor.commit()
|
||||
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
|
||||
for name, details in constraints.items():
|
||||
if details['columns'] == set(["height"]) and details['check']:
|
||||
self.fail("Check constraint for height found")
|
||||
# Alter the column to re-add it
|
||||
editor = connection.schema_editor()
|
||||
editor.start()
|
||||
editor.alter_field(
|
||||
Author,
|
||||
new_field,
|
||||
Author._meta.get_field_by_name("height")[0],
|
||||
strict = True,
|
||||
)
|
||||
editor.commit()
|
||||
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
|
||||
for name, details in constraints.items():
|
||||
if details['columns'] == set(["height"]) and details['check']:
|
||||
break
|
||||
else:
|
||||
self.fail("No check constraint for height found")
|
||||
|
||||
def test_unique(self):
|
||||
"""
|
||||
Tests removing and adding unique constraints to a single column.
|
||||
|
|
Loading…
Reference in New Issue