Refs #29641 -- Refactored database schema constraint creation.

Added a test for constraint names in the database.

Updated SQLite introspection to use sqlparse to allow reading the
constraint name for table check and unique constraints.

Co-authored-by: Ian Foote <python@ian.feete.org>
This commit is contained in:
Simon Charette 2018-08-05 21:06:52 -04:00 committed by Tim Graham
parent 2f120ac517
commit dba4a634ba
7 changed files with 147 additions and 82 deletions

View File

@ -61,25 +61,24 @@ class BaseDatabaseSchemaEditor:
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)"
sql_create_check = "ALTER TABLE %(table)s ADD %(check)s"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_foreign_key_constraint = "FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s)%(deferrable)s"
sql_unique_constraint = "UNIQUE (%(columns)s)"
sql_check_constraint = "CHECK (%(check)s)"
sql_create_constraint = "ALTER TABLE %(table)s ADD %(constraint)s"
sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_constraint = "CONSTRAINT %(name)s %(constraint)s"
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
sql_delete_unique = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_unique = None
sql_delete_unique = sql_delete_constraint
sql_create_fk = (
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
"REFERENCES %(to_table)s (%(to_column)s)%(deferrable)s"
)
sql_create_inline_fk = None
sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_fk = sql_delete_constraint
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s%(condition)s"
sql_delete_index = "DROP INDEX %(name)s"
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_pk = sql_delete_constraint
sql_delete_procedure = 'DROP PROCEDURE %(procedure)s'
@ -254,7 +253,7 @@ class BaseDatabaseSchemaEditor:
# Check constraints can go on the column SQL here
db_params = field.db_parameters(connection=self.connection)
if db_params['check']:
definition += " CHECK (%s)" % db_params['check']
definition += " " + self.sql_check_constraint % db_params
# Autoincrement SQL (for backends with inline variant)
col_type_suffix = field.db_type_suffix(connection=self.connection)
if col_type_suffix:
@ -287,7 +286,7 @@ class BaseDatabaseSchemaEditor:
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns))
constraints = [check.constraint_sql(model, self) for check in model._meta.constraints]
constraints = [check.full_constraint_sql(model, self) for check in model._meta.constraints]
# Make the table
sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table),
@ -596,7 +595,7 @@ class BaseDatabaseSchemaEditor:
old_field.column,
))
for constraint_name in constraint_names:
self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name))
self.execute(self._delete_constraint_sql(self.sql_delete_constraint, model, constraint_name))
# Have they renamed the column?
if old_field.column != new_field.column:
self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
@ -746,15 +745,16 @@ class BaseDatabaseSchemaEditor:
self.execute(self._create_fk_sql(rel.related_model, rel.field, "_fk"))
# Does it have check constraints we need to add?
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
self.execute(
self.sql_create_check % {
"table": self.quote_name(model._meta.db_table),
"check": self.sql_check % {
constraint = self.sql_constraint % {
'name': self.quote_name(
self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'),
),
'check': new_db_params['check'],
},
'constraint': self.sql_check_constraint % new_db_params,
}
self.execute(
self.sql_create_constraint % {
'table': self.quote_name(model._meta.db_table),
'constraint': constraint,
}
)
# Drop the default if we need to
@ -983,35 +983,57 @@ class BaseDatabaseSchemaEditor:
"type": new_type,
}
def _create_fk_sql(self, model, field, suffix):
from_table = model._meta.db_table
from_column = field.column
_, to_table = split_identifier(field.target_field.model._meta.db_table)
to_column = field.target_field.column
def _create_constraint_sql(self, table, name, constraint):
constraint = Statement(self.sql_constraint, name=name, constraint=constraint)
return Statement(self.sql_create_constraint, table=table, constraint=constraint)
def _create_fk_sql(self, model, field, suffix):
def create_fk_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs))
return Statement(
self.sql_create_fk,
table=Table(from_table, self.quote_name),
name=ForeignKeyName(from_table, [from_column], to_table, [to_column], suffix, create_fk_name),
column=Columns(from_table, [from_column], self.quote_name),
to_table=Table(field.target_field.model._meta.db_table, self.quote_name),
to_column=Columns(field.target_field.model._meta.db_table, [to_column], self.quote_name),
deferrable=self.connection.ops.deferrable_sql(),
table = Table(model._meta.db_table, self.quote_name)
name = ForeignKeyName(
model._meta.db_table,
[field.column],
split_identifier(field.target_field.model._meta.db_table)[1],
[field.target_field.column],
suffix,
create_fk_name,
)
column = Columns(model._meta.db_table, [field.column], self.quote_name)
to_table = Table(field.target_field.model._meta.db_table, self.quote_name)
to_column = Columns(field.target_field.model._meta.db_table, [field.target_field.column], self.quote_name)
deferrable = self.connection.ops.deferrable_sql()
constraint = Statement(
self.sql_foreign_key_constraint,
column=column,
to_table=to_table,
to_column=to_column,
deferrable=deferrable,
)
return self._create_constraint_sql(table, name, constraint)
def _create_unique_sql(self, model, columns):
def _create_unique_sql(self, model, columns, name=None):
def create_unique_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs))
table = model._meta.db_table
table = Table(model._meta.db_table, self.quote_name)
if name is None:
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
else:
name = self.quote_name(name)
columns = Columns(table, columns, self.quote_name)
if self.sql_create_unique:
# Some databases use a different syntax for unique constraint
# creation.
return Statement(
self.sql_create_unique,
table=Table(table, self.quote_name),
name=IndexName(table, columns, '_uniq', create_unique_name),
columns=Columns(table, columns, self.quote_name),
table=table,
name=name,
columns=columns,
)
constraint = Statement(self.sql_unique_constraint, columns=columns)
return self._create_constraint_sql(table, name, constraint)
def _delete_constraint_sql(self, template, model, name):
return template % {

View File

@ -1,5 +1,7 @@
import re
import sqlparse
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)
@ -242,19 +244,37 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# table_name is a view.
pass
else:
fields_with_check_constraints = [
schema_row.strip().split(' ')[0][1:-1]
for schema_row in table_schema.split(',')
if schema_row.find('CHECK') >= 0
]
for field_name in fields_with_check_constraints:
# An arbitrary made up name.
constraints['__check__%s' % field_name] = {
'columns': [field_name],
# Check constraint parsing is based of SQLite syntax diagram.
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
def next_ttype(ttype):
for token in tokens:
if token.ttype == ttype:
return token
statement = sqlparse.parse(table_schema)[0]
tokens = statement.flatten()
for token in tokens:
name = None
if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'):
# Table constraint
name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
name = name_token.value[1:-1]
token = next_ttype(sqlparse.tokens.Keyword)
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
# Column check constraint
if name is None:
column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
column = column_token.value[1:-1]
name = '__check__%s' % column
columns = [column]
else:
columns = []
constraints[name] = {
'check': True,
'columns': columns,
'primary_key': False,
'unique': False,
'foreign_key': False,
'check': True,
'index': False,
}
# Get the index info

View File

@ -12,10 +12,10 @@ from django.db.utils import NotSupportedError
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s"
sql_foreign_key_constraint = None
def __enter__(self):
# Some SQLite schema alterations need foreign key constraints to be

View File

@ -10,16 +10,22 @@ class BaseConstraint:
def constraint_sql(self, model, schema_editor):
raise NotImplementedError('This method must be implemented by a subclass.')
def full_constraint_sql(self, model, schema_editor):
return schema_editor.sql_constraint % {
'name': schema_editor.quote_name(self.name),
'constraint': self.constraint_sql(model, schema_editor),
}
def create_sql(self, model, schema_editor):
sql = self.constraint_sql(model, schema_editor)
return schema_editor.sql_create_check % {
sql = self.full_constraint_sql(model, schema_editor)
return schema_editor.sql_create_constraint % {
'table': schema_editor.quote_name(model._meta.db_table),
'check': sql,
'constraint': sql,
}
def remove_sql(self, model, schema_editor):
quote_name = schema_editor.quote_name
return schema_editor.sql_delete_check % {
return schema_editor.sql_delete_constraint % {
'table': quote_name(model._meta.db_table),
'name': quote_name(self.name),
}
@ -46,10 +52,7 @@ class CheckConstraint(BaseConstraint):
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
sql, params = where.as_sql(compiler, connection)
params = tuple(schema_editor.quote_value(p) for p in params)
return schema_editor.sql_check % {
'name': schema_editor.quote_name(self.name),
'check': sql % params,
}
return schema_editor.sql_check_constraint % {'check': sql % params}
def __repr__(self):
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)

View File

@ -293,6 +293,13 @@ Database backend API
* Third party database backends must implement support for partial indexes or
set ``DatabaseFeatures.supports_partial_indexes`` to ``False``.
* Several ``SchemaEditor`` attributes are changed:
* ``sql_create_check`` is replaced with ``sql_create_constraint``.
* ``sql_delete_check`` is replaced with ``sql_delete_constraint``.
* ``sql_create_fk`` is replaced with ``sql_foreign_key_constraint``,
``sql_constraint``, and ``sql_create_constraint``.
Admin actions are no longer collected from base ``ModelAdmin`` classes
----------------------------------------------------------------------

View File

@ -1,10 +1,15 @@
from django.db import IntegrityError, models
from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import Product
def get_constraints(table):
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
class BaseConstraintTests(SimpleTestCase):
def test_constraint_sql(self):
c = BaseConstraint('name')
@ -37,3 +42,11 @@ class CheckConstraintTests(TestCase):
Product.objects.create(name='Valid', price=10, discounted_price=5)
with self.assertRaises(IntegrityError):
Product.objects.create(name='Invalid', price=10, discounted_price=20)
@skipUnlessDBFeature('supports_table_check_constraints')
def test_name(self):
constraints = get_constraints(Product._meta.db_table)
expected_name = 'price_gt_discounted_price'
if connection.features.uppercases_column_names:
expected_name = expected_name.upper()
self.assertIn(expected_name, constraints)

View File

@ -2145,30 +2145,30 @@ class SchemaTests(TransactionTestCase):
self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table))
constraint_name = "CamelCaseUniqConstraint"
editor.execute(
editor.sql_create_unique % {
"table": editor.quote_name(table),
"name": editor.quote_name(constraint_name),
"columns": editor.quote_name(field.column),
}
)
editor.execute(editor._create_unique_sql(model, [field.column], constraint_name))
if connection.features.uppercases_column_names:
constraint_name = constraint_name.upper()
self.assertIn(constraint_name, self.get_constraints(model._meta.db_table))
editor.alter_field(model, get_field(unique=True), field, strict=True)
self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table))
if editor.sql_create_fk:
if editor.sql_foreign_key_constraint:
constraint_name = "CamelCaseFKConstraint"
editor.execute(
editor.sql_create_fk % {
"table": editor.quote_name(table),
"name": editor.quote_name(constraint_name),
fk_sql = editor.sql_foreign_key_constraint % {
"column": editor.quote_name(column),
"to_table": editor.quote_name(table),
"to_column": editor.quote_name(model._meta.auto_field.column),
"deferrable": connection.ops.deferrable_sql(),
}
constraint_sql = editor.sql_constraint % {
"name": editor.quote_name(constraint_name),
"constraint": fk_sql,
}
editor.execute(
editor.sql_create_constraint % {
"table": editor.quote_name(table),
"constraint": constraint_sql,
}
)
if connection.features.uppercases_column_names:
constraint_name = constraint_name.upper()