Refs #30593 -- Added _parse_constraint_columns() hook to introspection on MariaDB.

This commit is contained in:
Hasan Ramezani 2019-07-30 15:18:12 +02:00 committed by Mariusz Felisiak
parent 421c4cd2ee
commit b2aad9ad4d
2 changed files with 32 additions and 7 deletions

View File

@ -147,6 +147,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
return self.connection.features._mysql_storage_engine return self.connection.features._mysql_storage_engine
return result[0] return result[0]
def _parse_constraint_columns(self, check_clause):
check_columns = OrderedSet()
statement = sqlparse.parse(check_clause)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if token.ttype in [sqlparse.tokens.Name, sqlparse.tokens.Literal.String.Single]:
check_columns.add(token.value[1:-1])
return check_columns
def get_constraints(self, cursor, table_name): def get_constraints(self, cursor, table_name):
""" """
Retrieve any constraints or keys (unique, pk, fk, check, index) across Retrieve any constraints or keys (unique, pk, fk, check, index) across
@ -201,14 +210,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
""" """
cursor.execute(type_query, [table_name]) cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall(): for constraint, check_clause in cursor.fetchall():
# Parse columns.
columns = OrderedSet()
for statement in sqlparse.parse(check_clause):
for token in statement.flatten():
if token.ttype in [sqlparse.tokens.Name, sqlparse.tokens.Literal.String.Single]:
columns.add(token.value[1:-1])
constraints[constraint] = { constraints[constraint] = {
'columns': columns, 'columns': self._parse_constraint_columns(check_clause),
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': False,
'index': False, 'index': False,

View File

@ -0,0 +1,22 @@
from unittest import skipUnless
from django.db import connection
from django.test import TestCase
@skipUnless(connection.vendor == 'mysql', 'MySQL tests')
class ParsingTests(TestCase):
def test_parse_constraint_columns(self):
tests = (
('`height` >= 0', ['height']),
('`cost` BETWEEN 1 AND 10', ['cost']),
('`ref1` > `ref2`', ['ref1', 'ref2']),
(
'`start` IS NULL OR `end` IS NULL OR `start` < `end`',
['start', 'end'],
),
)
for check_clause, expected_columns in tests:
with self.subTest(check_clause):
check_columns = connection.introspection._parse_constraint_columns(check_clause)
self.assertEqual(list(check_columns), expected_columns)