From 782d85b6dfa191e67c0f1d572641d8236c79174c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pave=C5=82=20Ty=C5=9Blacki?= Date: Thu, 28 Feb 2019 00:47:29 +0300 Subject: [PATCH] Fixed #30183 -- Added introspection of inline SQLite constraints. --- django/db/backends/sqlite3/introspection.py | 179 ++++++++++++++------ tests/backends/sqlite/test_introspection.py | 115 +++++++++++++ tests/introspection/models.py | 18 ++ tests/introspection/tests.py | 59 ++++++- tests/schema/tests.py | 12 +- 5 files changed, 332 insertions(+), 51 deletions(-) diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 1104571bb2..faf9328415 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -217,50 +217,124 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): } return constraints - def _parse_table_constraints(self, sql): + def _parse_column_or_constraint_definition(self, tokens, columns): + token = None + is_constraint_definition = None + field_name = None + constraint_name = None + unique = False + unique_columns = [] + check = False + check_columns = [] + braces_deep = 0 + for token in tokens: + if token.match(sqlparse.tokens.Punctuation, '('): + braces_deep += 1 + elif token.match(sqlparse.tokens.Punctuation, ')'): + braces_deep -= 1 + if braces_deep < 0: + # End of columns and constraints for table definition. + break + elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','): + # End of current column or constraint definition. + break + # Detect column or constraint definition by first token. + if is_constraint_definition is None: + is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT') + if is_constraint_definition: + continue + if is_constraint_definition: + # Detect constraint name by second token. + if constraint_name is None: + if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): + constraint_name = token.value + elif token.ttype == sqlparse.tokens.Literal.String.Symbol: + constraint_name = token.value[1:-1] + # Start constraint columns parsing after UNIQUE keyword. + if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + unique = True + unique_braces_deep = braces_deep + elif unique: + if unique_braces_deep == braces_deep: + if unique_columns: + # Stop constraint parsing. + unique = False + continue + if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): + unique_columns.append(token.value) + elif token.ttype == sqlparse.tokens.Literal.String.Symbol: + unique_columns.append(token.value[1:-1]) + else: + # Detect field name by first token. + if field_name is None: + if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): + field_name = token.value + elif token.ttype == sqlparse.tokens.Literal.String.Symbol: + field_name = token.value[1:-1] + if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + unique_columns = [field_name] + # Start constraint columns parsing after CHECK keyword. + if token.match(sqlparse.tokens.Keyword, 'CHECK'): + check = True + check_braces_deep = braces_deep + elif check: + if check_braces_deep == braces_deep: + if check_columns: + # Stop constraint parsing. + check = False + continue + if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): + if token.value in columns: + check_columns.append(token.value) + elif token.ttype == sqlparse.tokens.Literal.String.Symbol: + if token.value[1:-1] in columns: + check_columns.append(token.value[1:-1]) + unique_constraint = { + 'unique': True, + 'columns': unique_columns, + 'primary_key': False, + 'foreign_key': None, + 'check': False, + 'index': False, + } if unique_columns else None + check_constraint = { + 'check': True, + 'columns': check_columns, + 'primary_key': False, + 'unique': False, + 'foreign_key': None, + 'index': False, + } if check_columns else None + return constraint_name, unique_constraint, check_constraint, token + + def _parse_table_constraints(self, sql, columns): # Check constraint parsing is based of SQLite syntax diagram. # https://www.sqlite.org/syntaxdiagrams.html#table-constraint - def next_ttype(ttype): - for token in tokens: - if token.ttype == ttype: - return token - statement = sqlparse.parse(sql)[0] constraints = {} - tokens = statement.flatten() + unnamed_constrains_index = 0 + tokens = (token for token in statement.flatten() if not token.is_whitespace) + # Go to columns and constraint definition for token in tokens: - name = None - if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'): - # Table constraint - name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) - name = name_token.value[1:-1] - token = next_ttype(sqlparse.tokens.Keyword) - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): - constraints[name] = { - 'unique': True, - 'columns': [], - 'primary_key': False, - 'foreign_key': None, - 'check': False, - 'index': False, - } - if token.match(sqlparse.tokens.Keyword, 'CHECK'): - # Column check constraint - if name is None: - column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) - column = column_token.value[1:-1] - name = '__check__%s' % column - columns = [column] + if token.match(sqlparse.tokens.Punctuation, '('): + break + # Parse columns and constraint definition + while True: + constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns) + if unique: + if constraint_name: + constraints[constraint_name] = unique else: - columns = [] - constraints[name] = { - 'check': True, - 'columns': columns, - 'primary_key': False, - 'unique': False, - 'foreign_key': None, - 'index': False, - } + unnamed_constrains_index += 1 + constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique + if check: + if constraint_name: + constraints[constraint_name] = check + else: + unnamed_constrains_index += 1 + constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check + if end_token.match(sqlparse.tokens.Punctuation, ')'): + break return constraints def get_constraints(self, cursor, table_name): @@ -280,7 +354,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # table_name is a view. pass else: - constraints.update(self._parse_table_constraints(table_schema)) + columns = {info.name for info in self.get_table_description(cursor, table_name)} + constraints.update(self._parse_table_constraints(table_schema, columns)) # Get the index info cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) @@ -288,6 +363,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # SQLite 3.8.9+ has 5 columns, however older versions only give 3 # columns. Discard last 2 columns if there. number, index, unique = row[:3] + cursor.execute( + "SELECT sql FROM sqlite_master " + "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) + ) + # There's at most one row. + sql, = cursor.fetchone() or (None,) + # Inline constraints are already detected in + # _parse_table_constraints(). The reasons to avoid fetching inline + # constraints from `PRAGMA index_list` are: + # - Inline constraints can have a different name and information + # than what `PRAGMA index_list` gives. + # - Not all inline constraints may appear in `PRAGMA index_list`. + if not sql: + # An inline constraint + continue # Get the index info for that index cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) for index_rank, column_rank, column in cursor.fetchall(): @@ -305,15 +395,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): if constraints[index]['index'] and not constraints[index]['unique']: # SQLite doesn't support any index type other than b-tree constraints[index]['type'] = Index.suffix - cursor.execute( - "SELECT sql FROM sqlite_master " - "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) - ) - orders = [] - # There would be only 1 row to loop over - for sql, in cursor.fetchall(): - order_info = sql.split('(')[-1].split(')')[0].split(',') - orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info] + order_info = sql.split('(')[-1].split(')')[0].split(',') + orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info] constraints[index]['orders'] = orders # Get the PK pk_column = self.get_primary_key_column(cursor, table_name) diff --git a/tests/backends/sqlite/test_introspection.py b/tests/backends/sqlite/test_introspection.py index 1695ee549e..e378e0ee56 100644 --- a/tests/backends/sqlite/test_introspection.py +++ b/tests/backends/sqlite/test_introspection.py @@ -1,5 +1,7 @@ import unittest +import sqlparse + from django.db import connection from django.test import TestCase @@ -25,3 +27,116 @@ class IntrospectionTests(TestCase): self.assertEqual(field, expected_string) finally: cursor.execute('DROP TABLE test_primary') + + +@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') +class ParsingTests(TestCase): + def parse_definition(self, sql, columns): + """Parse a column or constraint definition.""" + statement = sqlparse.parse(sql)[0] + tokens = (token for token in statement.flatten() if not token.is_whitespace) + with connection.cursor(): + return connection.introspection._parse_column_or_constraint_definition(tokens, set(columns)) + + def assertConstraint(self, constraint_details, cols, unique=False, check=False): + self.assertEqual(constraint_details, { + 'unique': unique, + 'columns': cols, + 'primary_key': False, + 'foreign_key': None, + 'check': check, + 'index': False, + }) + + def test_unique_column(self): + tests = ( + ('"ref" integer UNIQUE,', ['ref']), + ('ref integer UNIQUE,', ['ref']), + ('"customname" integer UNIQUE,', ['customname']), + ('customname integer UNIQUE,', ['customname']), + ) + for sql, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertIsNone(constraint) + self.assertConstraint(details, columns, unique=True) + self.assertIsNone(check) + + def test_unique_constraint(self): + tests = ( + ('CONSTRAINT "ref" UNIQUE ("ref"),', 'ref', ['ref']), + ('CONSTRAINT ref UNIQUE (ref),', 'ref', ['ref']), + ('CONSTRAINT "customname1" UNIQUE ("customname2"),', 'customname1', ['customname2']), + ('CONSTRAINT customname1 UNIQUE (customname2),', 'customname1', ['customname2']), + ) + for sql, constraint_name, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertEqual(constraint, constraint_name) + self.assertConstraint(details, columns, unique=True) + self.assertIsNone(check) + + def test_unique_constraint_multicolumn(self): + tests = ( + ('CONSTRAINT "ref" UNIQUE ("ref", "customname"),', 'ref', ['ref', 'customname']), + ('CONSTRAINT ref UNIQUE (ref, customname),', 'ref', ['ref', 'customname']), + ) + for sql, constraint_name, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertEqual(constraint, constraint_name) + self.assertConstraint(details, columns, unique=True) + self.assertIsNone(check) + + def test_check_column(self): + tests = ( + ('"ref" varchar(255) CHECK ("ref" != \'test\'),', ['ref']), + ('ref varchar(255) CHECK (ref != \'test\'),', ['ref']), + ('"customname1" varchar(255) CHECK ("customname2" != \'test\'),', ['customname2']), + ('customname1 varchar(255) CHECK (customname2 != \'test\'),', ['customname2']), + ) + for sql, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertIsNone(constraint) + self.assertIsNone(details) + self.assertConstraint(check, columns, check=True) + + def test_check_constraint(self): + tests = ( + ('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', 'ref', ['ref']), + ('CONSTRAINT ref CHECK (ref != \'test\'),', 'ref', ['ref']), + ('CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),', 'customname1', ['customname2']), + ('CONSTRAINT customname1 CHECK (customname2 != \'test\'),', 'customname1', ['customname2']), + ) + for sql, constraint_name, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertEqual(constraint, constraint_name) + self.assertIsNone(details) + self.assertConstraint(check, columns, check=True) + + def test_check_column_with_operators_and_functions(self): + tests = ( + ('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ['ref']), + ('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ['ref']), + ('"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),', ['ref', 'max_length']), + ) + for sql, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertIsNone(constraint) + self.assertIsNone(details) + self.assertConstraint(check, columns, check=True) + + def test_check_and_unique_column(self): + tests = ( + ('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ['ref']), + ('ref varchar(255) UNIQUE CHECK (ref != \'test\'),', ['ref']), + ) + for sql, columns in tests: + with self.subTest(sql=sql): + constraint, details, check, _ = self.parse_definition(sql, columns) + self.assertIsNone(constraint) + self.assertConstraint(details, columns, unique=True) + self.assertConstraint(check, columns, check=True) diff --git a/tests/introspection/models.py b/tests/introspection/models.py index 32acc323bd..fa663de2fd 100644 --- a/tests/introspection/models.py +++ b/tests/introspection/models.py @@ -58,3 +58,21 @@ class ArticleReporter(models.Model): class Meta: managed = False + + +class Comment(models.Model): + ref = models.UUIDField(unique=True) + article = models.ForeignKey(Article, models.CASCADE, db_index=True) + email = models.EmailField() + pub_date = models.DateTimeField() + up_votes = models.PositiveIntegerField() + body = models.TextField() + + class Meta: + constraints = [ + models.CheckConstraint(name='up_votes_gte_0_check', check=models.Q(up_votes__gte=0)), + models.UniqueConstraint(fields=['article', 'email', 'pub_date'], name='article_email_pub_date_uniq'), + ] + indexes = [ + models.Index(fields=['email', 'pub_date'], name='email_pub_date_idx'), + ] diff --git a/tests/introspection/tests.py b/tests/introspection/tests.py index d851352cae..10524cdacb 100644 --- a/tests/introspection/tests.py +++ b/tests/introspection/tests.py @@ -5,7 +5,7 @@ from django.db.models import Index from django.db.utils import DatabaseError from django.test import TransactionTestCase, skipUnlessDBFeature -from .models import Article, ArticleReporter, City, District, Reporter +from .models import Article, ArticleReporter, City, Comment, District, Reporter class IntrospectionTests(TransactionTestCase): @@ -211,3 +211,60 @@ class IntrospectionTests(TransactionTestCase): self.assertEqual(val['orders'], ['ASC'] * len(val['columns'])) indexes_verified += 1 self.assertEqual(indexes_verified, 4) + + def test_get_constraints(self): + def assertDetails(details, cols, primary_key=False, unique=False, index=False, check=False, foreign_key=None): + # Different backends have different values for same constraints: + # PRIMARY KEY UNIQUE CONSTRAINT UNIQUE INDEX + # MySQL pk=1 uniq=1 idx=1 pk=0 uniq=1 idx=1 pk=0 uniq=1 idx=1 + # PostgreSQL pk=1 uniq=1 idx=0 pk=0 uniq=1 idx=0 pk=0 uniq=1 idx=1 + # SQLite pk=1 uniq=0 idx=0 pk=0 uniq=1 idx=0 pk=0 uniq=1 idx=1 + if details['primary_key']: + details['unique'] = True + if details['unique']: + details['index'] = False + self.assertEqual(details['columns'], cols) + self.assertEqual(details['primary_key'], primary_key) + self.assertEqual(details['unique'], unique) + self.assertEqual(details['index'], index) + self.assertEqual(details['check'], check) + self.assertEqual(details['foreign_key'], foreign_key) + + with connection.cursor() as cursor: + constraints = connection.introspection.get_constraints(cursor, Comment._meta.db_table) + # Test custom constraints + custom_constraints = { + 'article_email_pub_date_uniq', + 'email_pub_date_idx', + } + if connection.features.supports_column_check_constraints: + custom_constraints.add('up_votes_gte_0_check') + assertDetails(constraints['up_votes_gte_0_check'], ['up_votes'], check=True) + assertDetails(constraints['article_email_pub_date_uniq'], ['article_id', 'email', 'pub_date'], unique=True) + assertDetails(constraints['email_pub_date_idx'], ['email', 'pub_date'], index=True) + # Test field constraints + field_constraints = set() + for name, details in constraints.items(): + if name in custom_constraints: + continue + elif details['columns'] == ['up_votes'] and details['check']: + assertDetails(details, ['up_votes'], check=True) + field_constraints.add(name) + elif details['columns'] == ['ref'] and details['unique']: + assertDetails(details, ['ref'], unique=True) + field_constraints.add(name) + elif details['columns'] == ['article_id'] and details['index']: + assertDetails(details, ['article_id'], index=True) + field_constraints.add(name) + elif details['columns'] == ['id'] and details['primary_key']: + assertDetails(details, ['id'], primary_key=True, unique=True) + field_constraints.add(name) + elif details['columns'] == ['article_id'] and details['foreign_key']: + assertDetails(details, ['article_id'], foreign_key=('introspection_article', 'id')) + field_constraints.add(name) + elif details['check']: + # Some databases (e.g. Oracle) include additional check + # constraints. + field_constraints.add(name) + # All constraints are accounted for. + self.assertEqual(constraints.keys() ^ (custom_constraints | field_constraints), set()) diff --git a/tests/schema/tests.py b/tests/schema/tests.py index 9b40a43523..00ce2e494e 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -129,6 +129,14 @@ class SchemaTests(TransactionTestCase): if c['index'] and len(c['columns']) == 1 ] + def get_uniques(self, table): + with connection.cursor() as cursor: + return [ + c['columns'][0] + for c in connection.introspection.get_constraints(cursor, table).values() + if c['unique'] and len(c['columns']) == 1 + ] + def get_constraints(self, table): """ Get the constraints on a table using a new cursor. @@ -1971,7 +1979,7 @@ class SchemaTests(TransactionTestCase): editor.add_field(Book, new_field3) self.assertIn( "slug", - self.get_indexes(Book._meta.db_table), + self.get_uniques(Book._meta.db_table), ) # Remove the unique, check the index goes with it new_field4 = CharField(max_length=20, unique=False) @@ -1980,7 +1988,7 @@ class SchemaTests(TransactionTestCase): editor.alter_field(BookWithSlug, new_field3, new_field4, strict=True) self.assertNotIn( "slug", - self.get_indexes(Book._meta.db_table), + self.get_uniques(Book._meta.db_table), ) def test_text_field_with_db_index(self):