diff --git a/django/test/testcases.py b/django/test/testcases.py index 5787dc0115..5c9a5262b8 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1229,6 +1229,21 @@ class TestCase(TransactionTestCase): not connection.needs_rollback and connection.is_usable() ) + @classmethod + @contextmanager + def captureOnCommitCallbacks(cls, *, using=DEFAULT_DB_ALIAS, execute=False): + """Context manager to capture transaction.on_commit() callbacks.""" + callbacks = [] + start_count = len(connections[using].run_on_commit) + try: + yield callbacks + finally: + run_on_commit = connections[using].run_on_commit[start_count:] + callbacks[:] = [func for sids, func in run_on_commit] + if execute: + for callback in callbacks: + callback() + class CheckCondition: """Descriptor class for deferred condition checking.""" diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 653e14a3a1..e0e80323e3 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -276,6 +276,11 @@ Tests * :class:`~django.test.Client` now preserves the request query string when following 307 and 308 redirects. +* The new :meth:`.TestCase.captureOnCommitCallbacks` method captures callback + functions passed to :func:`transaction.on_commit() + ` in a list. This allows you to test such + callbacks without using the slower :class:`.TransactionTestCase`. + URLs ~~~~ diff --git a/docs/topics/db/transactions.txt b/docs/topics/db/transactions.txt index 3eace66c83..996dd7534d 100644 --- a/docs/topics/db/transactions.txt +++ b/docs/topics/db/transactions.txt @@ -394,9 +394,19 @@ Use in tests Django's :class:`~django.test.TestCase` class wraps each test in a transaction and rolls back that transaction after each test, in order to provide test isolation. This means that no transaction is ever actually committed, thus your -:func:`on_commit` callbacks will never be run. If you need to test the results -of an :func:`on_commit` callback, use a -:class:`~django.test.TransactionTestCase` instead. +:func:`on_commit` callbacks will never be run. + +You can overcome this limitation by using +:meth:`.TestCase.captureOnCommitCallbacks`. This captures your +:func:`on_commit` callbacks in a list, allowing you to make assertions on them, +or emulate the transaction committing by calling them. + +Another way to overcome the limitation is to use +:class:`~django.test.TransactionTestCase` instead of +:class:`~django.test.TestCase`. This will mean your transactions are committed, +and the callbacks will run. However +:class:`~django.test.TransactionTestCase` flushes the database between tests, +which is significantly slower than :class:`~django.test.TestCase`\'s isolation. Why no rollback hook? --------------------- diff --git a/docs/topics/testing/tools.txt b/docs/topics/testing/tools.txt index a22428962a..741acd604c 100644 --- a/docs/topics/testing/tools.txt +++ b/docs/topics/testing/tools.txt @@ -881,6 +881,42 @@ It also provides an additional method: previous versions of Django these objects were reused and changes made to them were persisted between test methods. +.. classmethod:: TestCase.captureOnCommitCallbacks(using=DEFAULT_DB_ALIAS, execute=False) + + .. versionadded:: 3.2 + + Returns a context manager that captures :func:`transaction.on_commit() + ` callbacks for the given database + connection. It returns a list that contains, on exit of the context, the + captured callback functions. From this list you can make assertions on the + callbacks or call them to invoke their side effects, emulating a commit. + + ``using`` is the alias of the database connection to capture callbacks for. + + If ``execute`` is ``True``, all the callbacks will be called as the context + manager exits, if no exception occurred. This emulates a commit after the + wrapped block of code. + + For example:: + + from django.core import mail + from django.test import TestCase + + + class ContactTests(TestCase): + def test_post(self): + with self.captureOnCommitCallbacks(execute=True) as callbacks: + response = self.client.post( + '/contact/', + {'message': 'I like your site'}, + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(callbacks), 1) + self.assertEqual(len(mail.outbox), 1) + self.assertEqual(mail.outbox[0].subject, 'Contact Form') + self.assertEqual(mail.outbox[0].body, 'I like your site') + .. _live-test-server: ``LiveServerTestCase`` diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index 4e90d720ee..a82dadceaa 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -9,7 +9,9 @@ from django.contrib.staticfiles.finders import get_finder, get_finders from django.contrib.staticfiles.storage import staticfiles_storage from django.core.exceptions import ImproperlyConfigured from django.core.files.storage import default_storage -from django.db import connection, connections, models, router +from django.db import ( + IntegrityError, connection, connections, models, router, transaction, +) from django.forms import EmailField, IntegerField from django.http import HttpResponse from django.template.loader import render_to_string @@ -1273,6 +1275,71 @@ class TestBadSetUpTestData(TestCase): self.assertFalse(self._in_atomic_block) +class CaptureOnCommitCallbacksTests(TestCase): + databases = {'default', 'other'} + callback_called = False + + def enqueue_callback(self, using='default'): + def hook(): + self.callback_called = True + + transaction.on_commit(hook, using=using) + + def test_no_arguments(self): + with self.captureOnCommitCallbacks() as callbacks: + self.enqueue_callback() + + self.assertEqual(len(callbacks), 1) + self.assertIs(self.callback_called, False) + callbacks[0]() + self.assertIs(self.callback_called, True) + + def test_using(self): + with self.captureOnCommitCallbacks(using='other') as callbacks: + self.enqueue_callback(using='other') + + self.assertEqual(len(callbacks), 1) + self.assertIs(self.callback_called, False) + callbacks[0]() + self.assertIs(self.callback_called, True) + + def test_different_using(self): + with self.captureOnCommitCallbacks(using='default') as callbacks: + self.enqueue_callback(using='other') + + self.assertEqual(callbacks, []) + + def test_execute(self): + with self.captureOnCommitCallbacks(execute=True) as callbacks: + self.enqueue_callback() + + self.assertEqual(len(callbacks), 1) + self.assertIs(self.callback_called, True) + + def test_pre_callback(self): + def pre_hook(): + pass + + transaction.on_commit(pre_hook, using='default') + with self.captureOnCommitCallbacks() as callbacks: + self.enqueue_callback() + + self.assertEqual(len(callbacks), 1) + self.assertNotEqual(callbacks[0], pre_hook) + + def test_with_rolled_back_savepoint(self): + with self.captureOnCommitCallbacks() as callbacks: + try: + with transaction.atomic(): + self.enqueue_callback() + raise IntegrityError + except IntegrityError: + # Inner transaction.atomic() has been rolled back. + pass + + self.assertEqual(callbacks, []) + + class DisallowedDatabaseQueriesTests(SimpleTestCase): def test_disallowed_database_connections(self): expected_message = (