Fixed #30183 -- Added introspection of inline SQLite constraints.

This commit is contained in:
Paveł Tyślacki 2019-02-28 00:47:29 +03:00 committed by Tim Graham
parent 406de977ea
commit 782d85b6df
5 changed files with 332 additions and 51 deletions

View File

@ -217,50 +217,124 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
} }
return constraints return constraints
def _parse_table_constraints(self, sql): def _parse_column_or_constraint_definition(self, tokens, columns):
# Check constraint parsing is based of SQLite syntax diagram. token = None
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint is_constraint_definition = None
def next_ttype(ttype): field_name = None
constraint_name = None
unique = False
unique_columns = []
check = False
check_columns = []
braces_deep = 0
for token in tokens: for token in tokens:
if token.ttype == ttype: if token.match(sqlparse.tokens.Punctuation, '('):
return token braces_deep += 1
elif token.match(sqlparse.tokens.Punctuation, ')'):
statement = sqlparse.parse(sql)[0] braces_deep -= 1
constraints = {} if braces_deep < 0:
tokens = statement.flatten() # End of columns and constraints for table definition.
for token in tokens: break
name = None elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'): # End of current column or constraint definition.
# Table constraint break
name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) # Detect column or constraint definition by first token.
name = name_token.value[1:-1] if is_constraint_definition is None:
token = next_ttype(sqlparse.tokens.Keyword) is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
if is_constraint_definition:
continue
if is_constraint_definition:
# Detect constraint name by second token.
if constraint_name is None:
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
constraint_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
constraint_name = token.value[1:-1]
# Start constraint columns parsing after UNIQUE keyword.
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
constraints[name] = { unique = True
unique_braces_deep = braces_deep
elif unique:
if unique_braces_deep == braces_deep:
if unique_columns:
# Stop constraint parsing.
unique = False
continue
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
unique_columns.append(token.value)
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
unique_columns.append(token.value[1:-1])
else:
# Detect field name by first token.
if field_name is None:
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
field_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
field_name = token.value[1:-1]
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique_columns = [field_name]
# Start constraint columns parsing after CHECK keyword.
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
check = True
check_braces_deep = braces_deep
elif check:
if check_braces_deep == braces_deep:
if check_columns:
# Stop constraint parsing.
check = False
continue
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
if token.value in columns:
check_columns.append(token.value)
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
if token.value[1:-1] in columns:
check_columns.append(token.value[1:-1])
unique_constraint = {
'unique': True, 'unique': True,
'columns': [], 'columns': unique_columns,
'primary_key': False, 'primary_key': False,
'foreign_key': None, 'foreign_key': None,
'check': False, 'check': False,
'index': False, 'index': False,
} } if unique_columns else None
if token.match(sqlparse.tokens.Keyword, 'CHECK'): check_constraint = {
# Column check constraint
if name is None:
column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
column = column_token.value[1:-1]
name = '__check__%s' % column
columns = [column]
else:
columns = []
constraints[name] = {
'check': True, 'check': True,
'columns': columns, 'columns': check_columns,
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': False,
'foreign_key': None, 'foreign_key': None,
'index': False, 'index': False,
} } if check_columns else None
return constraint_name, unique_constraint, check_constraint, token
def _parse_table_constraints(self, sql, columns):
# Check constraint parsing is based of SQLite syntax diagram.
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
statement = sqlparse.parse(sql)[0]
constraints = {}
unnamed_constrains_index = 0
tokens = (token for token in statement.flatten() if not token.is_whitespace)
# Go to columns and constraint definition
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, '('):
break
# Parse columns and constraint definition
while True:
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
if unique:
if constraint_name:
constraints[constraint_name] = unique
else:
unnamed_constrains_index += 1
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
if check:
if constraint_name:
constraints[constraint_name] = check
else:
unnamed_constrains_index += 1
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
if end_token.match(sqlparse.tokens.Punctuation, ')'):
break
return constraints return constraints
def get_constraints(self, cursor, table_name): def get_constraints(self, cursor, table_name):
@ -280,7 +354,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# table_name is a view. # table_name is a view.
pass pass
else: else:
constraints.update(self._parse_table_constraints(table_schema)) columns = {info.name for info in self.get_table_description(cursor, table_name)}
constraints.update(self._parse_table_constraints(table_schema, columns))
# Get the index info # Get the index info
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
@ -288,6 +363,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# SQLite 3.8.9+ has 5 columns, however older versions only give 3 # SQLite 3.8.9+ has 5 columns, however older versions only give 3
# columns. Discard last 2 columns if there. # columns. Discard last 2 columns if there.
number, index, unique = row[:3] number, index, unique = row[:3]
cursor.execute(
"SELECT sql FROM sqlite_master "
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
# There's at most one row.
sql, = cursor.fetchone() or (None,)
# Inline constraints are already detected in
# _parse_table_constraints(). The reasons to avoid fetching inline
# constraints from `PRAGMA index_list` are:
# - Inline constraints can have a different name and information
# than what `PRAGMA index_list` gives.
# - Not all inline constraints may appear in `PRAGMA index_list`.
if not sql:
# An inline constraint
continue
# Get the index info for that index # Get the index info for that index
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
for index_rank, column_rank, column in cursor.fetchall(): for index_rank, column_rank, column in cursor.fetchall():
@ -305,13 +395,6 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
if constraints[index]['index'] and not constraints[index]['unique']: if constraints[index]['index'] and not constraints[index]['unique']:
# SQLite doesn't support any index type other than b-tree # SQLite doesn't support any index type other than b-tree
constraints[index]['type'] = Index.suffix constraints[index]['type'] = Index.suffix
cursor.execute(
"SELECT sql FROM sqlite_master "
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
orders = []
# There would be only 1 row to loop over
for sql, in cursor.fetchall():
order_info = sql.split('(')[-1].split(')')[0].split(',') order_info = sql.split('(')[-1].split(')')[0].split(',')
orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info] orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info]
constraints[index]['orders'] = orders constraints[index]['orders'] = orders

View File

@ -1,5 +1,7 @@
import unittest import unittest
import sqlparse
from django.db import connection from django.db import connection
from django.test import TestCase from django.test import TestCase
@ -25,3 +27,116 @@ class IntrospectionTests(TestCase):
self.assertEqual(field, expected_string) self.assertEqual(field, expected_string)
finally: finally:
cursor.execute('DROP TABLE test_primary') cursor.execute('DROP TABLE test_primary')
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
class ParsingTests(TestCase):
def parse_definition(self, sql, columns):
"""Parse a column or constraint definition."""
statement = sqlparse.parse(sql)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace)
with connection.cursor():
return connection.introspection._parse_column_or_constraint_definition(tokens, set(columns))
def assertConstraint(self, constraint_details, cols, unique=False, check=False):
self.assertEqual(constraint_details, {
'unique': unique,
'columns': cols,
'primary_key': False,
'foreign_key': None,
'check': check,
'index': False,
})
def test_unique_column(self):
tests = (
('"ref" integer UNIQUE,', ['ref']),
('ref integer UNIQUE,', ['ref']),
('"customname" integer UNIQUE,', ['customname']),
('customname integer UNIQUE,', ['customname']),
)
for sql, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertIsNone(constraint)
self.assertConstraint(details, columns, unique=True)
self.assertIsNone(check)
def test_unique_constraint(self):
tests = (
('CONSTRAINT "ref" UNIQUE ("ref"),', 'ref', ['ref']),
('CONSTRAINT ref UNIQUE (ref),', 'ref', ['ref']),
('CONSTRAINT "customname1" UNIQUE ("customname2"),', 'customname1', ['customname2']),
('CONSTRAINT customname1 UNIQUE (customname2),', 'customname1', ['customname2']),
)
for sql, constraint_name, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertEqual(constraint, constraint_name)
self.assertConstraint(details, columns, unique=True)
self.assertIsNone(check)
def test_unique_constraint_multicolumn(self):
tests = (
('CONSTRAINT "ref" UNIQUE ("ref", "customname"),', 'ref', ['ref', 'customname']),
('CONSTRAINT ref UNIQUE (ref, customname),', 'ref', ['ref', 'customname']),
)
for sql, constraint_name, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertEqual(constraint, constraint_name)
self.assertConstraint(details, columns, unique=True)
self.assertIsNone(check)
def test_check_column(self):
tests = (
('"ref" varchar(255) CHECK ("ref" != \'test\'),', ['ref']),
('ref varchar(255) CHECK (ref != \'test\'),', ['ref']),
('"customname1" varchar(255) CHECK ("customname2" != \'test\'),', ['customname2']),
('customname1 varchar(255) CHECK (customname2 != \'test\'),', ['customname2']),
)
for sql, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertIsNone(constraint)
self.assertIsNone(details)
self.assertConstraint(check, columns, check=True)
def test_check_constraint(self):
tests = (
('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', 'ref', ['ref']),
('CONSTRAINT ref CHECK (ref != \'test\'),', 'ref', ['ref']),
('CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),', 'customname1', ['customname2']),
('CONSTRAINT customname1 CHECK (customname2 != \'test\'),', 'customname1', ['customname2']),
)
for sql, constraint_name, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertEqual(constraint, constraint_name)
self.assertIsNone(details)
self.assertConstraint(check, columns, check=True)
def test_check_column_with_operators_and_functions(self):
tests = (
('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ['ref']),
('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ['ref']),
('"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),', ['ref', 'max_length']),
)
for sql, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertIsNone(constraint)
self.assertIsNone(details)
self.assertConstraint(check, columns, check=True)
def test_check_and_unique_column(self):
tests = (
('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ['ref']),
('ref varchar(255) UNIQUE CHECK (ref != \'test\'),', ['ref']),
)
for sql, columns in tests:
with self.subTest(sql=sql):
constraint, details, check, _ = self.parse_definition(sql, columns)
self.assertIsNone(constraint)
self.assertConstraint(details, columns, unique=True)
self.assertConstraint(check, columns, check=True)

View File

@ -58,3 +58,21 @@ class ArticleReporter(models.Model):
class Meta: class Meta:
managed = False managed = False
class Comment(models.Model):
ref = models.UUIDField(unique=True)
article = models.ForeignKey(Article, models.CASCADE, db_index=True)
email = models.EmailField()
pub_date = models.DateTimeField()
up_votes = models.PositiveIntegerField()
body = models.TextField()
class Meta:
constraints = [
models.CheckConstraint(name='up_votes_gte_0_check', check=models.Q(up_votes__gte=0)),
models.UniqueConstraint(fields=['article', 'email', 'pub_date'], name='article_email_pub_date_uniq'),
]
indexes = [
models.Index(fields=['email', 'pub_date'], name='email_pub_date_idx'),
]

View File

@ -5,7 +5,7 @@ from django.db.models import Index
from django.db.utils import DatabaseError from django.db.utils import DatabaseError
from django.test import TransactionTestCase, skipUnlessDBFeature from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import Article, ArticleReporter, City, District, Reporter from .models import Article, ArticleReporter, City, Comment, District, Reporter
class IntrospectionTests(TransactionTestCase): class IntrospectionTests(TransactionTestCase):
@ -211,3 +211,60 @@ class IntrospectionTests(TransactionTestCase):
self.assertEqual(val['orders'], ['ASC'] * len(val['columns'])) self.assertEqual(val['orders'], ['ASC'] * len(val['columns']))
indexes_verified += 1 indexes_verified += 1
self.assertEqual(indexes_verified, 4) self.assertEqual(indexes_verified, 4)
def test_get_constraints(self):
def assertDetails(details, cols, primary_key=False, unique=False, index=False, check=False, foreign_key=None):
# Different backends have different values for same constraints:
# PRIMARY KEY UNIQUE CONSTRAINT UNIQUE INDEX
# MySQL pk=1 uniq=1 idx=1 pk=0 uniq=1 idx=1 pk=0 uniq=1 idx=1
# PostgreSQL pk=1 uniq=1 idx=0 pk=0 uniq=1 idx=0 pk=0 uniq=1 idx=1
# SQLite pk=1 uniq=0 idx=0 pk=0 uniq=1 idx=0 pk=0 uniq=1 idx=1
if details['primary_key']:
details['unique'] = True
if details['unique']:
details['index'] = False
self.assertEqual(details['columns'], cols)
self.assertEqual(details['primary_key'], primary_key)
self.assertEqual(details['unique'], unique)
self.assertEqual(details['index'], index)
self.assertEqual(details['check'], check)
self.assertEqual(details['foreign_key'], foreign_key)
with connection.cursor() as cursor:
constraints = connection.introspection.get_constraints(cursor, Comment._meta.db_table)
# Test custom constraints
custom_constraints = {
'article_email_pub_date_uniq',
'email_pub_date_idx',
}
if connection.features.supports_column_check_constraints:
custom_constraints.add('up_votes_gte_0_check')
assertDetails(constraints['up_votes_gte_0_check'], ['up_votes'], check=True)
assertDetails(constraints['article_email_pub_date_uniq'], ['article_id', 'email', 'pub_date'], unique=True)
assertDetails(constraints['email_pub_date_idx'], ['email', 'pub_date'], index=True)
# Test field constraints
field_constraints = set()
for name, details in constraints.items():
if name in custom_constraints:
continue
elif details['columns'] == ['up_votes'] and details['check']:
assertDetails(details, ['up_votes'], check=True)
field_constraints.add(name)
elif details['columns'] == ['ref'] and details['unique']:
assertDetails(details, ['ref'], unique=True)
field_constraints.add(name)
elif details['columns'] == ['article_id'] and details['index']:
assertDetails(details, ['article_id'], index=True)
field_constraints.add(name)
elif details['columns'] == ['id'] and details['primary_key']:
assertDetails(details, ['id'], primary_key=True, unique=True)
field_constraints.add(name)
elif details['columns'] == ['article_id'] and details['foreign_key']:
assertDetails(details, ['article_id'], foreign_key=('introspection_article', 'id'))
field_constraints.add(name)
elif details['check']:
# Some databases (e.g. Oracle) include additional check
# constraints.
field_constraints.add(name)
# All constraints are accounted for.
self.assertEqual(constraints.keys() ^ (custom_constraints | field_constraints), set())

View File

@ -129,6 +129,14 @@ class SchemaTests(TransactionTestCase):
if c['index'] and len(c['columns']) == 1 if c['index'] and len(c['columns']) == 1
] ]
def get_uniques(self, table):
with connection.cursor() as cursor:
return [
c['columns'][0]
for c in connection.introspection.get_constraints(cursor, table).values()
if c['unique'] and len(c['columns']) == 1
]
def get_constraints(self, table): def get_constraints(self, table):
""" """
Get the constraints on a table using a new cursor. Get the constraints on a table using a new cursor.
@ -1971,7 +1979,7 @@ class SchemaTests(TransactionTestCase):
editor.add_field(Book, new_field3) editor.add_field(Book, new_field3)
self.assertIn( self.assertIn(
"slug", "slug",
self.get_indexes(Book._meta.db_table), self.get_uniques(Book._meta.db_table),
) )
# Remove the unique, check the index goes with it # Remove the unique, check the index goes with it
new_field4 = CharField(max_length=20, unique=False) new_field4 = CharField(max_length=20, unique=False)
@ -1980,7 +1988,7 @@ class SchemaTests(TransactionTestCase):
editor.alter_field(BookWithSlug, new_field3, new_field4, strict=True) editor.alter_field(BookWithSlug, new_field3, new_field4, strict=True)
self.assertNotIn( self.assertNotIn(
"slug", "slug",
self.get_indexes(Book._meta.db_table), self.get_uniques(Book._meta.db_table),
) )
def test_text_field_with_db_index(self): def test_text_field_with_db_index(self):