diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index a64e82e004..86133a254e 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -147,6 +147,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): return self.connection.features._mysql_storage_engine return result[0] + def _parse_constraint_columns(self, check_clause): + check_columns = OrderedSet() + statement = sqlparse.parse(check_clause)[0] + tokens = (token for token in statement.flatten() if not token.is_whitespace) + for token in tokens: + if token.ttype in [sqlparse.tokens.Name, sqlparse.tokens.Literal.String.Single]: + check_columns.add(token.value[1:-1]) + return check_columns + def get_constraints(self, cursor, table_name): """ Retrieve any constraints or keys (unique, pk, fk, check, index) across @@ -201,14 +210,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): """ cursor.execute(type_query, [table_name]) for constraint, check_clause in cursor.fetchall(): - # Parse columns. - columns = OrderedSet() - for statement in sqlparse.parse(check_clause): - for token in statement.flatten(): - if token.ttype in [sqlparse.tokens.Name, sqlparse.tokens.Literal.String.Single]: - columns.add(token.value[1:-1]) constraints[constraint] = { - 'columns': columns, + 'columns': self._parse_constraint_columns(check_clause), 'primary_key': False, 'unique': False, 'index': False, diff --git a/tests/backends/mysql/test_introspection.py b/tests/backends/mysql/test_introspection.py new file mode 100644 index 0000000000..1b8fb94ed9 --- /dev/null +++ b/tests/backends/mysql/test_introspection.py @@ -0,0 +1,22 @@ +from unittest import skipUnless + +from django.db import connection +from django.test import TestCase + + +@skipUnless(connection.vendor == 'mysql', 'MySQL tests') +class ParsingTests(TestCase): + def test_parse_constraint_columns(self): + tests = ( + ('`height` >= 0', ['height']), + ('`cost` BETWEEN 1 AND 10', ['cost']), + ('`ref1` > `ref2`', ['ref1', 'ref2']), + ( + '`start` IS NULL OR `end` IS NULL OR `start` < `end`', + ['start', 'end'], + ), + ) + for check_clause, expected_columns in tests: + with self.subTest(check_clause): + check_columns = connection.introspection._parse_constraint_columns(check_clause) + self.assertEqual(list(check_columns), expected_columns)