diff --git a/django/db/backends/base/introspection.py b/django/db/backends/base/introspection.py index 5184e39263..a3107c3034 100644 --- a/django/db/backends/base/introspection.py +++ b/django/db/backends/base/introspection.py @@ -113,9 +113,10 @@ class BaseDatabaseIntrospection: all apps. """ from django.apps import apps - from django.db import models, router + from django.db import router sequence_list = [] + cursor = self.connection.cursor() for app_config in apps.get_app_configs(): for model in router.get_migratable_models(app_config, self.connection.alias): @@ -123,19 +124,23 @@ class BaseDatabaseIntrospection: continue if model._meta.swapped: continue - for f in model._meta.local_fields: - if isinstance(f, models.AutoField): - sequence_list.append({'table': model._meta.db_table, 'column': f.column}) - break # Only one AutoField is allowed per model, so don't bother continuing. - + sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields)) for f in model._meta.local_many_to_many: # If this is an m2m using an intermediate table, # we don't need to reset the sequence. if f.remote_field.through is None: - sequence_list.append({'table': f.m2m_db_table(), 'column': None}) - + sequence = self.get_sequences(cursor, f.m2m_db_table()) + sequence_list.extend(sequence if sequence else [{'table': f.m2m_db_table(), 'column': None}]) return sequence_list + def get_sequences(self, cursor, table_name, table_fields=()): + """ + Return a list of introspected sequences for table_name. Each sequence + is a dict: {'table': , 'column': }. An optional + 'name' key can be added if the backend supports named sequences. + """ + raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method') + def get_key_columns(self, cursor, table_name): """ Backends can override this to return a list of: diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 8f61fa3061..caa5826e55 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -105,6 +105,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ) return fields + def get_sequences(self, cursor, table_name, table_fields=()): + for field_info in self.get_table_description(cursor, table_name): + if 'auto_increment' in field_info.extra: + # MySQL allows only one auto-increment column per table. + return [{'table': table_name, 'column': field_info.name}] + return [] + def get_relations(self, cursor, table_name): """ Return a dictionary of {field_name: (field_name_other_table, other_table)} diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 4bd1c0f422..7b873fd0d0 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -3,6 +3,7 @@ from collections import namedtuple import cx_Oracle +from django.db import models from django.db.backends.base.introspection import ( BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, ) @@ -98,6 +99,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): """Table name comparison is case insensitive under Oracle.""" return name.lower() + def get_sequences(self, cursor, table_name, table_fields=()): + cursor.execute(""" + SELECT + user_tab_identity_cols.sequence_name, + user_tab_identity_cols.column_name + FROM + user_tab_identity_cols, + user_constraints, + user_cons_columns cols + WHERE + user_constraints.constraint_name = cols.constraint_name + AND user_constraints.table_name = user_tab_identity_cols.table_name + AND cols.column_name = user_tab_identity_cols.column_name + AND user_constraints.constraint_type = 'P' + AND user_tab_identity_cols.table_name = UPPER(%s) + """, [table_name]) + # Oracle allows only one identity column per table. + row = cursor.fetchone() + if row: + return [{'name': row[0].lower(), 'table': table_name, 'column': row[1].lower()}] + # To keep backward compatibility for AutoFields that aren't Oracle + # identity columns. + for f in table_fields: + if isinstance(f, models.AutoField): + return [{'table': table_name, 'column': f.column}] + return [] + def get_relations(self, cursor, table_name): """ Return a dictionary of {field_name: (field_name_other_table, other_table)} diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index 5be80f3edd..1e987d1779 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -82,6 +82,26 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): for line in cursor.description ] + def get_sequences(self, cursor, table_name, table_fields=()): + sequences = [] + cursor.execute(""" + SELECT s.relname as sequence_name, col.attname + FROM pg_class s + JOIN pg_namespace sn ON sn.oid = s.relnamespace + JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass + JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass + JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum + JOIN pg_class tbl ON tbl.oid = ad.adrelid + JOIN pg_namespace n ON n.oid = tbl.relnamespace + WHERE s.relkind = 'S' + AND d.deptype in ('a', 'n') + AND n.nspname = 'public' + AND tbl.relname = %s + """, [table_name]) + for row in cursor.fetchall(): + sequences.append({'name': row[0], 'table': table_name, 'column': row[1]}) + return sequences + def get_relations(self, cursor, table_name): """ Return a dictionary of {field_name: (field_name_other_table, other_table)} diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 49c2095a1c..0518641344 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -85,6 +85,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ) for info in self._table_info(cursor, table_name) ] + def get_sequences(self, cursor, table_name, table_fields=()): + pk_col = self.get_primary_key_column(cursor, table_name) + return [{'table': table_name, 'column': pk_col}] + def column_name_converter(self, name): """ SQLite will in some cases, e.g. when returning columns from views and diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index d1fb42c5de..d25f25eff9 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -410,6 +410,10 @@ backends. for consistency). ``django.test`` also now passes those values as strings rather than as integers. +* Third-party database backends should add a + ``DatabaseIntrospection.get_sequences()`` method based on the stub in + ``BaseDatabaseIntrospection``. + Dropped support for Oracle 11.2 ------------------------------- diff --git a/tests/backends/oracle/test_introspection.py b/tests/backends/oracle/test_introspection.py new file mode 100644 index 0000000000..ac084e93df --- /dev/null +++ b/tests/backends/oracle/test_introspection.py @@ -0,0 +1,29 @@ +import unittest + +from django.db import connection +from django.test import TransactionTestCase + +from ..models import Person + + +@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests') +class DatabaseSequenceTests(TransactionTestCase): + available_apps = [] + + def test_get_sequences(self): + with connection.cursor() as cursor: + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields) + self.assertEqual(len(seqs), 1) + self.assertIsNotNone(seqs[0]['name']) + self.assertEqual(seqs[0]['table'], Person._meta.db_table) + self.assertEqual(seqs[0]['column'], 'id') + + def test_get_sequences_manually_created_index(self): + with connection.cursor() as cursor: + with connection.schema_editor() as editor: + editor._drop_identity(Person._meta.db_table, 'id') + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields) + self.assertEqual(seqs, [{'table': Person._meta.db_table, 'column': 'id'}]) + # Recreate model, because adding identity is impossible. + editor.delete_model(Person) + editor.create_model(Person) diff --git a/tests/backends/postgresql/test_introspection.py b/tests/backends/postgresql/test_introspection.py new file mode 100644 index 0000000000..cfa801a77f --- /dev/null +++ b/tests/backends/postgresql/test_introspection.py @@ -0,0 +1,23 @@ +import unittest + +from django.db import connection +from django.test import TestCase + +from ..models import Person + + +@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") +class DatabaseSequenceTests(TestCase): + def test_get_sequences(self): + cursor = connection.cursor() + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) + self.assertEqual( + seqs, + [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] + ) + cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) + self.assertEqual( + seqs, + [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] + ) diff --git a/tests/introspection/tests.py b/tests/introspection/tests.py index c449b6de82..ed10927b2a 100644 --- a/tests/introspection/tests.py +++ b/tests/introspection/tests.py @@ -61,8 +61,9 @@ class IntrospectionTests(TransactionTestCase): def test_sequence_list(self): sequences = connection.introspection.sequence_list() - expected = {'table': Reporter._meta.db_table, 'column': 'id'} - self.assertIn(expected, sequences, 'Reporter sequence not found in sequence_list()') + reporter_seqs = [seq for seq in sequences if seq['table'] == Reporter._meta.db_table] + self.assertEqual(len(reporter_seqs), 1, 'Reporter sequence not found in sequence_list()') + self.assertEqual(reporter_seqs[0]['column'], 'id') def test_get_table_description_names(self): with connection.cursor() as cursor: