import unittest from unittest import mock from migrations.test_base import OperationTestBase from django.db import ( IntegrityError, NotSupportedError, connection, transaction, ) from django.db.migrations.state import ProjectState from django.db.models import CheckConstraint, Index, Q, UniqueConstraint from django.db.utils import ProgrammingError from django.test import modify_settings, override_settings, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext from . import PostgreSQLTestCase try: from django.contrib.postgres.indexes import BrinIndex, BTreeIndex from django.contrib.postgres.operations import ( AddConstraintNotValid, AddIndexConcurrently, BloomExtension, CreateCollation, CreateExtension, RemoveCollation, RemoveIndexConcurrently, ValidateConstraint, ) except ImportError: pass @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') @modify_settings(INSTALLED_APPS={'append': 'migrations'}) class AddIndexConcurrentlyTests(OperationTestBase): app_label = 'test_add_concurrently' def test_requires_atomic_false(self): project_state = self.set_up_test_model(self.app_label) new_state = project_state.clone() operation = AddIndexConcurrently( 'Pony', Index(fields=['pink'], name='pony_pink_idx'), ) msg = ( 'The AddIndexConcurrently operation cannot be executed inside ' 'a transaction (set atomic = False on the migration).' ) with self.assertRaisesMessage(NotSupportedError, msg): with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) def test_add(self): project_state = self.set_up_test_model(self.app_label, index=False) table_name = '%s_pony' % self.app_label index = Index(fields=['pink'], name='pony_pink_idx') new_state = project_state.clone() operation = AddIndexConcurrently('Pony', index) self.assertEqual( operation.describe(), 'Concurrently create index pony_pink_idx on field(s) pink of ' 'model Pony' ) operation.state_forwards(self.app_label, new_state) self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 1) self.assertIndexNotExists(table_name, ['pink']) # Add index. with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertIndexExists(table_name, ['pink']) # Reversal. with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertIndexNotExists(table_name, ['pink']) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'AddIndexConcurrently') self.assertEqual(args, []) self.assertEqual(kwargs, {'model_name': 'Pony', 'index': index}) def test_add_other_index_type(self): project_state = self.set_up_test_model(self.app_label, index=False) table_name = '%s_pony' % self.app_label new_state = project_state.clone() operation = AddIndexConcurrently( 'Pony', BrinIndex(fields=['pink'], name='pony_pink_brin_idx'), ) self.assertIndexNotExists(table_name, ['pink']) # Add index. with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertIndexExists(table_name, ['pink'], index_type='brin') # Reversal. with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertIndexNotExists(table_name, ['pink']) def test_add_with_options(self): project_state = self.set_up_test_model(self.app_label, index=False) table_name = '%s_pony' % self.app_label new_state = project_state.clone() index = BTreeIndex(fields=['pink'], name='pony_pink_btree_idx', fillfactor=70) operation = AddIndexConcurrently('Pony', index) self.assertIndexNotExists(table_name, ['pink']) # Add index. with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertIndexExists(table_name, ['pink'], index_type='btree') # Reversal. with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertIndexNotExists(table_name, ['pink']) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') @modify_settings(INSTALLED_APPS={'append': 'migrations'}) class RemoveIndexConcurrentlyTests(OperationTestBase): app_label = 'test_rm_concurrently' def test_requires_atomic_false(self): project_state = self.set_up_test_model(self.app_label, index=True) new_state = project_state.clone() operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx') msg = ( 'The RemoveIndexConcurrently operation cannot be executed inside ' 'a transaction (set atomic = False on the migration).' ) with self.assertRaisesMessage(NotSupportedError, msg): with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) def test_remove(self): project_state = self.set_up_test_model(self.app_label, index=True) table_name = '%s_pony' % self.app_label self.assertTableExists(table_name) new_state = project_state.clone() operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx') self.assertEqual( operation.describe(), 'Concurrently remove index pony_pink_idx from Pony', ) operation.state_forwards(self.app_label, new_state) self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 0) self.assertIndexExists(table_name, ['pink']) # Remove index. with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertIndexNotExists(table_name, ['pink']) # Reversal. with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertIndexExists(table_name, ['pink']) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'RemoveIndexConcurrently') self.assertEqual(args, []) self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'}) class NoMigrationRouter(): def allow_migrate(self, db, app_label, **hints): return False @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') class CreateExtensionTests(PostgreSQLTestCase): app_label = 'test_allow_create_extention' @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()]) def test_no_allow_migrate(self): operation = CreateExtension('tablefunc') project_state = ProjectState() new_state = project_state.clone() # Don't create an extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 0) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 0) def test_allow_migrate(self): operation = CreateExtension('tablefunc') self.assertEqual(operation.migration_name_fragment, 'create_extension_tablefunc') project_state = ProjectState() new_state = project_state.clone() # Create an extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 4) self.assertIn('CREATE EXTENSION IF NOT EXISTS', captured_queries[1]['sql']) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 2) self.assertIn('DROP EXTENSION IF EXISTS', captured_queries[1]['sql']) def test_create_existing_extension(self): operation = BloomExtension() self.assertEqual(operation.migration_name_fragment, 'create_extension_bloom') project_state = ProjectState() new_state = project_state.clone() # Don't create an existing extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 3) self.assertIn('SELECT', captured_queries[0]['sql']) def test_drop_nonexistent_extension(self): operation = CreateExtension('tablefunc') project_state = ProjectState() new_state = project_state.clone() # Don't drop a nonexistent extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 1) self.assertIn('SELECT', captured_queries[0]['sql']) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') class CreateCollationTests(PostgreSQLTestCase): app_label = 'test_allow_create_collation' @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()]) def test_no_allow_migrate(self): operation = CreateCollation('C_test', locale='C') project_state = ProjectState() new_state = project_state.clone() # Don't create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 0) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 0) def test_create(self): operation = CreateCollation('C_test', locale='C') self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test') self.assertEqual(operation.describe(), 'Create collation C_test') project_state = ProjectState() new_state = project_state.clone() # Create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 1) self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) # Creating the same collation raises an exception. with self.assertRaisesMessage(ProgrammingError, 'already exists'): with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 1) self.assertIn('DROP COLLATION', captured_queries[0]['sql']) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'CreateCollation') self.assertEqual(args, []) self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'}) @skipUnlessDBFeature('supports_non_deterministic_collations') def test_create_non_deterministic_collation(self): operation = CreateCollation( 'case_insensitive_test', 'und-u-ks-level2', provider='icu', deterministic=False, ) project_state = ProjectState() new_state = project_state.clone() # Create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 1) self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 1) self.assertIn('DROP COLLATION', captured_queries[0]['sql']) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'CreateCollation') self.assertEqual(args, []) self.assertEqual(kwargs, { 'name': 'case_insensitive_test', 'locale': 'und-u-ks-level2', 'provider': 'icu', 'deterministic': False, }) def test_create_collation_alternate_provider(self): operation = CreateCollation( 'german_phonebook_test', provider='icu', locale='de-u-co-phonebk', ) project_state = ProjectState() new_state = project_state.clone() # Create an collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 1) self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 1) self.assertIn('DROP COLLATION', captured_queries[0]['sql']) def test_nondeterministic_collation_not_supported(self): operation = CreateCollation( 'case_insensitive_test', provider='icu', locale='und-u-ks-level2', deterministic=False, ) project_state = ProjectState() new_state = project_state.clone() msg = 'Non-deterministic collations require PostgreSQL 12+.' with connection.schema_editor(atomic=False) as editor: with mock.patch( 'django.db.backends.postgresql.features.DatabaseFeatures.' 'supports_non_deterministic_collations', False, ): with self.assertRaisesMessage(NotSupportedError, msg): operation.database_forwards(self.app_label, editor, project_state, new_state) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') class RemoveCollationTests(PostgreSQLTestCase): app_label = 'test_allow_remove_collation' @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()]) def test_no_allow_migrate(self): operation = RemoveCollation('C_test', locale='C') project_state = ProjectState() new_state = project_state.clone() # Don't create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 0) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 0) def test_remove(self): operation = CreateCollation('C_test', locale='C') project_state = ProjectState() new_state = project_state.clone() with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) operation = RemoveCollation('C_test', locale='C') self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test') self.assertEqual(operation.describe(), 'Remove collation C_test') project_state = ProjectState() new_state = project_state.clone() # Remove a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) self.assertEqual(len(captured_queries), 1) self.assertIn('DROP COLLATION', captured_queries[0]['sql']) # Removing a nonexistent collation raises an exception. with self.assertRaisesMessage(ProgrammingError, 'does not exist'): with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: operation.database_backwards(self.app_label, editor, new_state, project_state) self.assertEqual(len(captured_queries), 1) self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'RemoveCollation') self.assertEqual(args, []) self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'}) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') @modify_settings(INSTALLED_APPS={'append': 'migrations'}) class AddConstraintNotValidTests(OperationTestBase): app_label = 'test_add_constraint_not_valid' def test_non_check_constraint_not_supported(self): constraint = UniqueConstraint(fields=['pink'], name='pony_pink_uniq') msg = 'AddConstraintNotValid.constraint must be a check constraint.' with self.assertRaisesMessage(TypeError, msg): AddConstraintNotValid(model_name='pony', constraint=constraint) def test_add(self): table_name = f'{self.app_label}_pony' constraint_name = 'pony_pink_gte_check' constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name) operation = AddConstraintNotValid('Pony', constraint=constraint) project_state, new_state = self.make_test_state(self.app_label, operation) self.assertEqual( operation.describe(), f'Create not valid constraint {constraint_name} on model Pony', ) self.assertEqual( operation.migration_name_fragment, f'pony_{constraint_name}_not_valid', ) self.assertEqual( len(new_state.models[self.app_label, 'pony'].options['constraints']), 1, ) self.assertConstraintNotExists(table_name, constraint_name) Pony = new_state.apps.get_model(self.app_label, 'Pony') self.assertEqual(len(Pony._meta.constraints), 1) Pony.objects.create(pink=2, weight=1.0) # Add constraint. with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) msg = f'check constraint "{constraint_name}"' with self.assertRaisesMessage(IntegrityError, msg), transaction.atomic(): Pony.objects.create(pink=3, weight=1.0) self.assertConstraintExists(table_name, constraint_name) # Reversal. with connection.schema_editor(atomic=True) as editor: operation.database_backwards(self.app_label, editor, project_state, new_state) self.assertConstraintNotExists(table_name, constraint_name) Pony.objects.create(pink=3, weight=1.0) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'AddConstraintNotValid') self.assertEqual(args, []) self.assertEqual(kwargs, {'model_name': 'Pony', 'constraint': constraint}) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') @modify_settings(INSTALLED_APPS={'append': 'migrations'}) class ValidateConstraintTests(OperationTestBase): app_label = 'test_validate_constraint' def test_validate(self): constraint_name = 'pony_pink_gte_check' constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name) operation = AddConstraintNotValid('Pony', constraint=constraint) project_state, new_state = self.make_test_state(self.app_label, operation) Pony = new_state.apps.get_model(self.app_label, 'Pony') obj = Pony.objects.create(pink=2, weight=1.0) # Add constraint. with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) project_state = new_state new_state = new_state.clone() operation = ValidateConstraint('Pony', name=constraint_name) operation.state_forwards(self.app_label, new_state) self.assertEqual( operation.describe(), f'Validate constraint {constraint_name} on model Pony', ) self.assertEqual( operation.migration_name_fragment, f'pony_validate_{constraint_name}', ) # Validate constraint. with connection.schema_editor(atomic=True) as editor: msg = f'check constraint "{constraint_name}"' with self.assertRaisesMessage(IntegrityError, msg): operation.database_forwards(self.app_label, editor, project_state, new_state) obj.pink = 5 obj.save() with connection.schema_editor(atomic=True) as editor: operation.database_forwards(self.app_label, editor, project_state, new_state) # Reversal is a noop. with connection.schema_editor() as editor: with self.assertNumQueries(0): operation.database_backwards(self.app_label, editor, new_state, project_state) # Deconstruction. name, args, kwargs = operation.deconstruct() self.assertEqual(name, 'ValidateConstraint') self.assertEqual(args, []) self.assertEqual(kwargs, {'model_name': 'Pony', 'name': constraint_name})