Fixed #33616 -- Allowed registering callbacks that can fail in transaction.on_commit().

Thanks David Wobrock and Mariusz Felisiak for reviews.
This commit is contained in:
SirAbhi13 2022-06-07 21:20:46 +05:30 committed by Mariusz Felisiak
parent be63c78760
commit 4a1150b41d
7 changed files with 140 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import _thread import _thread
import copy import copy
import datetime import datetime
import logging
import threading import threading
import time import time
import warnings import warnings
@ -26,6 +27,8 @@ from django.utils.functional import cached_property
NO_DB_ALIAS = "__no_db__" NO_DB_ALIAS = "__no_db__"
RAN_DB_VERSION_CHECK = set() RAN_DB_VERSION_CHECK = set()
logger = logging.getLogger("django.db.backends.base")
# RemovedInDjango50Warning # RemovedInDjango50Warning
def timezone_constructor(tzname): def timezone_constructor(tzname):
@ -417,7 +420,9 @@ class BaseDatabaseWrapper:
# Remove any callbacks registered while this savepoint was active. # Remove any callbacks registered while this savepoint was active.
self.run_on_commit = [ self.run_on_commit = [
(sids, func) for (sids, func) in self.run_on_commit if sid not in sids (sids, func, robust)
for (sids, func, robust) in self.run_on_commit
if sid not in sids
] ]
@async_unsafe @async_unsafe
@ -723,12 +728,12 @@ class BaseDatabaseWrapper:
) )
return self.SchemaEditorClass(self, *args, **kwargs) return self.SchemaEditorClass(self, *args, **kwargs)
def on_commit(self, func): def on_commit(self, func, robust=False):
if not callable(func): if not callable(func):
raise TypeError("on_commit()'s callback must be a callable.") raise TypeError("on_commit()'s callback must be a callable.")
if self.in_atomic_block: if self.in_atomic_block:
# Transaction in progress; save for execution on commit. # Transaction in progress; save for execution on commit.
self.run_on_commit.append((set(self.savepoint_ids), func)) self.run_on_commit.append((set(self.savepoint_ids), func, robust))
elif not self.get_autocommit(): elif not self.get_autocommit():
raise TransactionManagementError( raise TransactionManagementError(
"on_commit() cannot be used in manual transaction management" "on_commit() cannot be used in manual transaction management"
@ -736,15 +741,36 @@ class BaseDatabaseWrapper:
else: else:
# No transaction in progress and in autocommit mode; execute # No transaction in progress and in autocommit mode; execute
# immediately. # immediately.
func() if robust:
try:
func()
except Exception as e:
logger.error(
f"Error calling {func.__qualname__} in on_commit() (%s).",
e,
exc_info=True,
)
else:
func()
def run_and_clear_commit_hooks(self): def run_and_clear_commit_hooks(self):
self.validate_no_atomic_block() self.validate_no_atomic_block()
current_run_on_commit = self.run_on_commit current_run_on_commit = self.run_on_commit
self.run_on_commit = [] self.run_on_commit = []
while current_run_on_commit: while current_run_on_commit:
sids, func = current_run_on_commit.pop(0) _, func, robust = current_run_on_commit.pop(0)
func() if robust:
try:
func()
except Exception as e:
logger.error(
f"Error calling {func.__qualname__} in on_commit() during "
f"transaction (%s).",
e,
exc_info=True,
)
else:
func()
@contextmanager @contextmanager
def execute_wrapper(self, wrapper): def execute_wrapper(self, wrapper):

View File

@ -125,12 +125,12 @@ def mark_for_rollback_on_error(using=None):
raise raise
def on_commit(func, using=None): def on_commit(func, using=None, robust=False):
""" """
Register `func` to be called when the current transaction is committed. Register `func` to be called when the current transaction is committed.
If the current transaction is rolled back, `func` will not be called. If the current transaction is rolled back, `func` will not be called.
""" """
get_connection(using).on_commit(func) get_connection(using).on_commit(func, robust)
################################# #################################

View File

@ -59,6 +59,8 @@ from django.utils.functional import classproperty
from django.utils.version import PY310 from django.utils.version import PY310
from django.views.static import serve from django.views.static import serve
logger = logging.getLogger("django.test")
__all__ = ( __all__ = (
"TestCase", "TestCase",
"TransactionTestCase", "TransactionTestCase",
@ -1510,10 +1512,23 @@ class TestCase(TransactionTestCase):
finally: finally:
while True: while True:
callback_count = len(connections[using].run_on_commit) callback_count = len(connections[using].run_on_commit)
for _, callback in connections[using].run_on_commit[start_count:]: for _, callback, robust in connections[using].run_on_commit[
start_count:
]:
callbacks.append(callback) callbacks.append(callback)
if execute: if execute:
callback() if robust:
try:
callback()
except Exception as e:
logger.error(
f"Error calling {callback.__qualname__} in "
f"on_commit() (%s).",
e,
exc_info=True,
)
else:
callback()
if callback_count == len(connections[using].run_on_commit): if callback_count == len(connections[using].run_on_commit):
break break

View File

@ -212,6 +212,10 @@ Models
* :ref:`Registering lookups <lookup-registration-api>` on * :ref:`Registering lookups <lookup-registration-api>` on
:class:`~django.db.models.Field` instances is now supported. :class:`~django.db.models.Field` instances is now supported.
* The new ``robust`` argument for :func:`~django.db.transaction.on_commit`
allows performing actions that can fail after a database transaction is
successfully committed.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -297,7 +297,7 @@ include a `Celery`_ task, an email notification, or a cache invalidation.
Django provides the :func:`on_commit` function to register callback functions Django provides the :func:`on_commit` function to register callback functions
that should be executed after a transaction is successfully committed: that should be executed after a transaction is successfully committed:
.. function:: on_commit(func, using=None) .. function:: on_commit(func, using=None, robust=False)
Pass any function (that takes no arguments) to :func:`on_commit`:: Pass any function (that takes no arguments) to :func:`on_commit`::
@ -325,6 +325,15 @@ If that hypothetical database write is instead rolled back (typically when an
unhandled exception is raised in an :func:`atomic` block), your function will unhandled exception is raised in an :func:`atomic` block), your function will
be discarded and never called. be discarded and never called.
It's sometimes useful to register callback functions that can fail. Passing
``robust=True`` allows the next functions to be executed even if the current
function throws an exception. All errors derived from Python's ``Exception``
class are caught and logged to the ``django.db.backends.base`` logger.
.. versionchanged:: 4.2
The ``robust`` argument was added.
Savepoints Savepoints
---------- ----------
@ -366,10 +375,14 @@ registered.
Exception handling Exception handling
------------------ ------------------
If one on-commit function within a given transaction raises an uncaught If one on-commit function registered with ``robust=False`` within a given
exception, no later registered functions in that same transaction will run. transaction raises an uncaught exception, no later registered functions in that
This is the same behavior as if you'd executed the functions sequentially same transaction will run. This is the same behavior as if you'd executed the
yourself without :func:`on_commit`. functions sequentially yourself without :func:`on_commit`.
.. versionchanged:: 4.2
The ``robust`` argument was added.
Timing of execution Timing of execution
------------------- -------------------

View File

@ -2285,6 +2285,32 @@ class CaptureOnCommitCallbacksTests(TestCase):
self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2]) self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2])
def test_execute_robust(self):
class MyException(Exception):
pass
def hook():
self.callback_called = True
raise MyException("robust callback")
with self.assertLogs("django.test", "ERROR") as cm:
with self.captureOnCommitCallbacks(execute=True) as callbacks:
transaction.on_commit(hook, robust=True)
self.assertEqual(len(callbacks), 1)
self.assertIs(self.callback_called, True)
log_record = cm.records[0]
self.assertEqual(
log_record.getMessage(),
"Error calling CaptureOnCommitCallbacksTests.test_execute_robust.<locals>."
"hook in on_commit() (robust callback).",
)
self.assertIsNotNone(log_record.exc_info)
raised_exception = log_record.exc_info[1]
self.assertIsInstance(raised_exception, MyException)
self.assertEqual(str(raised_exception), "robust callback")
class DisallowedDatabaseQueriesTests(SimpleTestCase): class DisallowedDatabaseQueriesTests(SimpleTestCase):
def test_disallowed_database_connections(self): def test_disallowed_database_connections(self):

View File

@ -43,6 +43,47 @@ class TestConnectionOnCommit(TransactionTestCase):
self.do(1) self.do(1)
self.assertDone([1]) self.assertDone([1])
def test_robust_if_no_transaction(self):
def robust_callback():
raise ForcedError("robust callback")
with self.assertLogs("django.db.backends.base", "ERROR") as cm:
transaction.on_commit(robust_callback, robust=True)
self.do(1)
self.assertDone([1])
log_record = cm.records[0]
self.assertEqual(
log_record.getMessage(),
"Error calling TestConnectionOnCommit.test_robust_if_no_transaction."
"<locals>.robust_callback in on_commit() (robust callback).",
)
self.assertIsNotNone(log_record.exc_info)
raised_exception = log_record.exc_info[1]
self.assertIsInstance(raised_exception, ForcedError)
self.assertEqual(str(raised_exception), "robust callback")
def test_robust_transaction(self):
def robust_callback():
raise ForcedError("robust callback")
with self.assertLogs("django.db.backends", "ERROR") as cm:
with transaction.atomic():
transaction.on_commit(robust_callback, robust=True)
self.do(1)
self.assertDone([1])
log_record = cm.records[0]
self.assertEqual(
log_record.getMessage(),
"Error calling TestConnectionOnCommit.test_robust_transaction.<locals>."
"robust_callback in on_commit() during transaction (robust callback).",
)
self.assertIsNotNone(log_record.exc_info)
raised_exception = log_record.exc_info[1]
self.assertIsInstance(raised_exception, ForcedError)
self.assertEqual(str(raised_exception), "robust callback")
def test_delays_execution_until_after_transaction_commit(self): def test_delays_execution_until_after_transaction_commit(self):
with transaction.atomic(): with transaction.atomic():
self.do(1) self.do(1)