Fixed #30511 -- Used identity columns instead of serials on PostgreSQL.

This commit is contained in:
Florian Apolloner 2022-03-24 16:46:19 +01:00 committed by Mariusz Felisiak
parent 62ffc9883a
commit 2eea361eff
7 changed files with 90 additions and 95 deletions

View File

@ -904,6 +904,9 @@ class BaseDatabaseSchemaEditor:
actions = [] actions = []
null_actions = [] null_actions = []
post_actions = [] post_actions = []
# Type suffix change? (e.g. auto increment).
old_type_suffix = old_field.db_type_suffix(connection=self.connection)
new_type_suffix = new_field.db_type_suffix(connection=self.connection)
# Collation change? # Collation change?
old_collation = getattr(old_field, "db_collation", None) old_collation = getattr(old_field, "db_collation", None)
new_collation = getattr(new_field, "db_collation", None) new_collation = getattr(new_field, "db_collation", None)
@ -914,7 +917,7 @@ class BaseDatabaseSchemaEditor:
) )
actions.append(fragment) actions.append(fragment)
# Type change? # Type change?
elif old_type != new_type: elif (old_type, old_type_suffix) != (new_type, new_type_suffix):
fragment, other_actions = self._alter_column_type_sql( fragment, other_actions = self._alter_column_type_sql(
model, old_field, new_field, new_type model, old_field, new_field, new_type
) )

View File

@ -72,8 +72,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# be interpolated against the values of Field.__dict__ before being output. # be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output. # If a column type is set to None, it won't be included in the output.
data_types = { data_types = {
"AutoField": "serial", "AutoField": "integer",
"BigAutoField": "bigserial", "BigAutoField": "bigint",
"BinaryField": "bytea", "BinaryField": "bytea",
"BooleanField": "boolean", "BooleanField": "boolean",
"CharField": "varchar(%(max_length)s)", "CharField": "varchar(%(max_length)s)",
@ -94,7 +94,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"PositiveIntegerField": "integer", "PositiveIntegerField": "integer",
"PositiveSmallIntegerField": "smallint", "PositiveSmallIntegerField": "smallint",
"SlugField": "varchar(%(max_length)s)", "SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallserial", "SmallAutoField": "smallint",
"SmallIntegerField": "smallint", "SmallIntegerField": "smallint",
"TextField": "text", "TextField": "text",
"TimeField": "time", "TimeField": "time",
@ -105,6 +105,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"PositiveIntegerField": '"%(column)s" >= 0', "PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0', "PositiveSmallIntegerField": '"%(column)s" >= 0',
} }
data_types_suffix = {
"AutoField": "GENERATED BY DEFAULT AS IDENTITY",
"BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
"SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
}
operators = { operators = {
"exact": "= %s", "exact": "= %s",
"iexact": "= UPPER(%s)", "iexact": "= UPPER(%s)",

View File

@ -1,10 +1,12 @@
from django.db.backends.base.introspection import ( from collections import namedtuple
BaseDatabaseIntrospection,
FieldInfo, from django.db.backends.base.introspection import BaseDatabaseIntrospection
TableInfo, from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
) from django.db.backends.base.introspection import TableInfo
from django.db.models import Index from django.db.models import Index
FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield",))
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types. # Maps type codes to Django Field types.
@ -37,7 +39,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_field_type(self, data_type, description): def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description) field_type = super().get_field_type(data_type, description)
if description.default and "nextval" in description.default: if description.is_autofield or (
# Required for pre-Django 4.1 serial columns.
description.default
and "nextval" in description.default
):
if field_type == "IntegerField": if field_type == "IntegerField":
return "AutoField" return "AutoField"
elif field_type == "BigIntegerField": elif field_type == "BigIntegerField":
@ -84,7 +90,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
a.attname AS column_name, a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable, NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default, pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
a.attidentity != '' AS is_autofield
FROM pg_attribute a FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
LEFT JOIN pg_collation co ON a.attcollation = co.oid LEFT JOIN pg_collation co ON a.attcollation = co.oid
@ -118,23 +125,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_sequences(self, cursor, table_name, table_fields=()): def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute( cursor.execute(
""" """
SELECT s.relname as sequence_name, col.attname SELECT
FROM pg_class s s.relname AS sequence_name,
JOIN pg_namespace sn ON sn.oid = s.relnamespace a.attname AS colname
JOIN FROM
pg_depend d ON d.refobjid = s.oid pg_class s
JOIN pg_depend d ON d.objid = s.oid
AND d.classid = 'pg_class'::regclass
AND d.refclassid = 'pg_class'::regclass AND d.refclassid = 'pg_class'::regclass
JOIN JOIN pg_attribute a ON d.refobjid = a.attrelid
pg_attrdef ad ON ad.oid = d.objid AND d.refobjsubid = a.attnum
AND d.classid = 'pg_attrdef'::regclass JOIN pg_class tbl ON tbl.oid = d.refobjid
JOIN AND tbl.relname = %s
pg_attribute col ON col.attrelid = ad.adrelid AND pg_catalog.pg_table_is_visible(tbl.oid)
AND col.attnum = ad.adnum WHERE
JOIN pg_class tbl ON tbl.oid = ad.adrelid s.relkind = 'S';
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
""", """,
[table_name], [table_name],
) )

View File

@ -7,12 +7,7 @@ from django.db.backends.utils import strip_quotes
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = (
"SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
)
sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_create_index = ( sql_create_index = (
"CREATE INDEX %(name)s ON %(table)s%(using)s " "CREATE INDEX %(name)s ON %(table)s%(using)s "
@ -39,6 +34,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
) )
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
sql_add_identity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
"GENERATED BY DEFAULT AS IDENTITY"
)
sql_drop_indentity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
)
def quote_value(self, value): def quote_value(self, value):
if isinstance(value, str): if isinstance(value, str):
value = value.replace("%", "%%") value = value.replace("%", "%%")
@ -116,78 +119,47 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
self.sql_alter_column_type += using_sql self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field): elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense. # Make ALTER TYPE with IDENTITY make sense.
table = strip_quotes(model._meta.db_table) table = strip_quotes(model._meta.db_table)
serial_fields_map = { auto_field_types = {
"bigserial": "bigint", "AutoField",
"serial": "integer", "BigAutoField",
"smallserial": "smallint", "SmallAutoField",
} }
if new_type.lower() in serial_fields_map: old_is_auto = old_internal_type in auto_field_types
new_is_auto = new_internal_type in auto_field_types
if new_is_auto and not old_is_auto:
column = strip_quotes(new_field.column) column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return ( return (
( (
self.sql_alter_column_type self.sql_alter_column_type
% { % {
"column": self.quote_name(column), "column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()], "type": new_type,
}, },
[], [],
), ),
[ [
( (
self.sql_delete_sequence self.sql_add_identity
% {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_create_sequence
% {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_alter_column
% {
"table": self.quote_name(table),
"changes": self.sql_alter_column_default
% {
"column": self.quote_name(column),
"default": "nextval('%s')"
% self.quote_name(sequence_name),
},
},
[],
),
(
self.sql_set_sequence_max
% { % {
"table": self.quote_name(table), "table": self.quote_name(table),
"column": self.quote_name(column), "column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_set_sequence_owner
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
}, },
[], [],
), ),
], ],
) )
elif ( elif old_is_auto and not new_is_auto:
old_field.db_parameters(connection=self.connection)["type"] # Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have
in serial_fields_map # it).
): self.execute(
# Drop the sequence if migrating away from AutoField. self.sql_drop_indentity
% {
"table": self.quote_name(table),
"column": self.quote_name(strip_quotes(old_field.column)),
}
)
column = strip_quotes(new_field.column) column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column) sequence_name = "%s_%s_seq" % (table, column)
fragment, _ = super()._alter_column_type_sql( fragment, _ = super()._alter_column_type_sql(
@ -195,6 +167,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
) )
return fragment, [ return fragment, [
( (
# Drop the sequence if exists (Django 4.1+ identity columns
# don't have it).
self.sql_delete_sequence self.sql_delete_sequence
% { % {
"sequence": self.quote_name(sequence_name), "sequence": self.quote_name(sequence_name),

View File

@ -309,6 +309,9 @@ Models
allows customizing attributes of fields that don't affect a column allows customizing attributes of fields that don't affect a column
definition. definition.
* On PostgreSQL, ``AutoField``, ``BigAutoField``, and ``SmallAutoField`` are
now created as identity columns rather than serial columns with sequences.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -27,3 +27,18 @@ class DatabaseSequenceTests(TestCase):
seqs, seqs,
[{"table": Person._meta.db_table, "column": "id", "name": "pers_seq"}], [{"table": Person._meta.db_table, "column": "id", "name": "pers_seq"}],
) )
def test_get_sequences_old_serial(self):
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE testing (serial_field SERIAL);")
seqs = connection.introspection.get_sequences(cursor, "testing")
self.assertEqual(
seqs,
[
{
"table": "testing",
"column": "serial_field",
"name": "testing_serial_field_seq",
}
],
)

View File

@ -1570,11 +1570,7 @@ class SchemaTests(TransactionTestCase):
Author.objects.create(name="Foo") Author.objects.create(name="Foo")
Author.objects.create(name="Bar") Author.objects.create(name="Bar")
def test_alter_autofield_pk_to_bigautofield_pk_sequence_owner(self): def test_alter_autofield_pk_to_bigautofield_pk(self):
"""
Converting an implicit PK to BigAutoField(primary_key=True) should keep
a sequence owner on PostgreSQL.
"""
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_model(Author) editor.create_model(Author)
old_field = Author._meta.get_field("id") old_field = Author._meta.get_field("id")
@ -1591,14 +1587,9 @@ class SchemaTests(TransactionTestCase):
) )
if sequence_reset_sqls: if sequence_reset_sqls:
cursor.execute(sequence_reset_sqls[0]) cursor.execute(sequence_reset_sqls[0])
# Fail on PostgreSQL if sequence is missing an owner.
self.assertIsNotNone(Author.objects.create(name="Bar")) self.assertIsNotNone(Author.objects.create(name="Bar"))
def test_alter_autofield_pk_to_smallautofield_pk_sequence_owner(self): def test_alter_autofield_pk_to_smallautofield_pk(self):
"""
Converting an implicit PK to SmallAutoField(primary_key=True) should
keep a sequence owner on PostgreSQL.
"""
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_model(Author) editor.create_model(Author)
old_field = Author._meta.get_field("id") old_field = Author._meta.get_field("id")
@ -1615,7 +1606,6 @@ class SchemaTests(TransactionTestCase):
) )
if sequence_reset_sqls: if sequence_reset_sqls:
cursor.execute(sequence_reset_sqls[0]) cursor.execute(sequence_reset_sqls[0])
# Fail on PostgreSQL if sequence is missing an owner.
self.assertIsNotNone(Author.objects.create(name="Bar")) self.assertIsNotNone(Author.objects.create(name="Bar"))
def test_alter_int_pk_to_autofield_pk(self): def test_alter_int_pk_to_autofield_pk(self):