diff --git a/django/test/testcases.py b/django/test/testcases.py index 685cbe2bd4..9759680f0f 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1253,18 +1253,16 @@ class TestCase(TransactionTestCase): try: yield callbacks finally: - callback_count = len(connections[using].run_on_commit) while True: - run_on_commit = connections[using].run_on_commit[start_count:] - callbacks[:] = [func for sids, func in run_on_commit] - if execute: - for callback in callbacks: + callback_count = len(connections[using].run_on_commit) + for _, callback in connections[using].run_on_commit[start_count:]: + callbacks.append(callback) + if execute: callback() if callback_count == len(connections[using].run_on_commit): break - start_count = callback_count - 1 - callback_count = len(connections[using].run_on_commit) + start_count = callback_count class CheckCondition: diff --git a/docs/releases/4.0.2.txt b/docs/releases/4.0.2.txt index 002e891017..b1f1fb9c76 100644 --- a/docs/releases/4.0.2.txt +++ b/docs/releases/4.0.2.txt @@ -9,4 +9,5 @@ Django 4.0.2 fixes several bugs in 4.0.1. Bugfixes ======== -* ... +* Fixed a bug in Django 4.0 where ``TestCase.captureOnCommitCallbacks()`` could + execute callbacks multiple times (:ticket:`33410`). diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index 2addc1fdde..3380a1c67e 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -1829,6 +1829,58 @@ class CaptureOnCommitCallbacksTests(TestCase): self.assertEqual(len(callbacks), 2) self.assertIs(self.callback_called, True) + def test_execute_tree(self): + """ + A visualisation of the callback tree tested. Each node is expected to + be visited only once: + + └─branch_1 + ├─branch_2 + │ ├─leaf_1 + │ └─leaf_2 + └─leaf_3 + """ + branch_1_call_counter = 0 + branch_2_call_counter = 0 + leaf_1_call_counter = 0 + leaf_2_call_counter = 0 + leaf_3_call_counter = 0 + + def leaf_1(): + nonlocal leaf_1_call_counter + leaf_1_call_counter += 1 + + def leaf_2(): + nonlocal leaf_2_call_counter + leaf_2_call_counter += 1 + + def leaf_3(): + nonlocal leaf_3_call_counter + leaf_3_call_counter += 1 + + def branch_1(): + nonlocal branch_1_call_counter + branch_1_call_counter += 1 + transaction.on_commit(branch_2) + transaction.on_commit(leaf_3) + + def branch_2(): + nonlocal branch_2_call_counter + branch_2_call_counter += 1 + transaction.on_commit(leaf_1) + transaction.on_commit(leaf_2) + + with self.captureOnCommitCallbacks(execute=True) as callbacks: + transaction.on_commit(branch_1) + + self.assertEqual(branch_1_call_counter, 1) + self.assertEqual(branch_2_call_counter, 1) + self.assertEqual(leaf_1_call_counter, 1) + self.assertEqual(leaf_2_call_counter, 1) + self.assertEqual(leaf_3_call_counter, 1) + + self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2]) + class DisallowedDatabaseQueriesTests(SimpleTestCase): def test_disallowed_database_connections(self):