Refs #28478 -- Prevented connection attempts against disallowed databases in tests.

Mocking connect as well as cursor methods makes sure an appropriate error
message is surfaced when running a subset of test attempting to access a
a disallowed database.
This commit is contained in:
Simon Charette 2019-01-12 14:33:50 -05:00 committed by Tim Graham
parent a96b901932
commit f5b635086a
3 changed files with 60 additions and 27 deletions

View File

@ -135,7 +135,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
return '%s was rendered.' % self.template_name
class _CursorFailure:
class _DatabaseFailure:
def __init__(self, wrapped, message):
self.wrapped = wrapped
self.message = message
@ -173,11 +173,17 @@ class SimpleTestCase(unittest.TestCase):
databases = _SimpleTestCaseDatabasesDescriptor()
_disallowed_database_msg = (
'Database queries are not allowed in SimpleTestCase subclasses. '
'Either subclass TestCase or TransactionTestCase to ensure proper '
'test isolation or add %(alias)r to %(test)s.databases to silence '
'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase '
'subclasses. Either subclass TestCase or TransactionTestCase to ensure '
'proper test isolation or add %(alias)r to %(test)s.databases to silence '
'this failure.'
)
_disallowed_connection_methods = [
('connect', 'connections'),
('temporary_connection', 'connections'),
('cursor', 'queries'),
('chunked_cursor', 'queries'),
]
@classmethod
def setUpClass(cls):
@ -188,7 +194,7 @@ class SimpleTestCase(unittest.TestCase):
if cls._modified_settings:
cls._cls_modified_context = modify_settings(cls._modified_settings)
cls._cls_modified_context.enable()
cls._add_cursor_failures()
cls._add_databases_failures()
@classmethod
def _validate_databases(cls):
@ -208,31 +214,34 @@ class SimpleTestCase(unittest.TestCase):
return frozenset(cls.databases)
@classmethod
def _add_cursor_failures(cls):
def _add_databases_failures(cls):
cls.databases = cls._validate_databases()
for alias in connections:
if alias in cls.databases:
continue
connection = connections[alias]
message = cls._disallowed_database_msg % {
'test': '%s.%s' % (cls.__module__, cls.__qualname__),
'alias': alias,
}
connection.cursor = _CursorFailure(connection.cursor, message)
connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message)
for name, operation in cls._disallowed_connection_methods:
message = cls._disallowed_database_msg % {
'test': '%s.%s' % (cls.__module__, cls.__qualname__),
'alias': alias,
'operation': operation,
}
method = getattr(connection, name)
setattr(connection, name, _DatabaseFailure(method, message))
@classmethod
def _remove_cursor_failures(cls):
def _remove_databases_failures(cls):
for alias in connections:
if alias in cls.databases:
continue
connection = connections[alias]
connection.cursor = connection.cursor.wrapped
connection.chunked_cursor = connection.chunked_cursor.wrapped
for name, _ in cls._disallowed_connection_methods:
method = getattr(connection, name)
setattr(connection, name, method.wrapped)
@classmethod
def tearDownClass(cls):
cls._remove_cursor_failures()
cls._remove_databases_failures()
if hasattr(cls, '_cls_modified_context'):
cls._cls_modified_context.disable()
delattr(cls, '_cls_modified_context')
@ -894,8 +903,8 @@ class TransactionTestCase(SimpleTestCase):
databases = _TransactionTestCaseDatabasesDescriptor()
_disallowed_database_msg = (
'Database queries to %(alias)r are not allowed in this test. Add '
'%(alias)r to %(test)s.databases to ensure proper test isolation '
'Database %(operation)s to %(alias)r are not allowed in this test. '
'Add %(alias)r to %(test)s.databases to ensure proper test isolation '
'and silence this failure.'
)
@ -1121,13 +1130,13 @@ class TestCase(TransactionTestCase):
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
except Exception:
cls._rollback_atomics(cls.cls_atomics)
cls._remove_cursor_failures()
cls._remove_databases_failures()
raise
try:
cls.setUpTestData()
except Exception:
cls._rollback_atomics(cls.cls_atomics)
cls._remove_cursor_failures()
cls._remove_databases_failures()
raise
@classmethod

View File

@ -1,4 +1,4 @@
from django.db import IntegrityError, transaction
from django.db import IntegrityError, connections, transaction
from django.test import TestCase, skipUnlessDBFeature
from .models import Car, PossessedCar
@ -19,6 +19,17 @@ class TestTestCase(TestCase):
finally:
self._rollback_atomics = rollback_atomics
def test_disallowed_database_connection(self):
message = (
"Database connections to 'other' are not allowed in this test. "
"Add 'other' to test_utils.test_testcase.TestTestCase.databases to "
"ensure proper test isolation and silence this failure."
)
with self.assertRaisesMessage(AssertionError, message):
connections['other'].connect()
with self.assertRaisesMessage(AssertionError, message):
connections['other'].temporary_connection()
def test_disallowed_database_queries(self):
message = (
"Database queries to 'other' are not allowed in this test. "

View File

@ -1159,11 +1159,24 @@ class TestBadSetUpTestData(TestCase):
class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self):
expected_message = (
"Database connections to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)
with self.assertRaisesMessage(AssertionError, expected_message):
connection.connect()
with self.assertRaisesMessage(AssertionError, expected_message):
connection.temporary_connection()
def test_disallowed_database_queries(self):
expected_message = (
"Database queries are not allowed in SimpleTestCase subclasses. "
"Either subclass TestCase or TransactionTestCase to ensure proper "
"test isolation or add 'default' to "
"Database queries to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)
@ -1172,9 +1185,9 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_chunked_cursor_queries(self):
expected_message = (
"Database queries are not allowed in SimpleTestCase subclasses. "
"Either subclass TestCase or TransactionTestCase to ensure proper "
"test isolation or add 'default' to "
"Database queries to 'default' are not allowed in SimpleTestCase "
"subclasses. Either subclass TestCase or TransactionTestCase to "
"ensure proper test isolation or add 'default' to "
"test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
"silence this failure."
)