Fixed #32672 -- Fixed introspection of primary key constraints on SQLite.

Thanks Simon Charette for the implementation idea.
This commit is contained in:
Anv3sh 2021-09-21 22:49:58 +05:30 committed by Mariusz Felisiak
parent d446f8ba08
commit 69af4d09ba
2 changed files with 23 additions and 17 deletions

View File

@ -201,25 +201,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_primary_key_column(self, cursor, table_name): def get_primary_key_column(self, cursor, table_name):
"""Return the column name of the primary key for the given table.""" """Return the column name of the primary key for the given table."""
# Don't use PRAGMA because that causes issues with some transactions
cursor.execute( cursor.execute(
"SELECT sql, type FROM sqlite_master " 'PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name]
) )
row = cursor.fetchone() for _, name, *_, pk in cursor.fetchall():
if row is None: if pk:
raise ValueError("Table %s does not exist" % table_name) return name
create_sql, table_type = row
if table_type == 'view':
# Views don't have a primary key.
return None
fields_sql = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
for field_desc in fields_sql.split(','):
field_desc = field_desc.strip()
m = re.match(r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc)
if m:
return m[1] if m[1] else m[2]
return None return None
def _get_foreign_key_constraints(self, cursor, table_name): def _get_foreign_key_constraints(self, cursor, table_name):

View File

@ -28,6 +28,25 @@ class IntrospectionTests(TestCase):
finally: finally:
cursor.execute('DROP TABLE test_primary') cursor.execute('DROP TABLE test_primary')
def test_get_primary_key_column_pk_constraint(self):
sql = """
CREATE TABLE test_primary(
id INTEGER NOT NULL,
created DATE,
PRIMARY KEY(id)
)
"""
with connection.cursor() as cursor:
try:
cursor.execute(sql)
field = connection.introspection.get_primary_key_column(
cursor,
'test_primary',
)
self.assertEqual(field, 'id')
finally:
cursor.execute('DROP TABLE test_primary')
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
class ParsingTests(TestCase): class ParsingTests(TestCase):