Fixed #31615 -- Made migrations skip extension operations if not needed.

- Don't try to create an existing extension.
- Don't try to drop a nonexistent extension.
This commit is contained in:
Frantisek Holop 2020-05-22 11:40:34 +02:00 committed by Mariusz Felisiak
parent f3ed42c8ad
commit d693a086de
4 changed files with 54 additions and 6 deletions

View File

@ -21,7 +21,10 @@ class CreateExtension(Operation):
not router.allow_migrate(schema_editor.connection.alias, app_label) not router.allow_migrate(schema_editor.connection.alias, app_label)
): ):
return return
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) if not self.extension_exists(schema_editor, self.name):
schema_editor.execute(
'CREATE EXTENSION %s' % schema_editor.quote_name(self.name)
)
# Clear cached, stale oids. # Clear cached, stale oids.
get_hstore_oids.cache_clear() get_hstore_oids.cache_clear()
get_citext_oids.cache_clear() get_citext_oids.cache_clear()
@ -33,11 +36,22 @@ class CreateExtension(Operation):
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
if not router.allow_migrate(schema_editor.connection.alias, app_label): if not router.allow_migrate(schema_editor.connection.alias, app_label):
return return
schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) if self.extension_exists(schema_editor, self.name):
schema_editor.execute(
'DROP EXTENSION %s' % schema_editor.quote_name(self.name)
)
# Clear cached, stale oids. # Clear cached, stale oids.
get_hstore_oids.cache_clear() get_hstore_oids.cache_clear()
get_citext_oids.cache_clear() get_citext_oids.cache_clear()
def extension_exists(self, schema_editor, extension):
with schema_editor.connection.cursor() as cursor:
cursor.execute(
'SELECT 1 FROM pg_extension WHERE extname = %s',
[extension],
)
return bool(cursor.fetchone())
def describe(self): def describe(self):
return "Creates extension %s" % self.name return "Creates extension %s" % self.name

View File

@ -30,12 +30,19 @@ For example::
... ...
] ]
Django checks that the extension already exists in the database and skips the
migration if so.
For most extensions, this requires a database user with superuser privileges. For most extensions, this requires a database user with superuser privileges.
If the Django database user doesn't have the appropriate privileges, you'll If the Django database user doesn't have the appropriate privileges, you'll
have to create the extension outside of Django migrations with a user that has have to create the extension outside of Django migrations with a user that has
them. In that case, connect to your Django database and run the query them. In that case, connect to your Django database and run the query
``CREATE EXTENSION IF NOT EXISTS hstore;``. ``CREATE EXTENSION IF NOT EXISTS hstore;``.
.. versionchanged:: 3.2
In older versions, the existence of an extension isn't checked.
.. currentmodule:: django.contrib.postgres.operations .. currentmodule:: django.contrib.postgres.operations
``CreateExtension`` ``CreateExtension``

View File

@ -63,7 +63,9 @@ Minor features
:mod:`django.contrib.messages` :mod:`django.contrib.messages`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ... * The :class:`~django.contrib.postgres.operations.CreateExtension` operation
now checks that the extension already exists in the database and skips the
migration if so.
:mod:`django.contrib.postgres` :mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -12,7 +12,8 @@ from . import PostgreSQLTestCase
try: try:
from django.contrib.postgres.operations import ( from django.contrib.postgres.operations import (
AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently, AddIndexConcurrently, BloomExtension, CreateExtension,
RemoveIndexConcurrently,
) )
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
except ImportError: except ImportError:
@ -180,9 +181,33 @@ class CreateExtensionTests(PostgreSQLTestCase):
with CaptureQueriesContext(connection) as captured_queries: with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor: with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state) operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertIn('CREATE EXTENSION', captured_queries[0]['sql']) self.assertEqual(len(captured_queries), 4)
self.assertIn('CREATE EXTENSION', captured_queries[1]['sql'])
# Reversal. # Reversal.
with CaptureQueriesContext(connection) as captured_queries: with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor: with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state) operation.database_backwards(self.app_label, editor, new_state, project_state)
self.assertIn('DROP EXTENSION', captured_queries[0]['sql']) self.assertEqual(len(captured_queries), 2)
self.assertIn('DROP EXTENSION', captured_queries[1]['sql'])
def test_create_existing_extension(self):
operation = BloomExtension()
project_state = ProjectState()
new_state = project_state.clone()
# Don't create an existing extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertEqual(len(captured_queries), 3)
self.assertIn('SELECT', captured_queries[0]['sql'])
def test_drop_nonexistent_extension(self):
operation = CreateExtension('tablefunc')
project_state = ProjectState()
new_state = project_state.clone()
# Don't drop a nonexistent extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, project_state, new_state)
self.assertEqual(len(captured_queries), 1)
self.assertIn('SELECT', captured_queries[0]['sql'])