mirror of https://github.com/django/django.git
Refs #27090 -- Added real database sequence introspection.
Thanks Mariusz Felisiak for the Oracle part and Tim Graham for the review.
This commit is contained in:
parent
c2ecef869c
commit
c6a1faecc3
|
@ -113,9 +113,10 @@ class BaseDatabaseIntrospection:
|
|||
all apps.
|
||||
"""
|
||||
from django.apps import apps
|
||||
from django.db import models, router
|
||||
from django.db import router
|
||||
|
||||
sequence_list = []
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
for app_config in apps.get_app_configs():
|
||||
for model in router.get_migratable_models(app_config, self.connection.alias):
|
||||
|
@ -123,19 +124,23 @@ class BaseDatabaseIntrospection:
|
|||
continue
|
||||
if model._meta.swapped:
|
||||
continue
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
|
||||
break # Only one AutoField is allowed per model, so don't bother continuing.
|
||||
|
||||
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
|
||||
for f in model._meta.local_many_to_many:
|
||||
# If this is an m2m using an intermediate table,
|
||||
# we don't need to reset the sequence.
|
||||
if f.remote_field.through is None:
|
||||
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
|
||||
|
||||
sequence = self.get_sequences(cursor, f.m2m_db_table())
|
||||
sequence_list.extend(sequence if sequence else [{'table': f.m2m_db_table(), 'column': None}])
|
||||
return sequence_list
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
"""
|
||||
Return a list of introspected sequences for table_name. Each sequence
|
||||
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
|
||||
'name' key can be added if the backend supports named sequences.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
"""
|
||||
Backends can override this to return a list of:
|
||||
|
|
|
@ -105,6 +105,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
)
|
||||
return fields
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
for field_info in self.get_table_description(cursor, table_name):
|
||||
if 'auto_increment' in field_info.extra:
|
||||
# MySQL allows only one auto-increment column per table.
|
||||
return [{'table': table_name, 'column': field_info.name}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
|
|
|
@ -3,6 +3,7 @@ from collections import namedtuple
|
|||
|
||||
import cx_Oracle
|
||||
|
||||
from django.db import models
|
||||
from django.db.backends.base.introspection import (
|
||||
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
||||
)
|
||||
|
@ -98,6 +99,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
"""Table name comparison is case insensitive under Oracle."""
|
||||
return name.lower()
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
user_tab_identity_cols.sequence_name,
|
||||
user_tab_identity_cols.column_name
|
||||
FROM
|
||||
user_tab_identity_cols,
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name
|
||||
AND user_constraints.table_name = user_tab_identity_cols.table_name
|
||||
AND cols.column_name = user_tab_identity_cols.column_name
|
||||
AND user_constraints.constraint_type = 'P'
|
||||
AND user_tab_identity_cols.table_name = UPPER(%s)
|
||||
""", [table_name])
|
||||
# Oracle allows only one identity column per table.
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return [{'name': row[0].lower(), 'table': table_name, 'column': row[1].lower()}]
|
||||
# To keep backward compatibility for AutoFields that aren't Oracle
|
||||
# identity columns.
|
||||
for f in table_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
return [{'table': table_name, 'column': f.column}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
|
|
|
@ -82,6 +82,26 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
for line in cursor.description
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
sequences = []
|
||||
cursor.execute("""
|
||||
SELECT s.relname as sequence_name, col.attname
|
||||
FROM pg_class s
|
||||
JOIN pg_namespace sn ON sn.oid = s.relnamespace
|
||||
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass
|
||||
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
|
||||
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
|
||||
JOIN pg_class tbl ON tbl.oid = ad.adrelid
|
||||
JOIN pg_namespace n ON n.oid = tbl.relnamespace
|
||||
WHERE s.relkind = 'S'
|
||||
AND d.deptype in ('a', 'n')
|
||||
AND n.nspname = 'public'
|
||||
AND tbl.relname = %s
|
||||
""", [table_name])
|
||||
for row in cursor.fetchall():
|
||||
sequences.append({'name': row[0], 'table': table_name, 'column': row[1]})
|
||||
return sequences
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
|
|
|
@ -85,6 +85,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
|||
) for info in self._table_info(cursor, table_name)
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||
return [{'table': table_name, 'column': pk_col}]
|
||||
|
||||
def column_name_converter(self, name):
|
||||
"""
|
||||
SQLite will in some cases, e.g. when returning columns from views and
|
||||
|
|
|
@ -410,6 +410,10 @@ backends.
|
|||
for consistency). ``django.test`` also now passes those values as strings
|
||||
rather than as integers.
|
||||
|
||||
* Third-party database backends should add a
|
||||
``DatabaseIntrospection.get_sequences()`` method based on the stub in
|
||||
``BaseDatabaseIntrospection``.
|
||||
|
||||
Dropped support for Oracle 11.2
|
||||
-------------------------------
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import unittest
|
||||
|
||||
from django.db import connection
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from ..models import Person
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests')
|
||||
class DatabaseSequenceTests(TransactionTestCase):
|
||||
available_apps = []
|
||||
|
||||
def test_get_sequences(self):
|
||||
with connection.cursor() as cursor:
|
||||
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields)
|
||||
self.assertEqual(len(seqs), 1)
|
||||
self.assertIsNotNone(seqs[0]['name'])
|
||||
self.assertEqual(seqs[0]['table'], Person._meta.db_table)
|
||||
self.assertEqual(seqs[0]['column'], 'id')
|
||||
|
||||
def test_get_sequences_manually_created_index(self):
|
||||
with connection.cursor() as cursor:
|
||||
with connection.schema_editor() as editor:
|
||||
editor._drop_identity(Person._meta.db_table, 'id')
|
||||
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table, Person._meta.local_fields)
|
||||
self.assertEqual(seqs, [{'table': Person._meta.db_table, 'column': 'id'}])
|
||||
# Recreate model, because adding identity is impossible.
|
||||
editor.delete_model(Person)
|
||||
editor.create_model(Person)
|
|
@ -0,0 +1,23 @@
|
|||
import unittest
|
||||
|
||||
from django.db import connection
|
||||
from django.test import TestCase
|
||||
|
||||
from ..models import Person
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
|
||||
class DatabaseSequenceTests(TestCase):
|
||||
def test_get_sequences(self):
|
||||
cursor = connection.cursor()
|
||||
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
|
||||
self.assertEqual(
|
||||
seqs,
|
||||
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
|
||||
)
|
||||
cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
|
||||
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
|
||||
self.assertEqual(
|
||||
seqs,
|
||||
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
|
||||
)
|
|
@ -61,8 +61,9 @@ class IntrospectionTests(TransactionTestCase):
|
|||
|
||||
def test_sequence_list(self):
|
||||
sequences = connection.introspection.sequence_list()
|
||||
expected = {'table': Reporter._meta.db_table, 'column': 'id'}
|
||||
self.assertIn(expected, sequences, 'Reporter sequence not found in sequence_list()')
|
||||
reporter_seqs = [seq for seq in sequences if seq['table'] == Reporter._meta.db_table]
|
||||
self.assertEqual(len(reporter_seqs), 1, 'Reporter sequence not found in sequence_list()')
|
||||
self.assertEqual(reporter_seqs[0]['column'], 'id')
|
||||
|
||||
def test_get_table_description_names(self):
|
||||
with connection.cursor() as cursor:
|
||||
|
|
Loading…
Reference in New Issue