diff --git a/django/db/__init__.py b/django/db/__init__.py index b1980488df..94eca13d41 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -42,8 +42,17 @@ backend = load_backend(connection.settings_dict['ENGINE']) # Register an event that closes the database connection # when a Django request is finished. def close_connection(**kwargs): - for conn in connections.all(): - conn.close() + # Avoid circular imports + from django.db import transaction + for conn in connections: + try: + transaction.abort(conn) + connections[conn].close() + except Exception: + # The connection's state is unknown, so it has to be + # abandoned. This could happen for example if the network + # connection has a failure. + del connections[conn] signals.request_finished.connect(close_connection) # Register an event that resets connection.queries diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 7dc5456827..bbb5a5b294 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -88,6 +88,17 @@ class BaseDatabaseWrapper(object): return self.cursor().execute(self.ops.savepoint_commit_sql(sid)) + def abort(self): + """ + Roll back any ongoing transaction and clean the transaction state + stack. + """ + if self._dirty: + self._rollback() + self._dirty = False + while self.transaction_state: + self.leave_transaction_management() + def enter_transaction_management(self, managed=True): """ Enters transaction management for a running thread. It must be balanced with diff --git a/django/db/transaction.py b/django/db/transaction.py index f3ce2b2335..dd7e2f4dcb 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -24,6 +24,21 @@ class TransactionManagementError(Exception): """ pass +def abort(using=None): + """ + Roll back any ongoing transactions and clean the transaction management + state of the connection. + + This method is to be used only in cases where using balanced + leave_transaction_management() calls isn't possible. For example after a + request has finished, the transaction state isn't known, yet the connection + must be cleaned up for the next request. + """ + if using is None: + using = DEFAULT_DB_ALIAS + connection = connections[using] + connection.abort() + def enter_transaction_management(managed=True, using=None): """ Enters transaction management for a running thread. It must be balanced with diff --git a/django/db/utils.py b/django/db/utils.py index 91fa774ed4..943e3e3f73 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -99,6 +99,9 @@ class ConnectionHandler(object): def __setitem__(self, key, value): setattr(self._connections, key, value) + def __delitem__(self, key): + delattr(self._connections, key) + def __iter__(self): return iter(self.databases) diff --git a/django/middleware/transaction.py b/django/middleware/transaction.py index 96b1538d9d..4440f377a7 100644 --- a/django/middleware/transaction.py +++ b/django/middleware/transaction.py @@ -15,6 +15,10 @@ class TransactionMiddleware(object): def process_exception(self, request, exception): """Rolls back the database and leaves transaction management""" if transaction.is_dirty(): + # This rollback might fail because of network failure for example. + # If rollback isn't possible it is impossible to clean the + # connection's state. So leave the connection in dirty state and + # let request_finished signal deal with cleaning the connection. transaction.rollback() transaction.leave_transaction_management() @@ -22,6 +26,21 @@ class TransactionMiddleware(object): """Commits and leaves transaction management.""" if transaction.is_managed(): if transaction.is_dirty(): - transaction.commit() + # Note: it is possible that the commit fails. If the reason is + # closed connection or some similar reason, then there is + # little hope to proceed nicely. However, in some cases ( + # deferred foreign key checks for exampl) it is still possible + # to rollback(). + try: + transaction.commit() + except Exception: + # If the rollback fails, the transaction state will be + # messed up. It doesn't matter, the connection will be set + # to clean state after the request finishes. And, we can't + # clean the state here properly even if we wanted to, the + # connection is in transaction but we can't rollback... + transaction.rollback() + transaction.leave_transaction_management() + raise transaction.leave_transaction_management() return response diff --git a/django/test/testcases.py b/django/test/testcases.py index 3aa0afa35e..f7c34a9f25 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -70,6 +70,7 @@ real_rollback = transaction.rollback real_enter_transaction_management = transaction.enter_transaction_management real_leave_transaction_management = transaction.leave_transaction_management real_managed = transaction.managed +real_abort = transaction.abort def nop(*args, **kwargs): return @@ -80,6 +81,7 @@ def disable_transaction_methods(): transaction.enter_transaction_management = nop transaction.leave_transaction_management = nop transaction.managed = nop + transaction.abort = nop def restore_transaction_methods(): transaction.commit = real_commit @@ -87,6 +89,7 @@ def restore_transaction_methods(): transaction.enter_transaction_management = real_enter_transaction_management transaction.leave_transaction_management = real_leave_transaction_management transaction.managed = real_managed + transaction.abort = real_abort def assert_and_parse_html(self, html, user_msg, msg): diff --git a/tests/regressiontests/middleware/tests.py b/tests/regressiontests/middleware/tests.py index a9a45c99ba..6c436415ab 100644 --- a/tests/regressiontests/middleware/tests.py +++ b/tests/regressiontests/middleware/tests.py @@ -9,9 +9,9 @@ import warnings from django.conf import settings from django.core import mail -from django.db import transaction -from django.http import HttpRequest -from django.http import HttpResponse, StreamingHttpResponse +from django.db import (transaction, connections, DEFAULT_DB_ALIAS, + IntegrityError) +from django.http import HttpRequest, HttpResponse, StreamingHttpResponse from django.middleware.clickjacking import XFrameOptionsMiddleware from django.middleware.common import CommonMiddleware, BrokenLinkEmailsMiddleware from django.middleware.http import ConditionalGetMiddleware @@ -710,3 +710,22 @@ class TransactionMiddlewareTest(TransactionTestCase): TransactionMiddleware().process_exception(self.request, None) self.assertEqual(Band.objects.count(), 0) self.assertFalse(transaction.is_dirty()) + + def test_failing_commit(self): + # It is possible that connection.commit() fails. Check that + # TransactionMiddleware handles such cases correctly. + try: + def raise_exception(): + raise IntegrityError() + connections[DEFAULT_DB_ALIAS].commit = raise_exception + transaction.enter_transaction_management() + transaction.managed(True) + Band.objects.create(name='The Beatles') + self.assertTrue(transaction.is_dirty()) + with self.assertRaises(IntegrityError): + TransactionMiddleware().process_response(self.request, None) + self.assertEqual(Band.objects.count(), 0) + self.assertFalse(transaction.is_dirty()) + self.assertFalse(transaction.is_managed()) + finally: + del connections[DEFAULT_DB_ALIAS].commit diff --git a/tests/regressiontests/requests/tests.py b/tests/regressiontests/requests/tests.py index 799cd9b302..d89f6d68be 100644 --- a/tests/regressiontests/requests/tests.py +++ b/tests/regressiontests/requests/tests.py @@ -6,9 +6,12 @@ import warnings from datetime import datetime, timedelta from io import BytesIO +from django.db import connection, connections, DEFAULT_DB_ALIAS +from django.core import signals from django.core.exceptions import SuspiciousOperation from django.core.handlers.wsgi import WSGIRequest, LimitedStream from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError +from django.test import TransactionTestCase from django.test.client import FakePayload from django.test.utils import override_settings, str_prefix from django.utils import six @@ -524,3 +527,42 @@ class RequestsTests(unittest.TestCase): with self.assertRaises(UnreadablePostError): request.body + +class TransactionRequestTests(TransactionTestCase): + def test_request_finished_db_state(self): + # The GET below will not succeed, but it will give a response with + # defined ._handler_class. That is needed for sending the + # request_finished signal. + response = self.client.get('/') + # Make sure there is an open connection + connection.cursor() + connection.enter_transaction_management() + connection.managed(True) + signals.request_finished.send(sender=response._handler_class) + # In-memory sqlite doesn't actually close connections. + if connection.vendor != 'sqlite': + self.assertIs(connection.connection, None) + self.assertEqual(len(connection.transaction_state), 0) + + @unittest.skipIf(connection.vendor == 'sqlite', + 'This test will close the connection, in-memory ' + 'sqlite connections must not be closed.') + def test_request_finished_failed_connection(self): + # See comments in test_request_finished_db_state() for the self.client + # usage. + response = self.client.get('/') + conn = connections[DEFAULT_DB_ALIAS] + conn.enter_transaction_management() + conn.managed(True) + conn.set_dirty() + # Test that the rollback doesn't succeed (for example network failure + # could cause this). + def fail_horribly(): + raise Exception("Horrible failure!") + conn._rollback = fail_horribly + signals.request_finished.send(sender=response._handler_class) + # As even rollback wasn't possible the connection wrapper itself was + # abandoned. Accessing the connections[alias] will create a new + # connection wrapper, whch must be different than the original one. + self.assertIsNot(conn, connections[DEFAULT_DB_ALIAS]) + self.assertEqual(len(connection.transaction_state), 0)