Fixed #33410 -- Fixed recursive capturing of callbacks by TestCase.captureOnCommitCallbacks().

Regression in d89f976bdd.
This commit is contained in:
Petter Friberg 2022-01-04 23:06:46 +01:00 committed by Mariusz Felisiak
parent 806efe912b
commit bc174e6ea0
3 changed files with 59 additions and 8 deletions

View File

@ -1253,18 +1253,16 @@ class TestCase(TransactionTestCase):
try: try:
yield callbacks yield callbacks
finally: finally:
callback_count = len(connections[using].run_on_commit)
while True: while True:
run_on_commit = connections[using].run_on_commit[start_count:] callback_count = len(connections[using].run_on_commit)
callbacks[:] = [func for sids, func in run_on_commit] for _, callback in connections[using].run_on_commit[start_count:]:
callbacks.append(callback)
if execute: if execute:
for callback in callbacks:
callback() callback()
if callback_count == len(connections[using].run_on_commit): if callback_count == len(connections[using].run_on_commit):
break break
start_count = callback_count - 1 start_count = callback_count
callback_count = len(connections[using].run_on_commit)
class CheckCondition: class CheckCondition:

View File

@ -9,4 +9,5 @@ Django 4.0.2 fixes several bugs in 4.0.1.
Bugfixes Bugfixes
======== ========
* ... * Fixed a bug in Django 4.0 where ``TestCase.captureOnCommitCallbacks()`` could
execute callbacks multiple times (:ticket:`33410`).

View File

@ -1829,6 +1829,58 @@ class CaptureOnCommitCallbacksTests(TestCase):
self.assertEqual(len(callbacks), 2) self.assertEqual(len(callbacks), 2)
self.assertIs(self.callback_called, True) self.assertIs(self.callback_called, True)
def test_execute_tree(self):
"""
A visualisation of the callback tree tested. Each node is expected to
be visited only once:
branch_1
branch_2
leaf_1
leaf_2
leaf_3
"""
branch_1_call_counter = 0
branch_2_call_counter = 0
leaf_1_call_counter = 0
leaf_2_call_counter = 0
leaf_3_call_counter = 0
def leaf_1():
nonlocal leaf_1_call_counter
leaf_1_call_counter += 1
def leaf_2():
nonlocal leaf_2_call_counter
leaf_2_call_counter += 1
def leaf_3():
nonlocal leaf_3_call_counter
leaf_3_call_counter += 1
def branch_1():
nonlocal branch_1_call_counter
branch_1_call_counter += 1
transaction.on_commit(branch_2)
transaction.on_commit(leaf_3)
def branch_2():
nonlocal branch_2_call_counter
branch_2_call_counter += 1
transaction.on_commit(leaf_1)
transaction.on_commit(leaf_2)
with self.captureOnCommitCallbacks(execute=True) as callbacks:
transaction.on_commit(branch_1)
self.assertEqual(branch_1_call_counter, 1)
self.assertEqual(branch_2_call_counter, 1)
self.assertEqual(leaf_1_call_counter, 1)
self.assertEqual(leaf_2_call_counter, 1)
self.assertEqual(leaf_3_call_counter, 1)
self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2])
class DisallowedDatabaseQueriesTests(SimpleTestCase): class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self): def test_disallowed_database_connections(self):