Fixed #30593 -- Added support for check constraints on MariaDB 10.2+.

This commit is contained in:
Hasan Ramezani 2019-07-14 01:24:35 +02:00 committed by Mariusz Felisiak
parent 7f612eda80
commit 1fc2c70f76
8 changed files with 66 additions and 7 deletions

View File

@ -181,6 +181,8 @@ class BaseDatabaseFeatures:
# Does it support CHECK constraints? # Does it support CHECK constraints?
supports_column_check_constraints = True supports_column_check_constraints = True
supports_table_check_constraints = True supports_table_check_constraints = True
# Does the backend support introspection of CHECK constraints?
can_introspect_check_constraints = True
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value}) # Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
# parameter passing? Note this can be provided by the backend even if not # parameter passing? Note this can be provided by the backend even if not

View File

@ -61,6 +61,7 @@ class CursorWrapper:
codes_for_integrityerror = ( codes_for_integrityerror = (
1048, # Column cannot be null 1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range 1690, # BIGINT UNSIGNED value is out of range
4025, # CHECK constraint failed
) )
def __init__(self, cursor): def __init__(self, cursor):
@ -328,6 +329,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else: else:
return True return True
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
return {
'PositiveIntegerField': '`%(column)s` >= 0',
'PositiveSmallIntegerField': '`%(column)s` >= 0',
}
return {}
@cached_property @cached_property
def mysql_server_info(self): def mysql_server_info(self):
with self.temporary_connection() as cursor: with self.temporary_connection() as cursor:

View File

@ -27,8 +27,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_auto_pk_0 = False allows_auto_pk_0 = False
can_release_savepoints = True can_release_savepoints = True
atomic_transactions = False atomic_transactions = False
supports_column_check_constraints = False
supports_table_check_constraints = False
can_clone_databases = True can_clone_databases = True
supports_temporal_subtraction = True supports_temporal_subtraction = True
supports_select_intersection = False supports_select_intersection = False
@ -89,6 +87,20 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return self.connection.mysql_version >= (10, 2) return self.connection.mysql_version >= (10, 2)
return self.connection.mysql_version >= (8, 0, 2) return self.connection.mysql_version >= (8, 0, 2)
@cached_property
def supports_column_check_constraints(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 2, 1)
supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints'))
@cached_property
def can_introspect_check_constraints(self):
if self.connection.mysql_is_mariadb:
version = self.connection.mysql_version
if (version >= (10, 2, 22) and version < (10, 3)) or version >= (10, 3, 10):
return True
return False
@cached_property @cached_property
def has_select_for_update_skip_locked(self): def has_select_for_update_skip_locked(self):
return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1) return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)

View File

@ -1,5 +1,6 @@
from collections import namedtuple from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import ( from django.db.backends.base.introspection import (
@ -189,6 +190,31 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
constraints[constraint]['unique'] = True constraints[constraint]['unique'] = True
elif kind.lower() == "unique": elif kind.lower() == "unique":
constraints[constraint]['unique'] = True constraints[constraint]['unique'] = True
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
type_query = """
SELECT c.constraint_name, c.check_clause
FROM information_schema.check_constraints AS c
WHERE
c.constraint_schema = DATABASE() AND
c.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
# Parse columns.
columns = OrderedSet()
for statement in sqlparse.parse(check_clause):
for token in statement.flatten():
if token.ttype in [sqlparse.tokens.Name, sqlparse.tokens.Literal.String.Single]:
columns.add(token.value[1:-1])
constraints[constraint] = {
'columns': columns,
'primary_key': False,
'unique': False,
'index': False,
'check': True,
'foreign_key': None,
}
# Now add in the indexes # Now add in the indexes
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)) cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
for table, non_unique, index, colseq, column, type_ in [x[:5] + (x[10],) for x in cursor.fetchall()]: for table, non_unique, index, colseq, column, type_ in [x[:5] + (x[10],) for x in cursor.fetchall()]:

View File

@ -28,9 +28,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY" sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s' sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s'
# The name of the column check constraint is the same as the field name on
# MariaDB. Adding IF EXISTS clause prevents migrations crash. Constraint is
# removed during a "MODIFY" column statement.
sql_delete_check = 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s'
def quote_value(self, value): def quote_value(self, value):
self.connection.ensure_connection() self.connection.ensure_connection()
if isinstance(value, str):
value = value.replace('%', '%%')
# MySQLdb escapes to string, PyMySQL to bytes. # MySQLdb escapes to string, PyMySQL to bytes.
quoted = self.connection.connection.escape(value, self.connection.connection.encoders) quoted = self.connection.connection.escape(value, self.connection.connection.encoders)
if isinstance(value, str) and isinstance(quoted, bytes): if isinstance(value, str) and isinstance(quoted, bytes):

View File

@ -73,7 +73,7 @@ class CheckConstraintTests(TestCase):
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
Product.objects.create(name='Invalid', price=10, discounted_price=20) Product.objects.create(name='Invalid', price=10, discounted_price=20)
@skipUnlessDBFeature('supports_table_check_constraints') @skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
def test_name(self): def test_name(self):
constraints = get_constraints(Product._meta.db_table) constraints = get_constraints(Product._meta.db_table)
for expected_name in ( for expected_name in (
@ -83,7 +83,7 @@ class CheckConstraintTests(TestCase):
with self.subTest(expected_name): with self.subTest(expected_name):
self.assertIn(expected_name, constraints) self.assertIn(expected_name, constraints)
@skipUnlessDBFeature('supports_table_check_constraints') @skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
def test_abstract_name(self): def test_abstract_name(self):
constraints = get_constraints(ChildModel._meta.db_table) constraints = get_constraints(ChildModel._meta.db_table)
self.assertIn('constraints_childmodel_adult', constraints) self.assertIn('constraints_childmodel_adult', constraints)

View File

@ -237,7 +237,10 @@ class IntrospectionTests(TransactionTestCase):
'article_email_pub_date_uniq', 'article_email_pub_date_uniq',
'email_pub_date_idx', 'email_pub_date_idx',
} }
if connection.features.supports_column_check_constraints: if (
connection.features.supports_column_check_constraints and
connection.features.can_introspect_check_constraints
):
custom_constraints.add('up_votes_gte_0_check') custom_constraints.add('up_votes_gte_0_check')
assertDetails(constraints['up_votes_gte_0_check'], ['up_votes'], check=True) assertDetails(constraints['up_votes_gte_0_check'], ['up_votes'], check=True)
assertDetails(constraints['article_email_pub_date_uniq'], ['article_id', 'email', 'pub_date'], unique=True) assertDetails(constraints['article_email_pub_date_uniq'], ['article_id', 'email', 'pub_date'], unique=True)

View File

@ -1556,7 +1556,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the m2m table is still there. # Ensure the m2m table is still there.
self.assertEqual(len(self.column_classes(LocalM2M)), 1) self.assertEqual(len(self.column_classes(LocalM2M)), 1)
@skipUnlessDBFeature('supports_column_check_constraints') @skipUnlessDBFeature('supports_column_check_constraints', 'can_introspect_check_constraints')
def test_check_constraints(self): def test_check_constraints(self):
""" """
Tests creating/deleting CHECK constraints Tests creating/deleting CHECK constraints
@ -1586,7 +1586,7 @@ class SchemaTests(TransactionTestCase):
if not any(details['columns'] == ['height'] and details['check'] for details in constraints.values()): if not any(details['columns'] == ['height'] and details['check'] for details in constraints.values()):
self.fail("No check constraint for height found") self.fail("No check constraint for height found")
@skipUnlessDBFeature('supports_column_check_constraints') @skipUnlessDBFeature('supports_column_check_constraints', 'can_introspect_check_constraints')
def test_remove_field_check_does_not_remove_meta_constraints(self): def test_remove_field_check_does_not_remove_meta_constraints(self):
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_model(Author) editor.create_model(Author)