Refs #30593 -- Fixed introspection of check constraints columns on MariaDB.

This commit is contained in:
Hasan Ramezani 2019-07-30 16:21:05 +02:00 committed by Mariusz Felisiak
parent b2aad9ad4d
commit e3fc9af4ab
2 changed files with 20 additions and 8 deletions

View File

@ -147,12 +147,16 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
return self.connection.features._mysql_storage_engine return self.connection.features._mysql_storage_engine
return result[0] return result[0]
def _parse_constraint_columns(self, check_clause): def _parse_constraint_columns(self, check_clause, columns):
check_columns = OrderedSet() check_columns = OrderedSet()
statement = sqlparse.parse(check_clause)[0] statement = sqlparse.parse(check_clause)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace) tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens: for token in tokens:
if token.ttype in [sqlparse.tokens.Name, sqlparse.tokens.Literal.String.Single]: if (
token.ttype == sqlparse.tokens.Name and
self.connection.ops.quote_name(token.value) == token.value and
token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1]) check_columns.add(token.value[1:-1])
return check_columns return check_columns
@ -201,6 +205,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
constraints[constraint]['unique'] = True constraints[constraint]['unique'] = True
# Add check constraints. # Add check constraints.
if self.connection.features.can_introspect_check_constraints: if self.connection.features.can_introspect_check_constraints:
columns = {info.name for info in self.get_table_description(cursor, table_name)}
type_query = """ type_query = """
SELECT c.constraint_name, c.check_clause SELECT c.constraint_name, c.check_clause
FROM information_schema.check_constraints AS c FROM information_schema.check_constraints AS c
@ -211,7 +216,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute(type_query, [table_name]) cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall(): for constraint, check_clause in cursor.fetchall():
constraints[constraint] = { constraints[constraint] = {
'columns': self._parse_constraint_columns(check_clause), 'columns': self._parse_constraint_columns(check_clause, columns),
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': False,
'index': False, 'index': False,

View File

@ -7,16 +7,23 @@ from django.test import TestCase
@skipUnless(connection.vendor == 'mysql', 'MySQL tests') @skipUnless(connection.vendor == 'mysql', 'MySQL tests')
class ParsingTests(TestCase): class ParsingTests(TestCase):
def test_parse_constraint_columns(self): def test_parse_constraint_columns(self):
_parse_constraint_columns = connection.introspection._parse_constraint_columns
tests = ( tests = (
('`height` >= 0', ['height']), ('`height` >= 0', ['height'], ['height']),
('`cost` BETWEEN 1 AND 10', ['cost']), ('`cost` BETWEEN 1 AND 10', ['cost'], ['cost']),
('`ref1` > `ref2`', ['ref1', 'ref2']), ('`ref1` > `ref2`', ['id', 'ref1', 'ref2'], ['ref1', 'ref2']),
( (
'`start` IS NULL OR `end` IS NULL OR `start` < `end`', '`start` IS NULL OR `end` IS NULL OR `start` < `end`',
['id', 'start', 'end'],
['start', 'end'], ['start', 'end'],
), ),
('JSON_VALID(`json_field`)', ['json_field'], ['json_field']),
('CHAR_LENGTH(`name`) > 2', ['name'], ['name']),
("lower(`ref1`) != 'test'", ['id', 'owe', 'ref1'], ['ref1']),
("lower(`ref1`) != 'test'", ['id', 'lower', 'ref1'], ['ref1']),
("`name` LIKE 'test%'", ['name'], ['name']),
) )
for check_clause, expected_columns in tests: for check_clause, table_columns, expected_columns in tests:
with self.subTest(check_clause): with self.subTest(check_clause):
check_columns = connection.introspection._parse_constraint_columns(check_clause) check_columns = _parse_constraint_columns(check_clause, table_columns)
self.assertEqual(list(check_columns), expected_columns) self.assertEqual(list(check_columns), expected_columns)