Refs #28478 -- Prevented connection attempts against disallowed databases in tests.

Mocking connect as well as cursor methods makes sure an appropriate error
message is surfaced when running a subset of test attempting to access a
a disallowed database.
This commit is contained in:
Simon Charette 2019-01-12 14:33:50 -05:00 committed by Tim Graham
parent a96b901932
commit f5b635086a
3 changed files with 60 additions and 27 deletions

View File

@ -135,7 +135,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
return '%s was rendered.' % self.template_name return '%s was rendered.' % self.template_name
class _CursorFailure: class _DatabaseFailure:
def __init__(self, wrapped, message): def __init__(self, wrapped, message):
self.wrapped = wrapped self.wrapped = wrapped
self.message = message self.message = message
@ -173,11 +173,17 @@ class SimpleTestCase(unittest.TestCase):
databases = _SimpleTestCaseDatabasesDescriptor() databases = _SimpleTestCaseDatabasesDescriptor()
_disallowed_database_msg = ( _disallowed_database_msg = (
'Database queries are not allowed in SimpleTestCase subclasses. ' 'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase '
'Either subclass TestCase or TransactionTestCase to ensure proper ' 'subclasses. Either subclass TestCase or TransactionTestCase to ensure '
'test isolation or add %(alias)r to %(test)s.databases to silence ' 'proper test isolation or add %(alias)r to %(test)s.databases to silence '
'this failure.' 'this failure.'
) )
_disallowed_connection_methods = [
('connect', 'connections'),
('temporary_connection', 'connections'),
('cursor', 'queries'),
('chunked_cursor', 'queries'),
]
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -188,7 +194,7 @@ class SimpleTestCase(unittest.TestCase):
if cls._modified_settings: if cls._modified_settings:
cls._cls_modified_context = modify_settings(cls._modified_settings) cls._cls_modified_context = modify_settings(cls._modified_settings)
cls._cls_modified_context.enable() cls._cls_modified_context.enable()
cls._add_cursor_failures() cls._add_databases_failures()
@classmethod @classmethod
def _validate_databases(cls): def _validate_databases(cls):
@ -208,31 +214,34 @@ class SimpleTestCase(unittest.TestCase):
return frozenset(cls.databases) return frozenset(cls.databases)
@classmethod @classmethod
def _add_cursor_failures(cls): def _add_databases_failures(cls):
cls.databases = cls._validate_databases() cls.databases = cls._validate_databases()
for alias in connections: for alias in connections:
if alias in cls.databases: if alias in cls.databases:
continue continue
connection = connections[alias] connection = connections[alias]
for name, operation in cls._disallowed_connection_methods:
message = cls._disallowed_database_msg % { message = cls._disallowed_database_msg % {
'test': '%s.%s' % (cls.__module__, cls.__qualname__), 'test': '%s.%s' % (cls.__module__, cls.__qualname__),
'alias': alias, 'alias': alias,
'operation': operation,
} }
connection.cursor = _CursorFailure(connection.cursor, message) method = getattr(connection, name)
connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message) setattr(connection, name, _DatabaseFailure(method, message))
@classmethod @classmethod
def _remove_cursor_failures(cls): def _remove_databases_failures(cls):
for alias in connections: for alias in connections:
if alias in cls.databases: if alias in cls.databases:
continue continue
connection = connections[alias] connection = connections[alias]
connection.cursor = connection.cursor.wrapped for name, _ in cls._disallowed_connection_methods:
connection.chunked_cursor = connection.chunked_cursor.wrapped method = getattr(connection, name)
setattr(connection, name, method.wrapped)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
cls._remove_cursor_failures() cls._remove_databases_failures()
if hasattr(cls, '_cls_modified_context'): if hasattr(cls, '_cls_modified_context'):
cls._cls_modified_context.disable() cls._cls_modified_context.disable()
delattr(cls, '_cls_modified_context') delattr(cls, '_cls_modified_context')
@ -894,8 +903,8 @@ class TransactionTestCase(SimpleTestCase):
databases = _TransactionTestCaseDatabasesDescriptor() databases = _TransactionTestCaseDatabasesDescriptor()
_disallowed_database_msg = ( _disallowed_database_msg = (
'Database queries to %(alias)r are not allowed in this test. Add ' 'Database %(operation)s to %(alias)r are not allowed in this test. '
'%(alias)r to %(test)s.databases to ensure proper test isolation ' 'Add %(alias)r to %(test)s.databases to ensure proper test isolation '
'and silence this failure.' 'and silence this failure.'
) )
@ -1121,13 +1130,13 @@ class TestCase(TransactionTestCase):
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
except Exception: except Exception:
cls._rollback_atomics(cls.cls_atomics) cls._rollback_atomics(cls.cls_atomics)
cls._remove_cursor_failures() cls._remove_databases_failures()
raise raise
try: try:
cls.setUpTestData() cls.setUpTestData()
except Exception: except Exception:
cls._rollback_atomics(cls.cls_atomics) cls._rollback_atomics(cls.cls_atomics)
cls._remove_cursor_failures() cls._remove_databases_failures()
raise raise
@classmethod @classmethod

View File

@ -1,4 +1,4 @@
from django.db import IntegrityError, transaction from django.db import IntegrityError, connections, transaction
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from .models import Car, PossessedCar from .models import Car, PossessedCar
@ -19,6 +19,17 @@ class TestTestCase(TestCase):
finally: finally:
self._rollback_atomics = rollback_atomics self._rollback_atomics = rollback_atomics
def test_disallowed_database_connection(self):
message = (
"Database connections to 'other' are not allowed in this test. "
"Add 'other' to test_utils.test_testcase.TestTestCase.databases to "
"ensure proper test isolation and silence this failure."
)
with self.assertRaisesMessage(AssertionError, message):
connections['other'].connect()
with self.assertRaisesMessage(AssertionError, message):
connections['other'].temporary_connection()
def test_disallowed_database_queries(self): def test_disallowed_database_queries(self):
message = ( message = (
"Database queries to 'other' are not allowed in this test. " "Database queries to 'other' are not allowed in this test. "

View File

@ -1159,11 +1159,24 @@ class TestBadSetUpTestData(TestCase):
class DisallowedDatabaseQueriesTests(SimpleTestCase): class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self):
expected_message = (
"Database connections to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)
with self.assertRaisesMessage(AssertionError, expected_message):
connection.connect()
with self.assertRaisesMessage(AssertionError, expected_message):
connection.temporary_connection()
def test_disallowed_database_queries(self): def test_disallowed_database_queries(self):
expected_message = ( expected_message = (
"Database queries are not allowed in SimpleTestCase subclasses. " "Database queries to 'default' are not allowed in SimpleTestCase "
"Either subclass TestCase or TransactionTestCase to ensure proper " "subclasses. Either subclass TestCase or TransactionTestCase to "
"test isolation or add 'default' to " "ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to " "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure." "silence this failure."
) )
@ -1172,9 +1185,9 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_chunked_cursor_queries(self): def test_disallowed_database_chunked_cursor_queries(self):
expected_message = ( expected_message = (
"Database queries are not allowed in SimpleTestCase subclasses. " "Database queries to 'default' are not allowed in SimpleTestCase "
"Either subclass TestCase or TransactionTestCase to ensure proper " "subclasses. Either subclass TestCase or TransactionTestCase to "
"test isolation or add 'default' to " "ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to " "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure." "silence this failure."
) )