diff --git a/tests/backends/test_mysql.py b/tests/backends/test_mysql.py index d7030975f6..637c3e377c 100644 --- a/tests/backends/test_mysql.py +++ b/tests/backends/test_mysql.py @@ -1,14 +1,40 @@ import unittest +from contextlib import contextmanager from django.core.exceptions import ImproperlyConfigured from django.db import connection from django.test import TestCase, override_settings +@contextmanager +def get_connection(): + new_connection = connection.copy() + yield new_connection + new_connection.close() + + @override_settings(DEBUG=True) @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific test.') class MySQLTests(TestCase): + read_committed = 'read committed' + repeatable_read = 'repeatable read' + isolation_values = { + level: level.replace(' ', '-').upper() + for level in (read_committed, repeatable_read) + } + + @classmethod + def setUpClass(cls): + super().setUpClass() + configured_isolation_level = connection.isolation_level or cls.isolation_values[cls.repeatable_read] + cls.configured_isolation_level = configured_isolation_level.upper() + cls.other_isolation_level = ( + cls.read_committed + if configured_isolation_level != cls.isolation_values[cls.read_committed] + else cls.repeatable_read + ) + @staticmethod def get_isolation_level(connection): with connection.cursor() as cursor: @@ -25,32 +51,24 @@ class MySQLTests(TestCase): self.assertNotIn(query, last_query) def test_connect_isolation_level(self): - read_committed = 'read committed' - repeatable_read = 'repeatable read' - isolation_values = { - level: level.replace(' ', '-').upper() - for level in (read_committed, repeatable_read) - } - configured_level = connection.isolation_level or isolation_values[repeatable_read] - configured_level = configured_level.upper() - other_level = read_committed if configured_level != isolation_values[read_committed] else repeatable_read + self.assertEqual(self.get_isolation_level(connection), self.configured_isolation_level) - self.assertEqual(self.get_isolation_level(connection), configured_level) - - new_connection = connection.copy() - new_connection.settings_dict['OPTIONS']['isolation_level'] = other_level - try: - self.assertEqual(self.get_isolation_level(new_connection), isolation_values[other_level]) - finally: - new_connection.close() + def test_setting_isolation_level(self): + with get_connection() as new_connection: + new_connection.settings_dict['OPTIONS']['isolation_level'] = self.other_isolation_level + self.assertEqual( + self.get_isolation_level(new_connection), + self.isolation_values[self.other_isolation_level] + ) + def test_uppercase_isolation_level(self): # Upper case values are also accepted in 'isolation_level'. - new_connection = connection.copy() - new_connection.settings_dict['OPTIONS']['isolation_level'] = other_level.upper() - try: - self.assertEqual(self.get_isolation_level(new_connection), isolation_values[other_level]) - finally: - new_connection.close() + with get_connection() as new_connection: + new_connection.settings_dict['OPTIONS']['isolation_level'] = self.other_isolation_level.upper() + self.assertEqual( + self.get_isolation_level(new_connection), + self.isolation_values[self.other_isolation_level] + ) def test_isolation_level_validation(self): new_connection = connection.copy()