diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py index 0bb131ddf2..740911d059 100644 --- a/django/contrib/postgres/operations.py +++ b/django/contrib/postgres/operations.py @@ -21,7 +21,10 @@ class CreateExtension(Operation): not router.allow_migrate(schema_editor.connection.alias, app_label) ): return - schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) + if not self.extension_exists(schema_editor, self.name): + schema_editor.execute( + 'CREATE EXTENSION %s' % schema_editor.quote_name(self.name) + ) # Clear cached, stale oids. get_hstore_oids.cache_clear() get_citext_oids.cache_clear() @@ -33,11 +36,22 @@ class CreateExtension(Operation): def database_backwards(self, app_label, schema_editor, from_state, to_state): if not router.allow_migrate(schema_editor.connection.alias, app_label): return - schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) + if self.extension_exists(schema_editor, self.name): + schema_editor.execute( + 'DROP EXTENSION %s' % schema_editor.quote_name(self.name) + ) # Clear cached, stale oids. get_hstore_oids.cache_clear() get_citext_oids.cache_clear() + def extension_exists(self, schema_editor, extension): + with schema_editor.connection.cursor() as cursor: + cursor.execute( + 'SELECT 1 FROM pg_extension WHERE extname = %s', + [extension], + ) + return bool(cursor.fetchone()) + def describe(self): return "Creates extension %s" % self.name diff --git a/docs/ref/contrib/postgres/operations.txt b/docs/ref/contrib/postgres/operations.txt index 8696e4e81f..8571cc84e8 100644 --- a/docs/ref/contrib/postgres/operations.txt +++ b/docs/ref/contrib/postgres/operations.txt @@ -30,12 +30,19 @@ For example:: ... ] +Django checks that the extension already exists in the database and skips the +migration if so. + For most extensions, this requires a database user with superuser privileges. If the Django database user doesn't have the appropriate privileges, you'll have to create the extension outside of Django migrations with a user that has them. In that case, connect to your Django database and run the query ``CREATE EXTENSION IF NOT EXISTS hstore;``. +.. versionchanged:: 3.2 + + In older versions, the existence of an extension isn't checked. + .. currentmodule:: django.contrib.postgres.operations ``CreateExtension`` diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 84f757cd2b..bd8db085a9 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -63,7 +63,9 @@ Minor features :mod:`django.contrib.messages` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -* ... +* The :class:`~django.contrib.postgres.operations.CreateExtension` operation + now checks that the extension already exists in the database and skips the + migration if so. :mod:`django.contrib.postgres` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py index 0a9d8040ef..8cc9a2b66e 100644 --- a/tests/postgres_tests/test_operations.py +++ b/tests/postgres_tests/test_operations.py @@ -12,7 +12,8 @@ from . import PostgreSQLTestCase try: from django.contrib.postgres.operations import ( - AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently, + AddIndexConcurrently, BloomExtension, CreateExtension, + RemoveIndexConcurrently, ) from django.contrib.postgres.indexes import BrinIndex, BTreeIndex except ImportError: @@ -180,9 +181,33 @@ class CreateExtensionTests(PostgreSQLTestCase): with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) - self.assertIn('CREATE EXTENSION', captured_queries[0]['sql']) + self.assertEqual(len(captured_queries), 4) + self.assertIn('CREATE EXTENSION', captured_queries[1]['sql']) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) - self.assertIn('DROP EXTENSION', captured_queries[0]['sql']) + self.assertEqual(len(captured_queries), 2) + self.assertIn('DROP EXTENSION', captured_queries[1]['sql']) + + def test_create_existing_extension(self): + operation = BloomExtension() + project_state = ProjectState() + new_state = project_state.clone() + # Don't create an existing extension. + with CaptureQueriesContext(connection) as captured_queries: + with connection.schema_editor(atomic=False) as editor: + operation.database_forwards(self.app_label, editor, project_state, new_state) + self.assertEqual(len(captured_queries), 3) + self.assertIn('SELECT', captured_queries[0]['sql']) + + def test_drop_nonexistent_extension(self): + operation = CreateExtension('tablefunc') + project_state = ProjectState() + new_state = project_state.clone() + # Don't drop a nonexistent extension. + with CaptureQueriesContext(connection) as captured_queries: + with connection.schema_editor(atomic=False) as editor: + operation.database_backwards(self.app_label, editor, project_state, new_state) + self.assertEqual(len(captured_queries), 1) + self.assertIn('SELECT', captured_queries[0]['sql'])