diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py index 12c850a073..0bb131ddf2 100644 --- a/django/contrib/postgres/operations.py +++ b/django/contrib/postgres/operations.py @@ -1,7 +1,7 @@ from django.contrib.postgres.signals import ( get_citext_oids, get_hstore_oids, register_type_handlers, ) -from django.db import NotSupportedError +from django.db import NotSupportedError, router from django.db.migrations import AddIndex, RemoveIndex from django.db.migrations.operations.base import Operation @@ -16,7 +16,10 @@ class CreateExtension(Operation): pass def database_forwards(self, app_label, schema_editor, from_state, to_state): - if schema_editor.connection.vendor != 'postgresql': + if ( + schema_editor.connection.vendor != 'postgresql' or + not router.allow_migrate(schema_editor.connection.alias, app_label) + ): return schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) # Clear cached, stale oids. @@ -28,6 +31,8 @@ class CreateExtension(Operation): register_type_handlers(schema_editor.connection) def database_backwards(self, app_label, schema_editor, from_state, to_state): + if not router.allow_migrate(schema_editor.connection.alias, app_label): + return schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) # Clear cached, stale oids. get_hstore_oids.cache_clear() diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py index 95c88d5fe0..7bcf6b2300 100644 --- a/tests/postgres_tests/test_operations.py +++ b/tests/postgres_tests/test_operations.py @@ -3,12 +3,16 @@ import unittest from migrations.test_base import OperationTestBase from django.db import NotSupportedError, connection +from django.db.migrations.state import ProjectState from django.db.models import Index -from django.test import modify_settings +from django.test import modify_settings, override_settings +from django.test.utils import CaptureQueriesContext + +from . import PostgreSQLTestCase try: from django.contrib.postgres.operations import ( - AddIndexConcurrently, RemoveIndexConcurrently, + AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently, ) from django.contrib.postgres.indexes import BrinIndex, BTreeIndex except ImportError: @@ -141,3 +145,44 @@ class RemoveIndexConcurrentlyTests(OperationTestBase): self.assertEqual(name, 'RemoveIndexConcurrently') self.assertEqual(args, []) self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'}) + + +class NoExtensionRouter(): + def allow_migrate(self, db, app_label, **hints): + return False + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') +class CreateExtensionTests(PostgreSQLTestCase): + app_label = 'test_allow_create_extention' + + @override_settings(DATABASE_ROUTERS=[NoExtensionRouter()]) + def test_no_allow_migrate(self): + operation = CreateExtension('uuid-ossp') + project_state = ProjectState() + new_state = project_state.clone() + # Don't create an extension. + with CaptureQueriesContext(connection) as captured_queries: + with connection.schema_editor(atomic=False) as editor: + operation.database_forwards(self.app_label, editor, project_state, new_state) + self.assertEqual(len(captured_queries), 0) + # Reversal. + with CaptureQueriesContext(connection) as captured_queries: + with connection.schema_editor(atomic=False) as editor: + operation.database_backwards(self.app_label, editor, new_state, project_state) + self.assertEqual(len(captured_queries), 0) + + def test_allow_migrate(self): + operation = CreateExtension('uuid-ossp') + project_state = ProjectState() + new_state = project_state.clone() + # Create an extension. + with CaptureQueriesContext(connection) as captured_queries: + with connection.schema_editor(atomic=False) as editor: + operation.database_forwards(self.app_label, editor, project_state, new_state) + self.assertIn('CREATE EXTENSION', captured_queries[0]['sql']) + # Reversal. + with CaptureQueriesContext(connection) as captured_queries: + with connection.schema_editor(atomic=False) as editor: + operation.database_backwards(self.app_label, editor, new_state, project_state) + self.assertIn('DROP EXTENSION', captured_queries[0]['sql'])