Refs #27090 -- Added real database sequence introspection.
Thanks Mariusz Felisiak for the Oracle part and Tim Graham for the review.
This commit is contained in:
parent
c2ecef869c
commit
c6a1faecc3
|
@ -113,9 +113,10 @@ class BaseDatabaseIntrospection:
|
||||||
all apps.
|
all apps.
|
||||||
"""
|
"""
|
||||||
from django.apps import apps
|
from django.apps import apps
|
||||||
from django.db import models, router
|
from django.db import router
|
||||||
|
|
||||||
sequence_list = []
|
sequence_list = []
|
||||||
|
cursor = self.connection.cursor()
|
||||||
|
|
||||||
for app_config in apps.get_app_configs():
|
for app_config in apps.get_app_configs():
|
||||||
for model in router.get_migratable_models(app_config, self.connection.alias):
|
for model in router.get_migratable_models(app_config, self.connection.alias):
|
||||||
|
@ -123,19 +124,23 @@ class BaseDatabaseIntrospection:
|
||||||
continue
|
continue
|
||||||
if model._meta.swapped:
|
if model._meta.swapped:
|
||||||
continue
|
continue
|
||||||
for f in model._meta.local_fields:
|
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
|
||||||
if isinstance(f, models.AutoField):
|
|
||||||
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
|
|
||||||
break # Only one AutoField is allowed per model, so don't bother continuing.
|
|
||||||
|
|
||||||
for f in model._meta.local_many_to_many:
|
for f in model._meta.local_many_to_many:
|
||||||
# If this is an m2m using an intermediate table,
|
# If this is an m2m using an intermediate table,
|
||||||
# we don't need to reset the sequence.
|
# we don't need to reset the sequence.
|
||||||
if f.remote_field.through is None:
|
if f.remote_field.through is None:
|
||||||
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
|
sequence = self.get_sequences(cursor, f.m2m_db_table())
|
||||||
|
sequence_list.extend(sequence if sequence else [{'table': f.m2m_db_table(), 'column': None}])
|
||||||
return sequence_list
|
return sequence_list
|
||||||
|
|
||||||
|
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||||
|
"""
|
||||||
|
Return a list of introspected sequences for table_name. Each sequence
|
||||||
|
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
|
||||||
|
'name' key can be added if the backend supports named sequences.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
|
||||||
|
|
||||||
def get_key_columns(self, cursor, table_name):
|
def get_key_columns(self, cursor, table_name):
|
||||||
"""
|
"""
|
||||||
Backends can override this to return a list of:
|
Backends can override this to return a list of:
|
||||||
|
|
|
@ -105,6 +105,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
)
|
)
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||||
|
for field_info in self.get_table_description(cursor, table_name):
|
||||||
|
if 'auto_increment' in field_info.extra:
|
||||||
|
# MySQL allows only one auto-increment column per table.
|
||||||
|
return [{'table': table_name, 'column': field_info.name}]
|
||||||
|
return []
|
||||||
|
|
||||||
def get_relations(self, cursor, table_name):
|
def get_relations(self, cursor, table_name):
|
||||||
"""
|
"""
|
||||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||||
|
|
|
@ -3,6 +3,7 @@ from collections import namedtuple
|
||||||
|
|
||||||
import cx_Oracle
|
import cx_Oracle
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
from django.db.backends.base.introspection import (
|
from django.db.backends.base.introspection import (
|
||||||
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
||||||
)
|
)
|
||||||
|
@ -98,6 +99,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
"""Table name comparison is case insensitive under Oracle."""
|
"""Table name comparison is case insensitive under Oracle."""
|
||||||
return name.lower()
|
return name.lower()
|
||||||
|
|
||||||
|
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT
|
||||||
|
user_tab_identity_cols.sequence_name,
|
||||||
|
user_tab_identity_cols.column_name
|
||||||
|
FROM
|
||||||
|
user_tab_identity_cols,
|
||||||
|
user_constraints,
|
||||||
|
user_cons_columns cols
|
||||||
|
WHERE
|
||||||
|
user_constraints.constraint_name = cols.constraint_name
|
||||||
|
AND user_constraints.table_name = user_tab_identity_cols.table_name
|
||||||
|
AND cols.column_name = user_tab_identity_cols.column_name
|
||||||
|
AND user_constraints.constraint_type = 'P'
|
||||||
|
AND user_tab_identity_cols.table_name = UPPER(%s)
|
||||||
|
""", [table_name])
|
||||||
|
# Oracle allows only one identity column per table.
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return [{'name': row[0].lower(), 'table': table_name, 'column': row[1].lower()}]
|
||||||
|
# To keep backward compatibility for AutoFields that aren't Oracle
|
||||||
|
# identity columns.
|
||||||
|
for f in table_fields:
|
||||||
|
if isinstance(f, models.AutoField):
|
||||||
|
return [{'table': table_name, 'column': f.column}]
|
||||||
|
return []
|
||||||
|
|
||||||
def get_relations(self, cursor, table_name):
|
def get_relations(self, cursor, table_name):
|
||||||
"""
|
"""
|
||||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||||
|
|
|
@ -82,6 +82,26 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
for line in cursor.description
|
for line in cursor.description
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||||
|
sequences = []
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT s.relname as sequence_name, col.attname
|
||||||
|
FROM pg_class s
|
||||||
|
JOIN pg_namespace sn ON sn.oid = s.relnamespace
|
||||||
|
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass
|
||||||
|
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
|
||||||
|
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
|
||||||
|
JOIN pg_class tbl ON tbl.oid = ad.adrelid
|
||||||
|
JOIN pg_namespace n ON n.oid = tbl.relnamespace
|
||||||
|
WHERE s.relkind = 'S'
|
||||||
|
AND d.deptype in ('a', 'n')
|
||||||
|
AND n.nspname = 'public'
|
||||||
|
AND tbl.relname = %s
|
||||||
|
""", [table_name])
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
sequences.append({'name': row[0], 'table': table_name, 'column': row[1]})
|
||||||
|
return sequences
|
||||||
|
|
||||||
def get_relations(self, cursor, table_name):
|
def get_relations(self, cursor, table_name):
|
||||||
"""
|
"""
|
||||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||||
|
|
|
@ -85,6 +85,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
) for info in self._table_info(cursor, table_name)
|
) for info in self._table_info(cursor, table_name)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||||
|
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||||
|
return [{'table': table_name, 'column': pk_col}]
|
||||||
|
|
||||||
def column_name_converter(self, name):
|
def column_name_converter(self, name):
|
||||||
"""
|
"""
|
||||||
SQLite will in some cases, e.g. when returning columns from views and
|
SQLite will in some cases, e.g. when returning columns from views and
|
||||||
|
|
|
@ -410,6 +410,10 @@ backends.
|
||||||
for consistency). ``django.test`` also now passes those values as strings
|
for consistency). ``django.test`` also now passes those values as strings
|
||||||
rather than as integers.
|
rather than as integers.
|
||||||
|
|
||||||
|
* Third-party database backends should add a
|
||||||
|
``DatabaseIntrospection.get_sequences()`` method based on the stub in
|
||||||
|
``BaseDatabaseIntrospection``.
|
||||||
|
|
||||||
Dropped support for Oracle 11.2
|
Dropped support for Oracle 11.2
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
|
from ..models import Person
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests')
|
||||||
|
class DatabaseSequenceTests(TransactionTestCase):
|
||||||
|
available_apps = []
|
||||||
|
|
||||||
|
def test_get_sequences(self):
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields)
|
||||||
|
self.assertEqual(len(seqs), 1)
|
||||||
|
self.assertIsNotNone(seqs[0]['name'])
|
||||||
|
self.assertEqual(seqs[0]['table'], Person._meta.db_table)
|
||||||
|
self.assertEqual(seqs[0]['column'], 'id')
|
||||||
|
|
||||||
|
def test_get_sequences_manually_created_index(self):
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
editor._drop_identity(Person._meta.db_table, 'id')
|
||||||
|
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields)
|
||||||
|
self.assertEqual(seqs, [{'table': Person._meta.db_table, 'column': 'id'}])
|
||||||
|
# Recreate model, because adding identity is impossible.
|
||||||
|
editor.delete_model(Person)
|
||||||
|
editor.create_model(Person)
|
|
@ -0,0 +1,23 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from ..models import Person
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
|
||||||
|
class DatabaseSequenceTests(TestCase):
|
||||||
|
def test_get_sequences(self):
|
||||||
|
cursor = connection.cursor()
|
||||||
|
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
|
||||||
|
self.assertEqual(
|
||||||
|
seqs,
|
||||||
|
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
|
||||||
|
)
|
||||||
|
cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
|
||||||
|
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
|
||||||
|
self.assertEqual(
|
||||||
|
seqs,
|
||||||
|
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
|
||||||
|
)
|
|
@ -61,8 +61,9 @@ class IntrospectionTests(TransactionTestCase):
|
||||||
|
|
||||||
def test_sequence_list(self):
|
def test_sequence_list(self):
|
||||||
sequences = connection.introspection.sequence_list()
|
sequences = connection.introspection.sequence_list()
|
||||||
expected = {'table': Reporter._meta.db_table, 'column': 'id'}
|
reporter_seqs = [seq for seq in sequences if seq['table'] == Reporter._meta.db_table]
|
||||||
self.assertIn(expected, sequences, 'Reporter sequence not found in sequence_list()')
|
self.assertEqual(len(reporter_seqs), 1, 'Reporter sequence not found in sequence_list()')
|
||||||
|
self.assertEqual(reporter_seqs[0]['column'], 'id')
|
||||||
|
|
||||||
def test_get_table_description_names(self):
|
def test_get_table_description_names(self):
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
|
|
Loading…
Reference in New Issue