Fixed #33616 -- Allowed registering callbacks that can fail in transaction.on_commit().

Thanks David Wobrock and Mariusz Felisiak for reviews.
This commit is contained in:
SirAbhi13 2022-06-07 21:20:46 +05:30 committed by Mariusz Felisiak
parent be63c78760
commit 4a1150b41d
7 changed files with 140 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import _thread
import copy
import datetime
import logging
import threading
import time
import warnings
@ -26,6 +27,8 @@ from django.utils.functional import cached_property
NO_DB_ALIAS = "__no_db__"
RAN_DB_VERSION_CHECK = set()
logger = logging.getLogger("django.db.backends.base")
# RemovedInDjango50Warning
def timezone_constructor(tzname):
@ -417,7 +420,9 @@ class BaseDatabaseWrapper:
# Remove any callbacks registered while this savepoint was active.
self.run_on_commit = [
(sids, func) for (sids, func) in self.run_on_commit if sid not in sids
(sids, func, robust)
for (sids, func, robust) in self.run_on_commit
if sid not in sids
]
@async_unsafe
@ -723,12 +728,12 @@ class BaseDatabaseWrapper:
)
return self.SchemaEditorClass(self, *args, **kwargs)
def on_commit(self, func):
def on_commit(self, func, robust=False):
if not callable(func):
raise TypeError("on_commit()'s callback must be a callable.")
if self.in_atomic_block:
# Transaction in progress; save for execution on commit.
self.run_on_commit.append((set(self.savepoint_ids), func))
self.run_on_commit.append((set(self.savepoint_ids), func, robust))
elif not self.get_autocommit():
raise TransactionManagementError(
"on_commit() cannot be used in manual transaction management"
@ -736,15 +741,36 @@ class BaseDatabaseWrapper:
else:
# No transaction in progress and in autocommit mode; execute
# immediately.
func()
if robust:
try:
func()
except Exception as e:
logger.error(
f"Error calling {func.__qualname__} in on_commit() (%s).",
e,
exc_info=True,
)
else:
func()
def run_and_clear_commit_hooks(self):
self.validate_no_atomic_block()
current_run_on_commit = self.run_on_commit
self.run_on_commit = []
while current_run_on_commit:
sids, func = current_run_on_commit.pop(0)
func()
_, func, robust = current_run_on_commit.pop(0)
if robust:
try:
func()
except Exception as e:
logger.error(
f"Error calling {func.__qualname__} in on_commit() during "
f"transaction (%s).",
e,
exc_info=True,
)
else:
func()
@contextmanager
def execute_wrapper(self, wrapper):

View File

@ -125,12 +125,12 @@ def mark_for_rollback_on_error(using=None):
raise
def on_commit(func, using=None):
def on_commit(func, using=None, robust=False):
"""
Register `func` to be called when the current transaction is committed.
If the current transaction is rolled back, `func` will not be called.
"""
get_connection(using).on_commit(func)
get_connection(using).on_commit(func, robust)
#################################

View File

@ -59,6 +59,8 @@ from django.utils.functional import classproperty
from django.utils.version import PY310
from django.views.static import serve
logger = logging.getLogger("django.test")
__all__ = (
"TestCase",
"TransactionTestCase",
@ -1510,10 +1512,23 @@ class TestCase(TransactionTestCase):
finally:
while True:
callback_count = len(connections[using].run_on_commit)
for _, callback in connections[using].run_on_commit[start_count:]:
for _, callback, robust in connections[using].run_on_commit[
start_count:
]:
callbacks.append(callback)
if execute:
callback()
if robust:
try:
callback()
except Exception as e:
logger.error(
f"Error calling {callback.__qualname__} in "
f"on_commit() (%s).",
e,
exc_info=True,
)
else:
callback()
if callback_count == len(connections[using].run_on_commit):
break

View File

@ -212,6 +212,10 @@ Models
* :ref:`Registering lookups <lookup-registration-api>` on
:class:`~django.db.models.Field` instances is now supported.
* The new ``robust`` argument for :func:`~django.db.transaction.on_commit`
allows performing actions that can fail after a database transaction is
successfully committed.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -297,7 +297,7 @@ include a `Celery`_ task, an email notification, or a cache invalidation.
Django provides the :func:`on_commit` function to register callback functions
that should be executed after a transaction is successfully committed:
.. function:: on_commit(func, using=None)
.. function:: on_commit(func, using=None, robust=False)
Pass any function (that takes no arguments) to :func:`on_commit`::
@ -325,6 +325,15 @@ If that hypothetical database write is instead rolled back (typically when an
unhandled exception is raised in an :func:`atomic` block), your function will
be discarded and never called.
It's sometimes useful to register callback functions that can fail. Passing
``robust=True`` allows the next functions to be executed even if the current
function throws an exception. All errors derived from Python's ``Exception``
class are caught and logged to the ``django.db.backends.base`` logger.
.. versionchanged:: 4.2
The ``robust`` argument was added.
Savepoints
----------
@ -366,10 +375,14 @@ registered.
Exception handling
------------------
If one on-commit function within a given transaction raises an uncaught
exception, no later registered functions in that same transaction will run.
This is the same behavior as if you'd executed the functions sequentially
yourself without :func:`on_commit`.
If one on-commit function registered with ``robust=False`` within a given
transaction raises an uncaught exception, no later registered functions in that
same transaction will run. This is the same behavior as if you'd executed the
functions sequentially yourself without :func:`on_commit`.
.. versionchanged:: 4.2
The ``robust`` argument was added.
Timing of execution
-------------------

View File

@ -2285,6 +2285,32 @@ class CaptureOnCommitCallbacksTests(TestCase):
self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2])
def test_execute_robust(self):
class MyException(Exception):
pass
def hook():
self.callback_called = True
raise MyException("robust callback")
with self.assertLogs("django.test", "ERROR") as cm:
with self.captureOnCommitCallbacks(execute=True) as callbacks:
transaction.on_commit(hook, robust=True)
self.assertEqual(len(callbacks), 1)
self.assertIs(self.callback_called, True)
log_record = cm.records[0]
self.assertEqual(
log_record.getMessage(),
"Error calling CaptureOnCommitCallbacksTests.test_execute_robust.<locals>."
"hook in on_commit() (robust callback).",
)
self.assertIsNotNone(log_record.exc_info)
raised_exception = log_record.exc_info[1]
self.assertIsInstance(raised_exception, MyException)
self.assertEqual(str(raised_exception), "robust callback")
class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self):

View File

@ -43,6 +43,47 @@ class TestConnectionOnCommit(TransactionTestCase):
self.do(1)
self.assertDone([1])
def test_robust_if_no_transaction(self):
def robust_callback():
raise ForcedError("robust callback")
with self.assertLogs("django.db.backends.base", "ERROR") as cm:
transaction.on_commit(robust_callback, robust=True)
self.do(1)
self.assertDone([1])
log_record = cm.records[0]
self.assertEqual(
log_record.getMessage(),
"Error calling TestConnectionOnCommit.test_robust_if_no_transaction."
"<locals>.robust_callback in on_commit() (robust callback).",
)
self.assertIsNotNone(log_record.exc_info)
raised_exception = log_record.exc_info[1]
self.assertIsInstance(raised_exception, ForcedError)
self.assertEqual(str(raised_exception), "robust callback")
def test_robust_transaction(self):
def robust_callback():
raise ForcedError("robust callback")
with self.assertLogs("django.db.backends", "ERROR") as cm:
with transaction.atomic():
transaction.on_commit(robust_callback, robust=True)
self.do(1)
self.assertDone([1])
log_record = cm.records[0]
self.assertEqual(
log_record.getMessage(),
"Error calling TestConnectionOnCommit.test_robust_transaction.<locals>."
"robust_callback in on_commit() during transaction (robust callback).",
)
self.assertIsNotNone(log_record.exc_info)
raised_exception = log_record.exc_info[1]
self.assertIsInstance(raised_exception, ForcedError)
self.assertEqual(str(raised_exception), "robust callback")
def test_delays_execution_until_after_transaction_commit(self):
with transaction.atomic():
self.do(1)