From f5b635086a07c9df5897e69685ebdc3a04049ec6 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sat, 12 Jan 2019 14:33:50 -0500 Subject: [PATCH] Refs #28478 -- Prevented connection attempts against disallowed databases in tests. Mocking connect as well as cursor methods makes sure an appropriate error message is surfaced when running a subset of test attempting to access a a disallowed database. --- django/test/testcases.py | 49 ++++++++++++++++++------------- tests/test_utils/test_testcase.py | 13 +++++++- tests/test_utils/tests.py | 25 ++++++++++++---- 3 files changed, 60 insertions(+), 27 deletions(-) diff --git a/django/test/testcases.py b/django/test/testcases.py index f820684c878..0e1da33064a 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -135,7 +135,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext): return '%s was rendered.' % self.template_name -class _CursorFailure: +class _DatabaseFailure: def __init__(self, wrapped, message): self.wrapped = wrapped self.message = message @@ -173,11 +173,17 @@ class SimpleTestCase(unittest.TestCase): databases = _SimpleTestCaseDatabasesDescriptor() _disallowed_database_msg = ( - 'Database queries are not allowed in SimpleTestCase subclasses. ' - 'Either subclass TestCase or TransactionTestCase to ensure proper ' - 'test isolation or add %(alias)r to %(test)s.databases to silence ' + 'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase ' + 'subclasses. Either subclass TestCase or TransactionTestCase to ensure ' + 'proper test isolation or add %(alias)r to %(test)s.databases to silence ' 'this failure.' ) + _disallowed_connection_methods = [ + ('connect', 'connections'), + ('temporary_connection', 'connections'), + ('cursor', 'queries'), + ('chunked_cursor', 'queries'), + ] @classmethod def setUpClass(cls): @@ -188,7 +194,7 @@ class SimpleTestCase(unittest.TestCase): if cls._modified_settings: cls._cls_modified_context = modify_settings(cls._modified_settings) cls._cls_modified_context.enable() - cls._add_cursor_failures() + cls._add_databases_failures() @classmethod def _validate_databases(cls): @@ -208,31 +214,34 @@ class SimpleTestCase(unittest.TestCase): return frozenset(cls.databases) @classmethod - def _add_cursor_failures(cls): + def _add_databases_failures(cls): cls.databases = cls._validate_databases() for alias in connections: if alias in cls.databases: continue connection = connections[alias] - message = cls._disallowed_database_msg % { - 'test': '%s.%s' % (cls.__module__, cls.__qualname__), - 'alias': alias, - } - connection.cursor = _CursorFailure(connection.cursor, message) - connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message) + for name, operation in cls._disallowed_connection_methods: + message = cls._disallowed_database_msg % { + 'test': '%s.%s' % (cls.__module__, cls.__qualname__), + 'alias': alias, + 'operation': operation, + } + method = getattr(connection, name) + setattr(connection, name, _DatabaseFailure(method, message)) @classmethod - def _remove_cursor_failures(cls): + def _remove_databases_failures(cls): for alias in connections: if alias in cls.databases: continue connection = connections[alias] - connection.cursor = connection.cursor.wrapped - connection.chunked_cursor = connection.chunked_cursor.wrapped + for name, _ in cls._disallowed_connection_methods: + method = getattr(connection, name) + setattr(connection, name, method.wrapped) @classmethod def tearDownClass(cls): - cls._remove_cursor_failures() + cls._remove_databases_failures() if hasattr(cls, '_cls_modified_context'): cls._cls_modified_context.disable() delattr(cls, '_cls_modified_context') @@ -894,8 +903,8 @@ class TransactionTestCase(SimpleTestCase): databases = _TransactionTestCaseDatabasesDescriptor() _disallowed_database_msg = ( - 'Database queries to %(alias)r are not allowed in this test. Add ' - '%(alias)r to %(test)s.databases to ensure proper test isolation ' + 'Database %(operation)s to %(alias)r are not allowed in this test. ' + 'Add %(alias)r to %(test)s.databases to ensure proper test isolation ' 'and silence this failure.' ) @@ -1121,13 +1130,13 @@ class TestCase(TransactionTestCase): call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) except Exception: cls._rollback_atomics(cls.cls_atomics) - cls._remove_cursor_failures() + cls._remove_databases_failures() raise try: cls.setUpTestData() except Exception: cls._rollback_atomics(cls.cls_atomics) - cls._remove_cursor_failures() + cls._remove_databases_failures() raise @classmethod diff --git a/tests/test_utils/test_testcase.py b/tests/test_utils/test_testcase.py index f374549400f..853aba7c228 100644 --- a/tests/test_utils/test_testcase.py +++ b/tests/test_utils/test_testcase.py @@ -1,4 +1,4 @@ -from django.db import IntegrityError, transaction +from django.db import IntegrityError, connections, transaction from django.test import TestCase, skipUnlessDBFeature from .models import Car, PossessedCar @@ -19,6 +19,17 @@ class TestTestCase(TestCase): finally: self._rollback_atomics = rollback_atomics + def test_disallowed_database_connection(self): + message = ( + "Database connections to 'other' are not allowed in this test. " + "Add 'other' to test_utils.test_testcase.TestTestCase.databases to " + "ensure proper test isolation and silence this failure." + ) + with self.assertRaisesMessage(AssertionError, message): + connections['other'].connect() + with self.assertRaisesMessage(AssertionError, message): + connections['other'].temporary_connection() + def test_disallowed_database_queries(self): message = ( "Database queries to 'other' are not allowed in this test. " diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index c7e55e0711c..680924b8380 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -1159,11 +1159,24 @@ class TestBadSetUpTestData(TestCase): class DisallowedDatabaseQueriesTests(SimpleTestCase): + def test_disallowed_database_connections(self): + expected_message = ( + "Database connections to 'default' are not allowed in SimpleTestCase " + "subclasses. Either subclass TestCase or TransactionTestCase to " + "ensure proper test isolation or add 'default' to " + "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " + "silence this failure." + ) + with self.assertRaisesMessage(AssertionError, expected_message): + connection.connect() + with self.assertRaisesMessage(AssertionError, expected_message): + connection.temporary_connection() + def test_disallowed_database_queries(self): expected_message = ( - "Database queries are not allowed in SimpleTestCase subclasses. " - "Either subclass TestCase or TransactionTestCase to ensure proper " - "test isolation or add 'default' to " + "Database queries to 'default' are not allowed in SimpleTestCase " + "subclasses. Either subclass TestCase or TransactionTestCase to " + "ensure proper test isolation or add 'default' to " "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " "silence this failure." ) @@ -1172,9 +1185,9 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase): def test_disallowed_database_chunked_cursor_queries(self): expected_message = ( - "Database queries are not allowed in SimpleTestCase subclasses. " - "Either subclass TestCase or TransactionTestCase to ensure proper " - "test isolation or add 'default' to " + "Database queries to 'default' are not allowed in SimpleTestCase " + "subclasses. Either subclass TestCase or TransactionTestCase to " + "ensure proper test isolation or add 'default' to " "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " "silence this failure." )