Fixed #33161 -- Enabled durability check for nested atomic blocks in TestCase.

Co-Authored-By: Adam Johnson <me@adamj.eu>
This commit is contained in:
Krzysztof Jagiello 2021-09-30 19:13:56 +02:00 committed by Mariusz Felisiak
parent adb4100e58
commit 8d9827c06c
7 changed files with 52 additions and 70 deletions

View File

@ -544,6 +544,7 @@ answer newbie questions, and generally made Django that much better:
Kowito Charoenratchatabhan <kowito@felspar.com> Kowito Charoenratchatabhan <kowito@felspar.com>
Krišjānis Vaiders <krisjanisvaiders@gmail.com> Krišjānis Vaiders <krisjanisvaiders@gmail.com>
krzysiek.pawlik@silvermedia.pl krzysiek.pawlik@silvermedia.pl
Krzysztof Jagiello <me@kjagiello.com>
Krzysztof Jurewicz <krzysztof.jurewicz@gmail.com> Krzysztof Jurewicz <krzysztof.jurewicz@gmail.com>
Krzysztof Kulewski <kulewski@gmail.com> Krzysztof Kulewski <kulewski@gmail.com>
kurtiss@meetro.com kurtiss@meetro.com

View File

@ -79,6 +79,8 @@ class BaseDatabaseWrapper:
self.savepoint_state = 0 self.savepoint_state = 0
# List of savepoints created by 'atomic'. # List of savepoints created by 'atomic'.
self.savepoint_ids = [] self.savepoint_ids = []
# Stack of active 'atomic' blocks.
self.atomic_blocks = []
# Tracks if the outermost 'atomic' block should commit on exit, # Tracks if the outermost 'atomic' block should commit on exit,
# ie. if autocommit was active on entry. # ie. if autocommit was active on entry.
self.commit_on_exit = True self.commit_on_exit = True
@ -200,6 +202,7 @@ class BaseDatabaseWrapper:
# In case the previous connection was closed while in an atomic block # In case the previous connection was closed while in an atomic block
self.in_atomic_block = False self.in_atomic_block = False
self.savepoint_ids = [] self.savepoint_ids = []
self.atomic_blocks = []
self.needs_rollback = False self.needs_rollback = False
# Reset parameters defining when to close the connection # Reset parameters defining when to close the connection
max_age = self.settings_dict['CONN_MAX_AGE'] max_age = self.settings_dict['CONN_MAX_AGE']

View File

@ -165,19 +165,21 @@ class Atomic(ContextDecorator):
This is a private API. This is a private API.
""" """
# This private flag is provided only to disable the durability checks in
# TestCase.
_ensure_durability = True
def __init__(self, using, savepoint, durable): def __init__(self, using, savepoint, durable):
self.using = using self.using = using
self.savepoint = savepoint self.savepoint = savepoint
self.durable = durable self.durable = durable
self._from_testcase = False
def __enter__(self): def __enter__(self):
connection = get_connection(self.using) connection = get_connection(self.using)
if self.durable and self._ensure_durability and connection.in_atomic_block: if (
self.durable and
connection.atomic_blocks and
not connection.atomic_blocks[-1]._from_testcase
):
raise RuntimeError( raise RuntimeError(
'A durable atomic block cannot be nested within another ' 'A durable atomic block cannot be nested within another '
'atomic block.' 'atomic block.'
@ -207,9 +209,15 @@ class Atomic(ContextDecorator):
connection.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True) connection.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True)
connection.in_atomic_block = True connection.in_atomic_block = True
if connection.in_atomic_block:
connection.atomic_blocks.append(self)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
connection = get_connection(self.using) connection = get_connection(self.using)
if connection.in_atomic_block:
connection.atomic_blocks.pop()
if connection.savepoint_ids: if connection.savepoint_ids:
sid = connection.savepoint_ids.pop() sid = connection.savepoint_ids.pop()
else: else:

View File

@ -1146,8 +1146,10 @@ class TestCase(TransactionTestCase):
"""Open atomic blocks for multiple databases.""" """Open atomic blocks for multiple databases."""
atomics = {} atomics = {}
for db_name in cls._databases_names(): for db_name in cls._databases_names():
atomics[db_name] = transaction.atomic(using=db_name) atomic = transaction.atomic(using=db_name)
atomics[db_name].__enter__() atomic._from_testcase = True
atomic.__enter__()
atomics[db_name] = atomic
return atomics return atomics
@classmethod @classmethod
@ -1166,35 +1168,27 @@ class TestCase(TransactionTestCase):
super().setUpClass() super().setUpClass()
if not cls._databases_support_transactions(): if not cls._databases_support_transactions():
return return
# Disable the durability check to allow testing durable atomic blocks cls.cls_atomics = cls._enter_atomics()
# in a transaction for performance reasons.
transaction.Atomic._ensure_durability = False
try:
cls.cls_atomics = cls._enter_atomics()
if cls.fixtures: if cls.fixtures:
for db_name in cls._databases_names(include_mirrors=False): for db_name in cls._databases_names(include_mirrors=False):
try: try:
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
except Exception: except Exception:
cls._rollback_atomics(cls.cls_atomics) cls._rollback_atomics(cls.cls_atomics)
raise raise
pre_attrs = cls.__dict__.copy() pre_attrs = cls.__dict__.copy()
try: try:
cls.setUpTestData() cls.setUpTestData()
except Exception:
cls._rollback_atomics(cls.cls_atomics)
raise
for name, value in cls.__dict__.items():
if value is not pre_attrs.get(name):
setattr(cls, name, TestData(name, value))
except Exception: except Exception:
transaction.Atomic._ensure_durability = True cls._rollback_atomics(cls.cls_atomics)
raise raise
for name, value in cls.__dict__.items():
if value is not pre_attrs.get(name):
setattr(cls, name, TestData(name, value))
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
transaction.Atomic._ensure_durability = True
if cls._databases_support_transactions(): if cls._databases_support_transactions():
cls._rollback_atomics(cls.cls_atomics) cls._rollback_atomics(cls.cls_atomics)
for conn in connections.all(): for conn in connections.all():

View File

@ -215,7 +215,8 @@ Templates
Tests Tests
~~~~~ ~~~~~
* ... * A nested atomic block marked as durable in :class:`django.test.TestCase` now
raises a ``RuntimeError``, the same as outside of tests.
URLs URLs
~~~~ ~~~~

View File

@ -238,11 +238,10 @@ Django provides a single API to control database transactions.
is especially important if you're using :func:`atomic` in long-running is especially important if you're using :func:`atomic` in long-running
processes, outside of Django's request / response cycle. processes, outside of Django's request / response cycle.
.. warning:: .. versionchanged:: 4.1
:class:`django.test.TestCase` disables the durability check to allow In older versions, the durability check was disabled in
testing durable atomic blocks in a transaction for performance reasons. Use :class:`django.test.TestCase`.
:class:`django.test.TransactionTestCase` for testing durability.
Autocommit Autocommit
========== ==========

View File

@ -501,7 +501,7 @@ class NonAutocommitTests(TransactionTestCase):
Reporter.objects.create(first_name="Tintin") Reporter.objects.create(first_name="Tintin")
class DurableTests(TransactionTestCase): class DurableTestsBase:
available_apps = ['transactions'] available_apps = ['transactions']
def test_commit(self): def test_commit(self):
@ -533,42 +533,18 @@ class DurableTests(TransactionTestCase):
with transaction.atomic(durable=True): with transaction.atomic(durable=True):
pass pass
def test_sequence_of_durables(self):
class DisableDurabiltityCheckTests(TestCase):
"""
TestCase runs all tests in a transaction by default. Code using
durable=True would always fail when run from TestCase. This would mean
these tests would be forced to use the slower TransactionTestCase even when
not testing durability. For this reason, TestCase disables the durability
check.
"""
available_apps = ['transactions']
def test_commit(self):
with transaction.atomic(durable=True): with transaction.atomic(durable=True):
reporter = Reporter.objects.create(first_name='Tintin') reporter = Reporter.objects.create(first_name='Tintin 1')
self.assertEqual(Reporter.objects.get(), reporter) self.assertEqual(Reporter.objects.get(first_name='Tintin 1'), reporter)
def test_nested_outer_durable(self):
with transaction.atomic(durable=True): with transaction.atomic(durable=True):
reporter1 = Reporter.objects.create(first_name='Tintin') reporter = Reporter.objects.create(first_name='Tintin 2')
with transaction.atomic(): self.assertEqual(Reporter.objects.get(first_name='Tintin 2'), reporter)
reporter2 = Reporter.objects.create(
first_name='Archibald',
last_name='Haddock',
)
self.assertSequenceEqual(Reporter.objects.all(), [reporter2, reporter1])
def test_nested_both_durable(self):
with transaction.atomic(durable=True):
# Error is not raised.
with transaction.atomic(durable=True):
reporter = Reporter.objects.create(first_name='Tintin')
self.assertEqual(Reporter.objects.get(), reporter)
def test_nested_inner_durable(self): class DurableTransactionTests(DurableTestsBase, TransactionTestCase):
with transaction.atomic(): pass
# Error is not raised.
with transaction.atomic(durable=True):
reporter = Reporter.objects.create(first_name='Tintin') class DurableTests(DurableTestsBase, TestCase):
self.assertEqual(Reporter.objects.get(), reporter) pass