Fixed #30511 -- Used identity columns instead of serials on PostgreSQL.

This commit is contained in:
Florian Apolloner 2022-03-24 16:46:19 +01:00 committed by Mariusz Felisiak
parent 62ffc9883a
commit 2eea361eff
7 changed files with 90 additions and 95 deletions

View File

@ -904,6 +904,9 @@ class BaseDatabaseSchemaEditor:
actions = []
null_actions = []
post_actions = []
# Type suffix change? (e.g. auto increment).
old_type_suffix = old_field.db_type_suffix(connection=self.connection)
new_type_suffix = new_field.db_type_suffix(connection=self.connection)
# Collation change?
old_collation = getattr(old_field, "db_collation", None)
new_collation = getattr(new_field, "db_collation", None)
@ -914,7 +917,7 @@ class BaseDatabaseSchemaEditor:
)
actions.append(fragment)
# Type change?
elif old_type != new_type:
elif (old_type, old_type_suffix) != (new_type, new_type_suffix):
fragment, other_actions = self._alter_column_type_sql(
model, old_field, new_field, new_type
)

View File

@ -72,8 +72,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
"AutoField": "serial",
"BigAutoField": "bigserial",
"AutoField": "integer",
"BigAutoField": "bigint",
"BinaryField": "bytea",
"BooleanField": "boolean",
"CharField": "varchar(%(max_length)s)",
@ -94,7 +94,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"PositiveIntegerField": "integer",
"PositiveSmallIntegerField": "smallint",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallserial",
"SmallAutoField": "smallint",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
@ -105,6 +105,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
}
data_types_suffix = {
"AutoField": "GENERATED BY DEFAULT AS IDENTITY",
"BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
"SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
}
operators = {
"exact": "= %s",
"iexact": "= UPPER(%s)",

View File

@ -1,10 +1,12 @@
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection,
FieldInfo,
TableInfo,
)
from collections import namedtuple
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.models import Index
FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield",))
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
@ -37,7 +39,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.default and "nextval" in description.default:
if description.is_autofield or (
# Required for pre-Django 4.1 serial columns.
description.default
and "nextval" in description.default
):
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
@ -84,7 +90,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
a.attidentity != '' AS is_autofield
FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
LEFT JOIN pg_collation co ON a.attcollation = co.oid
@ -118,23 +125,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute(
"""
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN
pg_depend d ON d.refobjid = s.oid
SELECT
s.relname AS sequence_name,
a.attname AS colname
FROM
pg_class s
JOIN pg_depend d ON d.objid = s.oid
AND d.classid = 'pg_class'::regclass
AND d.refclassid = 'pg_class'::regclass
JOIN
pg_attrdef ad ON ad.oid = d.objid
AND d.classid = 'pg_attrdef'::regclass
JOIN
pg_attribute col ON col.attrelid = ad.adrelid
AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
JOIN pg_attribute a ON d.refobjid = a.attrelid
AND d.refobjsubid = a.attnum
JOIN pg_class tbl ON tbl.oid = d.refobjid
AND tbl.relname = %s
AND pg_catalog.pg_table_is_visible(tbl.oid)
WHERE
s.relkind = 'S';
""",
[table_name],
)

View File

@ -7,12 +7,7 @@ from django.db.backends.utils import strip_quotes
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = (
"SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
)
sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_create_index = (
"CREATE INDEX %(name)s ON %(table)s%(using)s "
@ -39,6 +34,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
)
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
sql_add_identity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
"GENERATED BY DEFAULT AS IDENTITY"
)
sql_drop_indentity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
)
def quote_value(self, value):
if isinstance(value, str):
value = value.replace("%", "%%")
@ -116,78 +119,47 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense.
# Make ALTER TYPE with IDENTITY make sense.
table = strip_quotes(model._meta.db_table)
serial_fields_map = {
"bigserial": "bigint",
"serial": "integer",
"smallserial": "smallint",
auto_field_types = {
"AutoField",
"BigAutoField",
"SmallAutoField",
}
if new_type.lower() in serial_fields_map:
old_is_auto = old_internal_type in auto_field_types
new_is_auto = new_internal_type in auto_field_types
if new_is_auto and not old_is_auto:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type
% {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
"type": new_type,
},
[],
),
[
(
self.sql_delete_sequence
% {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_create_sequence
% {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_alter_column
% {
"table": self.quote_name(table),
"changes": self.sql_alter_column_default
% {
"column": self.quote_name(column),
"default": "nextval('%s')"
% self.quote_name(sequence_name),
},
},
[],
),
(
self.sql_set_sequence_max
self.sql_add_identity
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_set_sequence_owner
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
},
[],
),
],
)
elif (
old_field.db_parameters(connection=self.connection)["type"]
in serial_fields_map
):
# Drop the sequence if migrating away from AutoField.
elif old_is_auto and not new_is_auto:
# Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have
# it).
self.execute(
self.sql_drop_indentity
% {
"table": self.quote_name(table),
"column": self.quote_name(strip_quotes(old_field.column)),
}
)
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
fragment, _ = super()._alter_column_type_sql(
@ -195,6 +167,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
)
return fragment, [
(
# Drop the sequence if exists (Django 4.1+ identity columns
# don't have it).
self.sql_delete_sequence
% {
"sequence": self.quote_name(sequence_name),

View File

@ -309,6 +309,9 @@ Models
allows customizing attributes of fields that don't affect a column
definition.
* On PostgreSQL, ``AutoField``, ``BigAutoField``, and ``SmallAutoField`` are
now created as identity columns rather than serial columns with sequences.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -27,3 +27,18 @@ class DatabaseSequenceTests(TestCase):
seqs,
[{"table": Person._meta.db_table, "column": "id", "name": "pers_seq"}],
)
def test_get_sequences_old_serial(self):
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE testing (serial_field SERIAL);")
seqs = connection.introspection.get_sequences(cursor, "testing")
self.assertEqual(
seqs,
[
{
"table": "testing",
"column": "serial_field",
"name": "testing_serial_field_seq",
}
],
)

View File

@ -1570,11 +1570,7 @@ class SchemaTests(TransactionTestCase):
Author.objects.create(name="Foo")
Author.objects.create(name="Bar")
def test_alter_autofield_pk_to_bigautofield_pk_sequence_owner(self):
"""
Converting an implicit PK to BigAutoField(primary_key=True) should keep
a sequence owner on PostgreSQL.
"""
def test_alter_autofield_pk_to_bigautofield_pk(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
old_field = Author._meta.get_field("id")
@ -1591,14 +1587,9 @@ class SchemaTests(TransactionTestCase):
)
if sequence_reset_sqls:
cursor.execute(sequence_reset_sqls[0])
# Fail on PostgreSQL if sequence is missing an owner.
self.assertIsNotNone(Author.objects.create(name="Bar"))
def test_alter_autofield_pk_to_smallautofield_pk_sequence_owner(self):
"""
Converting an implicit PK to SmallAutoField(primary_key=True) should
keep a sequence owner on PostgreSQL.
"""
def test_alter_autofield_pk_to_smallautofield_pk(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
old_field = Author._meta.get_field("id")
@ -1615,7 +1606,6 @@ class SchemaTests(TransactionTestCase):
)
if sequence_reset_sqls:
cursor.execute(sequence_reset_sqls[0])
# Fail on PostgreSQL if sequence is missing an owner.
self.assertIsNotNone(Author.objects.create(name="Bar"))
def test_alter_int_pk_to_autofield_pk(self):