Refs #27090 -- Added real database sequence introspection.

Thanks Mariusz Felisiak for the Oracle part and Tim Graham for the
review.
This commit is contained in:
Mariusz Felisiak 2017-09-13 20:12:32 +02:00 committed by GitHub
parent c2ecef869c
commit c6a1faecc3
9 changed files with 131 additions and 10 deletions

View File

@ -113,9 +113,10 @@ class BaseDatabaseIntrospection:
all apps. all apps.
""" """
from django.apps import apps from django.apps import apps
from django.db import models, router from django.db import router
sequence_list = [] sequence_list = []
cursor = self.connection.cursor()
for app_config in apps.get_app_configs(): for app_config in apps.get_app_configs():
for model in router.get_migratable_models(app_config, self.connection.alias): for model in router.get_migratable_models(app_config, self.connection.alias):
@ -123,19 +124,23 @@ class BaseDatabaseIntrospection:
continue continue
if model._meta.swapped: if model._meta.swapped:
continue continue
for f in model._meta.local_fields: sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many: for f in model._meta.local_many_to_many:
# If this is an m2m using an intermediate table, # If this is an m2m using an intermediate table,
# we don't need to reset the sequence. # we don't need to reset the sequence.
if f.remote_field.through is None: if f.remote_field.through is None:
sequence_list.append({'table': f.m2m_db_table(), 'column': None}) sequence = self.get_sequences(cursor, f.m2m_db_table())
sequence_list.extend(sequence if sequence else [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()):
"""
Return a list of introspected sequences for table_name. Each sequence
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
'name' key can be added if the backend supports named sequences.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
def get_key_columns(self, cursor, table_name): def get_key_columns(self, cursor, table_name):
""" """
Backends can override this to return a list of: Backends can override this to return a list of:

View File

@ -105,6 +105,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
) )
return fields return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if 'auto_increment' in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{'table': table_name, 'column': field_info.name}]
return []
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Return a dictionary of {field_name: (field_name_other_table, other_table)} Return a dictionary of {field_name: (field_name_other_table, other_table)}

View File

@ -3,6 +3,7 @@ from collections import namedtuple
import cx_Oracle import cx_Oracle
from django.db import models
from django.db.backends.base.introspection import ( from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
) )
@ -98,6 +99,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"""Table name comparison is case insensitive under Oracle.""" """Table name comparison is case insensitive under Oracle."""
return name.lower() return name.lower()
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute("""
SELECT
user_tab_identity_cols.sequence_name,
user_tab_identity_cols.column_name
FROM
user_tab_identity_cols,
user_constraints,
user_cons_columns cols
WHERE
user_constraints.constraint_name = cols.constraint_name
AND user_constraints.table_name = user_tab_identity_cols.table_name
AND cols.column_name = user_tab_identity_cols.column_name
AND user_constraints.constraint_type = 'P'
AND user_tab_identity_cols.table_name = UPPER(%s)
""", [table_name])
# Oracle allows only one identity column per table.
row = cursor.fetchone()
if row:
return [{'name': row[0].lower(), 'table': table_name, 'column': row[1].lower()}]
# To keep backward compatibility for AutoFields that aren't Oracle
# identity columns.
for f in table_fields:
if isinstance(f, models.AutoField):
return [{'table': table_name, 'column': f.column}]
return []
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Return a dictionary of {field_name: (field_name_other_table, other_table)} Return a dictionary of {field_name: (field_name_other_table, other_table)}

View File

@ -82,6 +82,26 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for line in cursor.description for line in cursor.description
] ]
def get_sequences(self, cursor, table_name, table_fields=()):
sequences = []
cursor.execute("""
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
JOIN pg_namespace n ON n.oid = tbl.relnamespace
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND n.nspname = 'public'
AND tbl.relname = %s
""", [table_name])
for row in cursor.fetchall():
sequences.append({'name': row[0], 'table': table_name, 'column': row[1]})
return sequences
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Return a dictionary of {field_name: (field_name_other_table, other_table)} Return a dictionary of {field_name: (field_name_other_table, other_table)}

View File

@ -85,6 +85,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
) for info in self._table_info(cursor, table_name) ) for info in self._table_info(cursor, table_name)
] ]
def get_sequences(self, cursor, table_name, table_fields=()):
pk_col = self.get_primary_key_column(cursor, table_name)
return [{'table': table_name, 'column': pk_col}]
def column_name_converter(self, name): def column_name_converter(self, name):
""" """
SQLite will in some cases, e.g. when returning columns from views and SQLite will in some cases, e.g. when returning columns from views and

View File

@ -410,6 +410,10 @@ backends.
for consistency). ``django.test`` also now passes those values as strings for consistency). ``django.test`` also now passes those values as strings
rather than as integers. rather than as integers.
* Third-party database backends should add a
``DatabaseIntrospection.get_sequences()`` method based on the stub in
``BaseDatabaseIntrospection``.
Dropped support for Oracle 11.2 Dropped support for Oracle 11.2
------------------------------- -------------------------------

View File

@ -0,0 +1,29 @@
import unittest
from django.db import connection
from django.test import TransactionTestCase
from ..models import Person
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests')
class DatabaseSequenceTests(TransactionTestCase):
available_apps = []
def test_get_sequences(self):
with connection.cursor() as cursor:
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields)
self.assertEqual(len(seqs), 1)
self.assertIsNotNone(seqs[0]['name'])
self.assertEqual(seqs[0]['table'], Person._meta.db_table)
self.assertEqual(seqs[0]['column'], 'id')
def test_get_sequences_manually_created_index(self):
with connection.cursor() as cursor:
with connection.schema_editor() as editor:
editor._drop_identity(Person._meta.db_table, 'id')
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields)
self.assertEqual(seqs, [{'table': Person._meta.db_table, 'column': 'id'}])
# Recreate model, because adding identity is impossible.
editor.delete_model(Person)
editor.create_model(Person)

View File

@ -0,0 +1,23 @@
import unittest
from django.db import connection
from django.test import TestCase
from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase):
def test_get_sequences(self):
cursor = connection.cursor()
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual(
seqs,
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
)
cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual(
seqs,
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
)

View File

@ -61,8 +61,9 @@ class IntrospectionTests(TransactionTestCase):
def test_sequence_list(self): def test_sequence_list(self):
sequences = connection.introspection.sequence_list() sequences = connection.introspection.sequence_list()
expected = {'table': Reporter._meta.db_table, 'column': 'id'} reporter_seqs = [seq for seq in sequences if seq['table'] == Reporter._meta.db_table]
self.assertIn(expected, sequences, 'Reporter sequence not found in sequence_list()') self.assertEqual(len(reporter_seqs), 1, 'Reporter sequence not found in sequence_list()')
self.assertEqual(reporter_seqs[0]['column'], 'id')
def test_get_table_description_names(self): def test_get_table_description_names(self):
with connection.cursor() as cursor: with connection.cursor() as cursor: