mirror of https://github.com/django/django.git
Fixed #30183 -- Added introspection of inline SQLite constraints.
This commit is contained in:
parent
406de977ea
commit
782d85b6df
|
@ -217,50 +217,124 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
}
|
||||
return constraints
|
||||
|
||||
def _parse_table_constraints(self, sql):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
def next_ttype(ttype):
|
||||
def _parse_column_or_constraint_definition(self, tokens, columns):
|
||||
token = None
|
||||
is_constraint_definition = None
|
||||
field_name = None
|
||||
constraint_name = None
|
||||
unique = False
|
||||
unique_columns = []
|
||||
check = False
|
||||
check_columns = []
|
||||
braces_deep = 0
|
||||
for token in tokens:
|
||||
if token.ttype == ttype:
|
||||
return token
|
||||
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
tokens = statement.flatten()
|
||||
for token in tokens:
|
||||
name = None
|
||||
if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'):
|
||||
# Table constraint
|
||||
name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
|
||||
name = name_token.value[1:-1]
|
||||
token = next_ttype(sqlparse.tokens.Keyword)
|
||||
if token.match(sqlparse.tokens.Punctuation, '('):
|
||||
braces_deep += 1
|
||||
elif token.match(sqlparse.tokens.Punctuation, ')'):
|
||||
braces_deep -= 1
|
||||
if braces_deep < 0:
|
||||
# End of columns and constraints for table definition.
|
||||
break
|
||||
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
|
||||
# End of current column or constraint definition.
|
||||
break
|
||||
# Detect column or constraint definition by first token.
|
||||
if is_constraint_definition is None:
|
||||
is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
|
||||
if is_constraint_definition:
|
||||
continue
|
||||
if is_constraint_definition:
|
||||
# Detect constraint name by second token.
|
||||
if constraint_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
constraint_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
constraint_name = token.value[1:-1]
|
||||
# Start constraint columns parsing after UNIQUE keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
|
||||
constraints[name] = {
|
||||
unique = True
|
||||
unique_braces_deep = braces_deep
|
||||
elif unique:
|
||||
if unique_braces_deep == braces_deep:
|
||||
if unique_columns:
|
||||
# Stop constraint parsing.
|
||||
unique = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
unique_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
unique_columns.append(token.value[1:-1])
|
||||
else:
|
||||
# Detect field name by first token.
|
||||
if field_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
field_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
field_name = token.value[1:-1]
|
||||
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
|
||||
unique_columns = [field_name]
|
||||
# Start constraint columns parsing after CHECK keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
|
||||
check = True
|
||||
check_braces_deep = braces_deep
|
||||
elif check:
|
||||
if check_braces_deep == braces_deep:
|
||||
if check_columns:
|
||||
# Stop constraint parsing.
|
||||
check = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
if token.value in columns:
|
||||
check_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
if token.value[1:-1] in columns:
|
||||
check_columns.append(token.value[1:-1])
|
||||
unique_constraint = {
|
||||
'unique': True,
|
||||
'columns': [],
|
||||
'columns': unique_columns,
|
||||
'primary_key': False,
|
||||
'foreign_key': None,
|
||||
'check': False,
|
||||
'index': False,
|
||||
}
|
||||
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
|
||||
# Column check constraint
|
||||
if name is None:
|
||||
column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
|
||||
column = column_token.value[1:-1]
|
||||
name = '__check__%s' % column
|
||||
columns = [column]
|
||||
else:
|
||||
columns = []
|
||||
constraints[name] = {
|
||||
} if unique_columns else None
|
||||
check_constraint = {
|
||||
'check': True,
|
||||
'columns': columns,
|
||||
'columns': check_columns,
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'foreign_key': None,
|
||||
'index': False,
|
||||
}
|
||||
} if check_columns else None
|
||||
return constraint_name, unique_constraint, check_constraint, token
|
||||
|
||||
def _parse_table_constraints(self, sql, columns):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
unnamed_constrains_index = 0
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
# Go to columns and constraint definition
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, '('):
|
||||
break
|
||||
# Parse columns and constraint definition
|
||||
while True:
|
||||
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
|
||||
if unique:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = unique
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
|
||||
if check:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = check
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
|
||||
if end_token.match(sqlparse.tokens.Punctuation, ')'):
|
||||
break
|
||||
return constraints
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
|
@ -280,7 +354,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
# table_name is a view.
|
||||
pass
|
||||
else:
|
||||
constraints.update(self._parse_table_constraints(table_schema))
|
||||
columns = {info.name for info in self.get_table_description(cursor, table_name)}
|
||||
constraints.update(self._parse_table_constraints(table_schema, columns))
|
||||
|
||||
# Get the index info
|
||||
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
|
||||
|
@ -288,6 +363,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
|
||||
# columns. Discard last 2 columns if there.
|
||||
number, index, unique = row[:3]
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
# There's at most one row.
|
||||
sql, = cursor.fetchone() or (None,)
|
||||
# Inline constraints are already detected in
|
||||
# _parse_table_constraints(). The reasons to avoid fetching inline
|
||||
# constraints from `PRAGMA index_list` are:
|
||||
# - Inline constraints can have a different name and information
|
||||
# than what `PRAGMA index_list` gives.
|
||||
# - Not all inline constraints may appear in `PRAGMA index_list`.
|
||||
if not sql:
|
||||
# An inline constraint
|
||||
continue
|
||||
# Get the index info for that index
|
||||
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
|
||||
for index_rank, column_rank, column in cursor.fetchall():
|
||||
|
@ -305,13 +395,6 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
if constraints[index]['index'] and not constraints[index]['unique']:
|
||||
# SQLite doesn't support any index type other than b-tree
|
||||
constraints[index]['type'] = Index.suffix
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
orders = []
|
||||
# There would be only 1 row to loop over
|
||||
for sql, in cursor.fetchall():
|
||||
order_info = sql.split('(')[-1].split(')')[0].split(',')
|
||||
orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info]
|
||||
constraints[index]['orders'] = orders
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.db import connection
|
||||
from django.test import TestCase
|
||||
|
||||
|
@ -25,3 +27,116 @@ class IntrospectionTests(TestCase):
|
|||
self.assertEqual(field, expected_string)
|
||||
finally:
|
||||
cursor.execute('DROP TABLE test_primary')
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
|
||||
class ParsingTests(TestCase):
|
||||
def parse_definition(self, sql, columns):
|
||||
"""Parse a column or constraint definition."""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
with connection.cursor():
|
||||
return connection.introspection._parse_column_or_constraint_definition(tokens, set(columns))
|
||||
|
||||
def assertConstraint(self, constraint_details, cols, unique=False, check=False):
|
||||
self.assertEqual(constraint_details, {
|
||||
'unique': unique,
|
||||
'columns': cols,
|
||||
'primary_key': False,
|
||||
'foreign_key': None,
|
||||
'check': check,
|
||||
'index': False,
|
||||
})
|
||||
|
||||
def test_unique_column(self):
|
||||
tests = (
|
||||
('"ref" integer UNIQUE,', ['ref']),
|
||||
('ref integer UNIQUE,', ['ref']),
|
||||
('"customname" integer UNIQUE,', ['customname']),
|
||||
('customname integer UNIQUE,', ['customname']),
|
||||
)
|
||||
for sql, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertIsNone(constraint)
|
||||
self.assertConstraint(details, columns, unique=True)
|
||||
self.assertIsNone(check)
|
||||
|
||||
def test_unique_constraint(self):
|
||||
tests = (
|
||||
('CONSTRAINT "ref" UNIQUE ("ref"),', 'ref', ['ref']),
|
||||
('CONSTRAINT ref UNIQUE (ref),', 'ref', ['ref']),
|
||||
('CONSTRAINT "customname1" UNIQUE ("customname2"),', 'customname1', ['customname2']),
|
||||
('CONSTRAINT customname1 UNIQUE (customname2),', 'customname1', ['customname2']),
|
||||
)
|
||||
for sql, constraint_name, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertEqual(constraint, constraint_name)
|
||||
self.assertConstraint(details, columns, unique=True)
|
||||
self.assertIsNone(check)
|
||||
|
||||
def test_unique_constraint_multicolumn(self):
|
||||
tests = (
|
||||
('CONSTRAINT "ref" UNIQUE ("ref", "customname"),', 'ref', ['ref', 'customname']),
|
||||
('CONSTRAINT ref UNIQUE (ref, customname),', 'ref', ['ref', 'customname']),
|
||||
)
|
||||
for sql, constraint_name, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertEqual(constraint, constraint_name)
|
||||
self.assertConstraint(details, columns, unique=True)
|
||||
self.assertIsNone(check)
|
||||
|
||||
def test_check_column(self):
|
||||
tests = (
|
||||
('"ref" varchar(255) CHECK ("ref" != \'test\'),', ['ref']),
|
||||
('ref varchar(255) CHECK (ref != \'test\'),', ['ref']),
|
||||
('"customname1" varchar(255) CHECK ("customname2" != \'test\'),', ['customname2']),
|
||||
('customname1 varchar(255) CHECK (customname2 != \'test\'),', ['customname2']),
|
||||
)
|
||||
for sql, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertIsNone(constraint)
|
||||
self.assertIsNone(details)
|
||||
self.assertConstraint(check, columns, check=True)
|
||||
|
||||
def test_check_constraint(self):
|
||||
tests = (
|
||||
('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', 'ref', ['ref']),
|
||||
('CONSTRAINT ref CHECK (ref != \'test\'),', 'ref', ['ref']),
|
||||
('CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),', 'customname1', ['customname2']),
|
||||
('CONSTRAINT customname1 CHECK (customname2 != \'test\'),', 'customname1', ['customname2']),
|
||||
)
|
||||
for sql, constraint_name, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertEqual(constraint, constraint_name)
|
||||
self.assertIsNone(details)
|
||||
self.assertConstraint(check, columns, check=True)
|
||||
|
||||
def test_check_column_with_operators_and_functions(self):
|
||||
tests = (
|
||||
('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ['ref']),
|
||||
('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ['ref']),
|
||||
('"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),', ['ref', 'max_length']),
|
||||
)
|
||||
for sql, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertIsNone(constraint)
|
||||
self.assertIsNone(details)
|
||||
self.assertConstraint(check, columns, check=True)
|
||||
|
||||
def test_check_and_unique_column(self):
|
||||
tests = (
|
||||
('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ['ref']),
|
||||
('ref varchar(255) UNIQUE CHECK (ref != \'test\'),', ['ref']),
|
||||
)
|
||||
for sql, columns in tests:
|
||||
with self.subTest(sql=sql):
|
||||
constraint, details, check, _ = self.parse_definition(sql, columns)
|
||||
self.assertIsNone(constraint)
|
||||
self.assertConstraint(details, columns, unique=True)
|
||||
self.assertConstraint(check, columns, check=True)
|
||||
|
|
|
@ -58,3 +58,21 @@ class ArticleReporter(models.Model):
|
|||
|
||||
class Meta:
|
||||
managed = False
|
||||
|
||||
|
||||
class Comment(models.Model):
|
||||
ref = models.UUIDField(unique=True)
|
||||
article = models.ForeignKey(Article, models.CASCADE, db_index=True)
|
||||
email = models.EmailField()
|
||||
pub_date = models.DateTimeField()
|
||||
up_votes = models.PositiveIntegerField()
|
||||
body = models.TextField()
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
models.CheckConstraint(name='up_votes_gte_0_check', check=models.Q(up_votes__gte=0)),
|
||||
models.UniqueConstraint(fields=['article', 'email', 'pub_date'], name='article_email_pub_date_uniq'),
|
||||
]
|
||||
indexes = [
|
||||
models.Index(fields=['email', 'pub_date'], name='email_pub_date_idx'),
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@ from django.db.models import Index
|
|||
from django.db.utils import DatabaseError
|
||||
from django.test import TransactionTestCase, skipUnlessDBFeature
|
||||
|
||||
from .models import Article, ArticleReporter, City, District, Reporter
|
||||
from .models import Article, ArticleReporter, City, Comment, District, Reporter
|
||||
|
||||
|
||||
class IntrospectionTests(TransactionTestCase):
|
||||
|
@ -211,3 +211,60 @@ class IntrospectionTests(TransactionTestCase):
|
|||
self.assertEqual(val['orders'], ['ASC'] * len(val['columns']))
|
||||
indexes_verified += 1
|
||||
self.assertEqual(indexes_verified, 4)
|
||||
|
||||
def test_get_constraints(self):
|
||||
def assertDetails(details, cols, primary_key=False, unique=False, index=False, check=False, foreign_key=None):
|
||||
# Different backends have different values for same constraints:
|
||||
# PRIMARY KEY UNIQUE CONSTRAINT UNIQUE INDEX
|
||||
# MySQL pk=1 uniq=1 idx=1 pk=0 uniq=1 idx=1 pk=0 uniq=1 idx=1
|
||||
# PostgreSQL pk=1 uniq=1 idx=0 pk=0 uniq=1 idx=0 pk=0 uniq=1 idx=1
|
||||
# SQLite pk=1 uniq=0 idx=0 pk=0 uniq=1 idx=0 pk=0 uniq=1 idx=1
|
||||
if details['primary_key']:
|
||||
details['unique'] = True
|
||||
if details['unique']:
|
||||
details['index'] = False
|
||||
self.assertEqual(details['columns'], cols)
|
||||
self.assertEqual(details['primary_key'], primary_key)
|
||||
self.assertEqual(details['unique'], unique)
|
||||
self.assertEqual(details['index'], index)
|
||||
self.assertEqual(details['check'], check)
|
||||
self.assertEqual(details['foreign_key'], foreign_key)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
constraints = connection.introspection.get_constraints(cursor, Comment._meta.db_table)
|
||||
# Test custom constraints
|
||||
custom_constraints = {
|
||||
'article_email_pub_date_uniq',
|
||||
'email_pub_date_idx',
|
||||
}
|
||||
if connection.features.supports_column_check_constraints:
|
||||
custom_constraints.add('up_votes_gte_0_check')
|
||||
assertDetails(constraints['up_votes_gte_0_check'], ['up_votes'], check=True)
|
||||
assertDetails(constraints['article_email_pub_date_uniq'], ['article_id', 'email', 'pub_date'], unique=True)
|
||||
assertDetails(constraints['email_pub_date_idx'], ['email', 'pub_date'], index=True)
|
||||
# Test field constraints
|
||||
field_constraints = set()
|
||||
for name, details in constraints.items():
|
||||
if name in custom_constraints:
|
||||
continue
|
||||
elif details['columns'] == ['up_votes'] and details['check']:
|
||||
assertDetails(details, ['up_votes'], check=True)
|
||||
field_constraints.add(name)
|
||||
elif details['columns'] == ['ref'] and details['unique']:
|
||||
assertDetails(details, ['ref'], unique=True)
|
||||
field_constraints.add(name)
|
||||
elif details['columns'] == ['article_id'] and details['index']:
|
||||
assertDetails(details, ['article_id'], index=True)
|
||||
field_constraints.add(name)
|
||||
elif details['columns'] == ['id'] and details['primary_key']:
|
||||
assertDetails(details, ['id'], primary_key=True, unique=True)
|
||||
field_constraints.add(name)
|
||||
elif details['columns'] == ['article_id'] and details['foreign_key']:
|
||||
assertDetails(details, ['article_id'], foreign_key=('introspection_article', 'id'))
|
||||
field_constraints.add(name)
|
||||
elif details['check']:
|
||||
# Some databases (e.g. Oracle) include additional check
|
||||
# constraints.
|
||||
field_constraints.add(name)
|
||||
# All constraints are accounted for.
|
||||
self.assertEqual(constraints.keys() ^ (custom_constraints | field_constraints), set())
|
||||
|
|
|
@ -129,6 +129,14 @@ class SchemaTests(TransactionTestCase):
|
|||
if c['index'] and len(c['columns']) == 1
|
||||
]
|
||||
|
||||
def get_uniques(self, table):
|
||||
with connection.cursor() as cursor:
|
||||
return [
|
||||
c['columns'][0]
|
||||
for c in connection.introspection.get_constraints(cursor, table).values()
|
||||
if c['unique'] and len(c['columns']) == 1
|
||||
]
|
||||
|
||||
def get_constraints(self, table):
|
||||
"""
|
||||
Get the constraints on a table using a new cursor.
|
||||
|
@ -1971,7 +1979,7 @@ class SchemaTests(TransactionTestCase):
|
|||
editor.add_field(Book, new_field3)
|
||||
self.assertIn(
|
||||
"slug",
|
||||
self.get_indexes(Book._meta.db_table),
|
||||
self.get_uniques(Book._meta.db_table),
|
||||
)
|
||||
# Remove the unique, check the index goes with it
|
||||
new_field4 = CharField(max_length=20, unique=False)
|
||||
|
@ -1980,7 +1988,7 @@ class SchemaTests(TransactionTestCase):
|
|||
editor.alter_field(BookWithSlug, new_field3, new_field4, strict=True)
|
||||
self.assertNotIn(
|
||||
"slug",
|
||||
self.get_indexes(Book._meta.db_table),
|
||||
self.get_uniques(Book._meta.db_table),
|
||||
)
|
||||
|
||||
def test_text_field_with_db_index(self):
|
||||
|
|
Loading…
Reference in New Issue