From dba4a634ba999bf376caee193b3378bc0b730bd4 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sun, 5 Aug 2018 21:06:52 -0400 Subject: [PATCH] Refs #29641 -- Refactored database schema constraint creation. Added a test for constraint names in the database. Updated SQLite introspection to use sqlparse to allow reading the constraint name for table check and unique constraints. Co-authored-by: Ian Foote --- django/db/backends/base/schema.py | 108 ++++++++++++-------- django/db/backends/sqlite3/introspection.py | 50 ++++++--- django/db/backends/sqlite3/schema.py | 2 +- django/db/models/constraints.py | 19 ++-- docs/releases/2.2.txt | 7 ++ tests/constraints/tests.py | 15 ++- tests/schema/tests.py | 28 ++--- 7 files changed, 147 insertions(+), 82 deletions(-) diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 783f7cd64e..f5cb433d6c 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -61,25 +61,24 @@ class BaseDatabaseSchemaEditor: sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" - sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)" - sql_create_check = "ALTER TABLE %(table)s ADD %(check)s" - sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" + sql_foreign_key_constraint = "FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s)%(deferrable)s" + sql_unique_constraint = "UNIQUE (%(columns)s)" + sql_check_constraint = "CHECK (%(check)s)" + sql_create_constraint = "ALTER TABLE %(table)s ADD %(constraint)s" + sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" + sql_constraint = "CONSTRAINT %(name)s %(constraint)s" - sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)" - sql_delete_unique = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" + sql_create_unique = None + sql_delete_unique = sql_delete_constraint - sql_create_fk = ( - "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) " - "REFERENCES %(to_table)s (%(to_column)s)%(deferrable)s" - ) sql_create_inline_fk = None - sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" + sql_delete_fk = sql_delete_constraint sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s%(condition)s" sql_delete_index = "DROP INDEX %(name)s" sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" - sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" + sql_delete_pk = sql_delete_constraint sql_delete_procedure = 'DROP PROCEDURE %(procedure)s' @@ -254,7 +253,7 @@ class BaseDatabaseSchemaEditor: # Check constraints can go on the column SQL here db_params = field.db_parameters(connection=self.connection) if db_params['check']: - definition += " CHECK (%s)" % db_params['check'] + definition += " " + self.sql_check_constraint % db_params # Autoincrement SQL (for backends with inline variant) col_type_suffix = field.db_type_suffix(connection=self.connection) if col_type_suffix: @@ -287,7 +286,7 @@ class BaseDatabaseSchemaEditor: for fields in model._meta.unique_together: columns = [model._meta.get_field(field).column for field in fields] self.deferred_sql.append(self._create_unique_sql(model, columns)) - constraints = [check.constraint_sql(model, self) for check in model._meta.constraints] + constraints = [check.full_constraint_sql(model, self) for check in model._meta.constraints] # Make the table sql = self.sql_create_table % { "table": self.quote_name(model._meta.db_table), @@ -596,7 +595,7 @@ class BaseDatabaseSchemaEditor: old_field.column, )) for constraint_name in constraint_names: - self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name)) + self.execute(self._delete_constraint_sql(self.sql_delete_constraint, model, constraint_name)) # Have they renamed the column? if old_field.column != new_field.column: self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) @@ -746,15 +745,16 @@ class BaseDatabaseSchemaEditor: self.execute(self._create_fk_sql(rel.related_model, rel.field, "_fk")) # Does it have check constraints we need to add? if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: + constraint = self.sql_constraint % { + 'name': self.quote_name( + self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'), + ), + 'constraint': self.sql_check_constraint % new_db_params, + } self.execute( - self.sql_create_check % { - "table": self.quote_name(model._meta.db_table), - "check": self.sql_check % { - 'name': self.quote_name( - self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'), - ), - 'check': new_db_params['check'], - }, + self.sql_create_constraint % { + 'table': self.quote_name(model._meta.db_table), + 'constraint': constraint, } ) # Drop the default if we need to @@ -983,35 +983,57 @@ class BaseDatabaseSchemaEditor: "type": new_type, } - def _create_fk_sql(self, model, field, suffix): - from_table = model._meta.db_table - from_column = field.column - _, to_table = split_identifier(field.target_field.model._meta.db_table) - to_column = field.target_field.column + def _create_constraint_sql(self, table, name, constraint): + constraint = Statement(self.sql_constraint, name=name, constraint=constraint) + return Statement(self.sql_create_constraint, table=table, constraint=constraint) + def _create_fk_sql(self, model, field, suffix): def create_fk_name(*args, **kwargs): return self.quote_name(self._create_index_name(*args, **kwargs)) - return Statement( - self.sql_create_fk, - table=Table(from_table, self.quote_name), - name=ForeignKeyName(from_table, [from_column], to_table, [to_column], suffix, create_fk_name), - column=Columns(from_table, [from_column], self.quote_name), - to_table=Table(field.target_field.model._meta.db_table, self.quote_name), - to_column=Columns(field.target_field.model._meta.db_table, [to_column], self.quote_name), - deferrable=self.connection.ops.deferrable_sql(), + table = Table(model._meta.db_table, self.quote_name) + name = ForeignKeyName( + model._meta.db_table, + [field.column], + split_identifier(field.target_field.model._meta.db_table)[1], + [field.target_field.column], + suffix, + create_fk_name, ) + column = Columns(model._meta.db_table, [field.column], self.quote_name) + to_table = Table(field.target_field.model._meta.db_table, self.quote_name) + to_column = Columns(field.target_field.model._meta.db_table, [field.target_field.column], self.quote_name) + deferrable = self.connection.ops.deferrable_sql() + constraint = Statement( + self.sql_foreign_key_constraint, + column=column, + to_table=to_table, + to_column=to_column, + deferrable=deferrable, + ) + return self._create_constraint_sql(table, name, constraint) - def _create_unique_sql(self, model, columns): + def _create_unique_sql(self, model, columns, name=None): def create_unique_name(*args, **kwargs): return self.quote_name(self._create_index_name(*args, **kwargs)) - table = model._meta.db_table - return Statement( - self.sql_create_unique, - table=Table(table, self.quote_name), - name=IndexName(table, columns, '_uniq', create_unique_name), - columns=Columns(table, columns, self.quote_name), - ) + + table = Table(model._meta.db_table, self.quote_name) + if name is None: + name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name) + else: + name = self.quote_name(name) + columns = Columns(table, columns, self.quote_name) + if self.sql_create_unique: + # Some databases use a different syntax for unique constraint + # creation. + return Statement( + self.sql_create_unique, + table=table, + name=name, + columns=columns, + ) + constraint = Statement(self.sql_unique_constraint, columns=columns) + return self._create_constraint_sql(table, name, constraint) def _delete_constraint_sql(self, template, model, name): return template % { diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 0c82ea8844..47ca25a78a 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -1,5 +1,7 @@ import re +import sqlparse + from django.db.backends.base.introspection import ( BaseDatabaseIntrospection, FieldInfo, TableInfo, ) @@ -242,21 +244,39 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # table_name is a view. pass else: - fields_with_check_constraints = [ - schema_row.strip().split(' ')[0][1:-1] - for schema_row in table_schema.split(',') - if schema_row.find('CHECK') >= 0 - ] - for field_name in fields_with_check_constraints: - # An arbitrary made up name. - constraints['__check__%s' % field_name] = { - 'columns': [field_name], - 'primary_key': False, - 'unique': False, - 'foreign_key': False, - 'check': True, - 'index': False, - } + # Check constraint parsing is based of SQLite syntax diagram. + # https://www.sqlite.org/syntaxdiagrams.html#table-constraint + def next_ttype(ttype): + for token in tokens: + if token.ttype == ttype: + return token + + statement = sqlparse.parse(table_schema)[0] + tokens = statement.flatten() + for token in tokens: + name = None + if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'): + # Table constraint + name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) + name = name_token.value[1:-1] + token = next_ttype(sqlparse.tokens.Keyword) + if token.match(sqlparse.tokens.Keyword, 'CHECK'): + # Column check constraint + if name is None: + column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) + column = column_token.value[1:-1] + name = '__check__%s' % column + columns = [column] + else: + columns = [] + constraints[name] = { + 'check': True, + 'columns': columns, + 'primary_key': False, + 'unique': False, + 'foreign_key': False, + 'index': False, + } # Get the index info cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) for row in cursor.fetchall(): diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index 7aa1f28f53..ad22d03763 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -12,10 +12,10 @@ from django.db.utils import NotSupportedError class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_delete_table = "DROP TABLE %(table)s" - sql_create_fk = None sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)" sql_delete_unique = "DROP INDEX %(name)s" + sql_foreign_key_constraint = None def __enter__(self): # Some SQLite schema alterations need foreign key constraints to be diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 2bad8db221..698b278fe8 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -10,16 +10,22 @@ class BaseConstraint: def constraint_sql(self, model, schema_editor): raise NotImplementedError('This method must be implemented by a subclass.') + def full_constraint_sql(self, model, schema_editor): + return schema_editor.sql_constraint % { + 'name': schema_editor.quote_name(self.name), + 'constraint': self.constraint_sql(model, schema_editor), + } + def create_sql(self, model, schema_editor): - sql = self.constraint_sql(model, schema_editor) - return schema_editor.sql_create_check % { + sql = self.full_constraint_sql(model, schema_editor) + return schema_editor.sql_create_constraint % { 'table': schema_editor.quote_name(model._meta.db_table), - 'check': sql, + 'constraint': sql, } def remove_sql(self, model, schema_editor): quote_name = schema_editor.quote_name - return schema_editor.sql_delete_check % { + return schema_editor.sql_delete_constraint % { 'table': quote_name(model._meta.db_table), 'name': quote_name(self.name), } @@ -46,10 +52,7 @@ class CheckConstraint(BaseConstraint): compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') sql, params = where.as_sql(compiler, connection) params = tuple(schema_editor.quote_value(p) for p in params) - return schema_editor.sql_check % { - 'name': schema_editor.quote_name(self.name), - 'check': sql % params, - } + return schema_editor.sql_check_constraint % {'check': sql % params} def __repr__(self): return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 4dd3c72e20..6be7e3a544 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -293,6 +293,13 @@ Database backend API * Third party database backends must implement support for partial indexes or set ``DatabaseFeatures.supports_partial_indexes`` to ``False``. +* Several ``SchemaEditor`` attributes are changed: + + * ``sql_create_check`` is replaced with ``sql_create_constraint``. + * ``sql_delete_check`` is replaced with ``sql_delete_constraint``. + * ``sql_create_fk`` is replaced with ``sql_foreign_key_constraint``, + ``sql_constraint``, and ``sql_create_constraint``. + Admin actions are no longer collected from base ``ModelAdmin`` classes ---------------------------------------------------------------------- diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 28a5c4ba34..b144c24a28 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -1,10 +1,15 @@ -from django.db import IntegrityError, models +from django.db import IntegrityError, connection, models from django.db.models.constraints import BaseConstraint from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from .models import Product +def get_constraints(table): + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + class BaseConstraintTests(SimpleTestCase): def test_constraint_sql(self): c = BaseConstraint('name') @@ -37,3 +42,11 @@ class CheckConstraintTests(TestCase): Product.objects.create(name='Valid', price=10, discounted_price=5) with self.assertRaises(IntegrityError): Product.objects.create(name='Invalid', price=10, discounted_price=20) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_name(self): + constraints = get_constraints(Product._meta.db_table) + expected_name = 'price_gt_discounted_price' + if connection.features.uppercases_column_names: + expected_name = expected_name.upper() + self.assertIn(expected_name, constraints) diff --git a/tests/schema/tests.py b/tests/schema/tests.py index 7f170c863e..6f2b6df765 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -2145,29 +2145,29 @@ class SchemaTests(TransactionTestCase): self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table)) constraint_name = "CamelCaseUniqConstraint" - editor.execute( - editor.sql_create_unique % { - "table": editor.quote_name(table), - "name": editor.quote_name(constraint_name), - "columns": editor.quote_name(field.column), - } - ) + editor.execute(editor._create_unique_sql(model, [field.column], constraint_name)) if connection.features.uppercases_column_names: constraint_name = constraint_name.upper() self.assertIn(constraint_name, self.get_constraints(model._meta.db_table)) editor.alter_field(model, get_field(unique=True), field, strict=True) self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table)) - if editor.sql_create_fk: + if editor.sql_foreign_key_constraint: constraint_name = "CamelCaseFKConstraint" + fk_sql = editor.sql_foreign_key_constraint % { + "column": editor.quote_name(column), + "to_table": editor.quote_name(table), + "to_column": editor.quote_name(model._meta.auto_field.column), + "deferrable": connection.ops.deferrable_sql(), + } + constraint_sql = editor.sql_constraint % { + "name": editor.quote_name(constraint_name), + "constraint": fk_sql, + } editor.execute( - editor.sql_create_fk % { + editor.sql_create_constraint % { "table": editor.quote_name(table), - "name": editor.quote_name(constraint_name), - "column": editor.quote_name(column), - "to_table": editor.quote_name(table), - "to_column": editor.quote_name(model._meta.auto_field.column), - "deferrable": connection.ops.deferrable_sql(), + "constraint": constraint_sql, } ) if connection.features.uppercases_column_names: