Fixed #33054 -- Made TestCase.captureOnCommitCallbacks() capture callbacks recursively.

This commit is contained in:
Eugene Morozov 2021-08-24 23:29:55 +03:00 committed by Mariusz Felisiak
parent 02bc7161ec
commit d89f976bdd
4 changed files with 27 additions and 5 deletions

View File

@ -1282,12 +1282,19 @@ class TestCase(TransactionTestCase):
try: try:
yield callbacks yield callbacks
finally: finally:
callback_count = len(connections[using].run_on_commit)
while True:
run_on_commit = connections[using].run_on_commit[start_count:] run_on_commit = connections[using].run_on_commit[start_count:]
callbacks[:] = [func for sids, func in run_on_commit] callbacks[:] = [func for sids, func in run_on_commit]
if execute: if execute:
for callback in callbacks: for callback in callbacks:
callback() callback()
if callback_count == len(connections[using].run_on_commit):
break
start_count = callback_count - 1
callback_count = len(connections[using].run_on_commit)
class CheckCondition: class CheckCondition:
"""Descriptor class for deferred condition checking.""" """Descriptor class for deferred condition checking."""

View File

@ -368,6 +368,9 @@ Tests
* The :option:`test --parallel` option now supports the value ``auto`` to run * The :option:`test --parallel` option now supports the value ``auto`` to run
one test process for each processor core. one test process for each processor core.
* :meth:`.TestCase.captureOnCommitCallbacks` now captures new callbacks added
while executing :func:`.transaction.on_commit` callbacks.
URLs URLs
~~~~ ~~~~

View File

@ -912,6 +912,11 @@ It also provides an additional method:
self.assertEqual(mail.outbox[0].subject, 'Contact Form') self.assertEqual(mail.outbox[0].subject, 'Contact Form')
self.assertEqual(mail.outbox[0].body, 'I like your site') self.assertEqual(mail.outbox[0].body, 'I like your site')
.. versionchanged:: 4.0
In older versions, new callbacks added while executing
:func:`.transaction.on_commit` callbacks were not captured.
.. _live-test-server: .. _live-test-server:
``LiveServerTestCase`` ``LiveServerTestCase``

View File

@ -1502,6 +1502,13 @@ class CaptureOnCommitCallbacksTests(TestCase):
self.assertEqual(callbacks, []) self.assertEqual(callbacks, [])
def test_execute_recursive(self):
with self.captureOnCommitCallbacks(execute=True) as callbacks:
transaction.on_commit(self.enqueue_callback)
self.assertEqual(len(callbacks), 2)
self.assertIs(self.callback_called, True)
class DisallowedDatabaseQueriesTests(SimpleTestCase): class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self): def test_disallowed_database_connections(self):