Fixed #31615 -- Made migrations skip extension operations if not needed.

- Don't try to create an existing extension.
- Don't try to drop a nonexistent extension.
This commit is contained in:
Frantisek Holop 2020-05-22 11:40:34 +02:00 committed by Mariusz Felisiak
parent f3ed42c8ad
commit d693a086de
4 changed files with 54 additions and 6 deletions

View File

@ -21,7 +21,10 @@ class CreateExtension(Operation):
not router.allow_migrate(schema_editor.connection.alias, app_label)
):
return
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
if not self.extension_exists(schema_editor, self.name):
schema_editor.execute(
'CREATE EXTENSION %s' % schema_editor.quote_name(self.name)
)
# Clear cached, stale oids.
get_hstore_oids.cache_clear()
get_citext_oids.cache_clear()
@ -33,11 +36,22 @@ class CreateExtension(Operation):
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if not router.allow_migrate(schema_editor.connection.alias, app_label):
return
schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
if self.extension_exists(schema_editor, self.name):
schema_editor.execute(
'DROP EXTENSION %s' % schema_editor.quote_name(self.name)
)
# Clear cached, stale oids.
get_hstore_oids.cache_clear()
get_citext_oids.cache_clear()
def extension_exists(self, schema_editor, extension):
with schema_editor.connection.cursor() as cursor:
cursor.execute(
'SELECT 1 FROM pg_extension WHERE extname = %s',
[extension],
)
return bool(cursor.fetchone())
def describe(self):
return "Creates extension %s" % self.name

View File

@ -30,12 +30,19 @@ For example::
...
]
Django checks that the extension already exists in the database and skips the
migration if so.
For most extensions, this requires a database user with superuser privileges.
If the Django database user doesn't have the appropriate privileges, you'll
have to create the extension outside of Django migrations with a user that has
them. In that case, connect to your Django database and run the query
``CREATE EXTENSION IF NOT EXISTS hstore;``.
.. versionchanged:: 3.2
In older versions, the existence of an extension isn't checked.
.. currentmodule:: django.contrib.postgres.operations
``CreateExtension``

View File

@ -63,7 +63,9 @@ Minor features
:mod:`django.contrib.messages`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ...
* The :class:`~django.contrib.postgres.operations.CreateExtension` operation
now checks that the extension already exists in the database and skips the
migration if so.
:mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -12,7 +12,8 @@ from . import PostgreSQLTestCase
try:
from django.contrib.postgres.operations import (
AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently,
AddIndexConcurrently, BloomExtension, CreateExtension,
RemoveIndexConcurrently,
)
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
except ImportError:
@ -180,9 +181,33 @@ class CreateExtensionTests(PostgreSQLTestCase):
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertIn('CREATE EXTENSION', captured_queries[0]['sql'])
self.assertEqual(len(captured_queries), 4)
self.assertIn('CREATE EXTENSION', captured_queries[1]['sql'])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
self.assertIn('DROP EXTENSION', captured_queries[0]['sql'])
self.assertEqual(len(captured_queries), 2)
self.assertIn('DROP EXTENSION', captured_queries[1]['sql'])
def test_create_existing_extension(self):
operation = BloomExtension()
project_state = ProjectState()
new_state = project_state.clone()
# Don't create an existing extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertEqual(len(captured_queries), 3)
self.assertIn('SELECT', captured_queries[0]['sql'])
def test_drop_nonexistent_extension(self):
operation = CreateExtension('tablefunc')
project_state = ProjectState()
new_state = project_state.clone()
# Don't drop a nonexistent extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, project_state, new_state)
self.assertEqual(len(captured_queries), 1)
self.assertIn('SELECT', captured_queries[0]['sql'])