Fixed #33288 -- Made SQLite introspection use information schema for relations.

Previous solution was using brittle and complex parsing rules to
extract them from the SQL used to define the tables.

Removed a now unnecessary unit test that ensured the removed parsing
logic accounted for optional spacing.
This commit is contained in:
Simon Charette 2021-11-13 14:43:04 -05:00 committed by Mariusz Felisiak
parent 30ec7fe89a
commit 483e30c3d5
2 changed files with 6 additions and 70 deletions

View File

@ -1,4 +1,3 @@
import re
from collections import namedtuple from collections import namedtuple
import sqlparse import sqlparse
@ -117,61 +116,16 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Return a dictionary of {field_name: (field_name_other_table, other_table)} Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
representing all foreign keys in the given table. representing all foreign keys in the given table.
""" """
# Dictionary of relations to return
relations = {}
# Schema for this table
cursor.execute( cursor.execute(
"SELECT sql, type FROM sqlite_master " 'PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name)
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name]
) )
create_sql, table_type = cursor.fetchone() return {
if table_type == 'view': column_name: (ref_column_name, ref_table_name)
# It might be a view, then no results will be returned for _, _, ref_table_name, column_name, ref_column_name, *_ in cursor.fetchall()
return relations }
results = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_desc in results.split(','):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
m = re.search(r'references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
if not m:
continue
table, column = [s.strip('"') for s in m.groups()]
if field_desc.startswith("FOREIGN KEY"):
# Find name of the target FK field
m = re.match(r'FOREIGN KEY\s*\(([^\)]*)\).*', field_desc, re.I)
field_name = m[1].strip('"')
else:
field_name = field_desc.split()[0].strip('"')
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
result = cursor.fetchall()[0]
other_table_results = result[0].strip()
li, ri = other_table_results.index('('), other_table_results.rindex(')')
other_table_results = other_table_results[li + 1:ri]
for other_desc in other_table_results.split(','):
other_desc = other_desc.strip()
if other_desc.startswith('UNIQUE'):
continue
other_name = other_desc.split(' ', 1)[0].strip('"')
if other_name == column:
relations[field_name] = (other_name, table)
break
return relations
def get_primary_key_column(self, cursor, table_name): def get_primary_key_column(self, cursor, table_name):
"""Return the column name of the primary key for the given table.""" """Return the column name of the primary key for the given table."""

View File

@ -1,5 +1,3 @@
from unittest import mock, skipUnless
from django.db import DatabaseError, connection from django.db import DatabaseError, connection
from django.db.models import Index from django.db.models import Index
from django.test import TransactionTestCase, skipUnlessDBFeature from django.test import TransactionTestCase, skipUnlessDBFeature
@ -152,22 +150,6 @@ class IntrospectionTests(TransactionTestCase):
editor.add_field(Article, body) editor.add_field(Article, body)
self.assertEqual(relations, expected_relations) self.assertEqual(relations, expected_relations)
@skipUnless(connection.vendor == 'sqlite', "This is an sqlite-specific issue")
def test_get_relations_alt_format(self):
"""
With SQLite, foreign keys can be added with different syntaxes and
formatting.
"""
create_table_statements = [
"CREATE TABLE track(id, art_id INTEGER, FOREIGN KEY(art_id) REFERENCES {}(id));",
"CREATE TABLE track(id, art_id INTEGER, FOREIGN KEY (art_id) REFERENCES {}(id));"
]
for statement in create_table_statements:
with connection.cursor() as cursor:
cursor.fetchone = mock.Mock(return_value=[statement.format(Article._meta.db_table), 'table'])
relations = connection.introspection.get_relations(cursor, 'mocked_table')
self.assertEqual(relations, {'art_id': ('id', Article._meta.db_table)})
def test_get_primary_key_column(self): def test_get_primary_key_column(self):
with connection.cursor() as cursor: with connection.cursor() as cursor:
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)