diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 85564110f87..32efd7ff6f2 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -214,7 +214,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): results = results[results.index('(') + 1:results.rindex(')')] for field_desc in results.split(','): field_desc = field_desc.strip() - m = re.search('"(.*)".*PRIMARY KEY( AUTOINCREMENT)?$', field_desc) + m = re.search('"(.*)".*PRIMARY KEY( AUTOINCREMENT)?', field_desc) if m: return m.groups()[0] return None diff --git a/tests/introspection/models.py b/tests/introspection/models.py index f989709e793..6d6650bc246 100644 --- a/tests/introspection/models.py +++ b/tests/introspection/models.py @@ -15,7 +15,7 @@ class City(models.Model): @python_2_unicode_compatible class District(models.Model): - city = models.ForeignKey(City, models.CASCADE) + city = models.ForeignKey(City, models.CASCADE, primary_key=True) name = models.CharField(max_length=50) def __str__(self): diff --git a/tests/introspection/tests.py b/tests/introspection/tests.py index f01f81e7183..5e2d6f1c992 100644 --- a/tests/introspection/tests.py +++ b/tests/introspection/tests.py @@ -6,7 +6,7 @@ from django.db import connection from django.db.utils import DatabaseError from django.test import TransactionTestCase, mock, skipUnlessDBFeature -from .models import Article, ArticleReporter, City, Reporter +from .models import Article, ArticleReporter, City, District, Reporter class IntrospectionTests(TransactionTestCase): @@ -165,7 +165,9 @@ class IntrospectionTests(TransactionTestCase): def test_get_primary_key_column(self): with connection.cursor() as cursor: primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) + pk_fk_column = connection.introspection.get_primary_key_column(cursor, District._meta.db_table) self.assertEqual(primary_key_column, 'id') + self.assertEqual(pk_fk_column, 'city_id') def test_get_indexes(self): with connection.cursor() as cursor: