Fixed #21134 -- Prevented queries in broken transactions.

Backport of 728548e4 from master.

Squashed commit of the following:

commit 63ddb271a44df389b2c302e421fc17b7f0529755
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 29 22:51:00 2013 +0200

    Clarified interactions between atomic and exceptions.

commit 2899ec299228217c876ba3aa4024e523a41c8504
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 22 22:45:32 2013 +0200

    Fixed TransactionManagementError in tests.

    Previous commit introduced an additional check to prevent running
    queries in transactions that will be rolled back, which triggered a few
    failures in the tests. In practice using transaction.atomic instead of
    the low-level savepoint APIs was enough to fix the problems.

commit 4a639b059ea80aeb78f7f160a7d4b9f609b9c238
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Tue Sep 24 22:24:17 2013 +0200

    Allowed nesting constraint_checks_disabled inside atomic.

    Since MySQL handles transactions loosely, this isn't a problem.

commit 2a4ab1cb6e83391ff7e25d08479e230ca564bfef
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sat Sep 21 18:43:12 2013 +0200

    Prevented running queries in transactions that will be rolled back.

    This avoids a counter-intuitive behavior in an edge case on databases
    with non-atomic transaction semantics.

    It prevents using savepoint_rollback() inside an atomic block without
    calling set_rollback(False) first, which is backwards-incompatible in
    tests.

    Refs #21134.

commit 8e3db393853c7ac64a445b66e57f3620a3fde7b0
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 22 22:14:17 2013 +0200

    Replaced manual savepoints by atomic blocks.

    This ensures the rollback flag is handled consistently in internal APIs.
This commit is contained in:
Aymeric Augustin 2013-09-30 10:14:22 +02:00
parent c4468e0619
commit 0d74bdaf0c
14 changed files with 149 additions and 86 deletions

View File

@ -58,12 +58,11 @@ class SessionStore(SessionBase):
expire_date=self.get_expiry_date() expire_date=self.get_expiry_date()
) )
using = router.db_for_write(Session, instance=obj) using = router.db_for_write(Session, instance=obj)
sid = transaction.savepoint(using=using)
try: try:
with transaction.atomic(using=using):
obj.save(force_insert=must_create, using=using) obj.save(force_insert=must_create, using=using)
except IntegrityError: except IntegrityError:
if must_create: if must_create:
transaction.savepoint_rollback(sid, using=using)
raise CreateError raise CreateError
raise raise

View File

@ -359,6 +359,12 @@ class BaseDatabaseWrapper(object):
raise TransactionManagementError( raise TransactionManagementError(
"This is forbidden when an 'atomic' block is active.") "This is forbidden when an 'atomic' block is active.")
def validate_no_broken_transaction(self):
if self.needs_rollback:
raise TransactionManagementError(
"An error occurred in the current transaction. You can't "
"execute queries until the end of the 'atomic' block.")
def abort(self): def abort(self):
""" """
Roll back any ongoing transaction and clean the transaction state Roll back any ongoing transaction and clean the transaction state
@ -626,6 +632,9 @@ class BaseDatabaseFeatures(object):
# when autocommit is disabled? http://bugs.python.org/issue8145#msg109965 # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965
autocommits_when_autocommit_is_off = False autocommits_when_autocommit_is_off = False
# Does the backend prevent running SQL queries in broken transactions?
atomic_transactions = True
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value}) # Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
# parameter passing? Note this can be provided by the backend even if not # parameter passing? Note this can be provided by the backend even if not
# supported by the Python driver # supported by the Python driver

View File

@ -166,6 +166,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_explicit_null_ordering_when_grouping = True requires_explicit_null_ordering_when_grouping = True
allows_primary_key_0 = False allows_primary_key_0 = False
uses_savepoints = True uses_savepoints = True
atomic_transactions = False
def __init__(self, connection): def __init__(self, connection):
super(DatabaseFeatures, self).__init__(connection) super(DatabaseFeatures, self).__init__(connection)
@ -470,7 +471,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
""" """
Re-enable foreign key checks after they have been disabled. Re-enable foreign key checks after they have been disabled.
""" """
# Override needs_rollback in case constraint_checks_disabled is
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
self.cursor().execute('SET foreign_key_checks=1') self.cursor().execute('SET foreign_key_checks=1')
finally:
self.needs_rollback = needs_rollback
def check_constraints(self, table_names=None): def check_constraints(self, table_names=None):
""" """

View File

@ -89,6 +89,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_bulk_insert = True has_bulk_insert = True
supports_tablespaces = True supports_tablespaces = True
supports_sequence_reset = False supports_sequence_reset = False
atomic_transactions = False
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.oracle.compiler" compiler_module = "django.db.backends.oracle.compiler"

View File

@ -101,6 +101,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_bulk_insert = True has_bulk_insert = True
can_combine_inserts_with_and_without_auto_increment_pk = False can_combine_inserts_with_and_without_auto_increment_pk = False
autocommits_when_autocommit_is_off = True autocommits_when_autocommit_is_off = True
atomic_transactions = False
supports_paramstyle_pyformat = False supports_paramstyle_pyformat = False
@cached_property @cached_property

View File

@ -19,14 +19,9 @@ class CursorWrapper(object):
self.cursor = cursor self.cursor = cursor
self.db = db self.db = db
SET_DIRTY_ATTRS = frozenset(['execute', 'executemany', 'callproc']) WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
WRAP_ERROR_ATTRS = frozenset([
'callproc', 'close', 'execute', 'executemany',
'fetchone', 'fetchmany', 'fetchall', 'nextset'])
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in CursorWrapper.SET_DIRTY_ATTRS:
self.db.set_dirty()
cursor_attr = getattr(self.cursor, attr) cursor_attr = getattr(self.cursor, attr)
if attr in CursorWrapper.WRAP_ERROR_ATTRS: if attr in CursorWrapper.WRAP_ERROR_ATTRS:
return self.db.wrap_database_errors(cursor_attr) return self.db.wrap_database_errors(cursor_attr)
@ -36,18 +31,42 @@ class CursorWrapper(object):
def __iter__(self): def __iter__(self):
return iter(self.cursor) return iter(self.cursor)
# The following methods cannot be implemented in __getattr__, because the
# code must run when the method is invoked, not just when it is accessed.
def callproc(self, procname, params=None):
self.db.validate_no_broken_transaction()
self.db.set_dirty()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.callproc(procname)
else:
return self.cursor.callproc(procname, params)
def execute(self, sql, params=None):
self.db.validate_no_broken_transaction()
self.db.set_dirty()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.execute(sql)
else:
return self.cursor.execute(sql, params)
def executemany(self, sql, param_list):
self.db.validate_no_broken_transaction()
self.db.set_dirty()
with self.db.wrap_database_errors:
return self.cursor.executemany(sql, param_list)
class CursorDebugWrapper(CursorWrapper): class CursorDebugWrapper(CursorWrapper):
# XXX callproc isn't instrumented at this time.
def execute(self, sql, params=None): def execute(self, sql, params=None):
self.db.set_dirty()
start = time() start = time()
try: try:
with self.db.wrap_database_errors: return super(CursorDebugWrapper, self).execute(sql, params)
if params is None:
# params default might be backend specific
return self.cursor.execute(sql)
return self.cursor.execute(sql, params)
finally: finally:
stop = time() stop = time()
duration = stop - start duration = stop - start
@ -61,11 +80,9 @@ class CursorDebugWrapper(CursorWrapper):
) )
def executemany(self, sql, param_list): def executemany(self, sql, param_list):
self.db.set_dirty()
start = time() start = time()
try: try:
with self.db.wrap_database_errors: return super(CursorDebugWrapper, self).executemany(sql, param_list)
return self.cursor.executemany(sql, param_list)
finally: finally:
stop = time() stop = time()
duration = stop - start duration = stop - start

View File

@ -376,12 +376,10 @@ class QuerySet(object):
params = dict((k, v) for k, v in kwargs.items() if LOOKUP_SEP not in k) params = dict((k, v) for k, v in kwargs.items() if LOOKUP_SEP not in k)
params.update(defaults) params.update(defaults)
obj = self.model(**params) obj = self.model(**params)
sid = transaction.savepoint(using=self.db) with transaction.atomic(using=self.db):
obj.save(force_insert=True, using=self.db) obj.save(force_insert=True, using=self.db)
transaction.savepoint_commit(sid, using=self.db)
return obj, True return obj, True
except DatabaseError: except DatabaseError:
transaction.savepoint_rollback(sid, using=self.db)
exc_info = sys.exc_info() exc_info = sys.exc_info()
try: try:
return self.get(**lookup), False return self.get(**lookup), False

View File

@ -16,14 +16,15 @@ import warnings
from functools import wraps from functools import wraps
from django.db import connections, DatabaseError, DEFAULT_DB_ALIAS from django.db import (
connections, DEFAULT_DB_ALIAS,
DatabaseError, ProgrammingError)
from django.utils.decorators import available_attrs from django.utils.decorators import available_attrs
class TransactionManagementError(Exception): class TransactionManagementError(ProgrammingError):
""" """
This exception is thrown when something bad happens with transaction This exception is thrown when transaction management is used improperly.
management.
""" """
pass pass

View File

@ -163,20 +163,31 @@ Django provides a single API to control database transactions.
called, so the exception handler can also operate on the database if called, so the exception handler can also operate on the database if
necessary. necessary.
.. admonition:: Don't catch database exceptions inside ``atomic``! .. admonition:: Avoid catching exceptions inside ``atomic``!
If you catch :exc:`~django.db.DatabaseError` or a subclass such as When exiting an ``atomic`` block, Django looks at whether it's exited
:exc:`~django.db.IntegrityError` inside an ``atomic`` block, you will normally or with an exception to determine whether to commit or roll
hide from Django the fact that an error has occurred and that the back. If you catch and handle exceptions inside an ``atomic`` block,
transaction is broken. At this point, Django's behavior is unspecified you may hide from Django the fact that a problem has happened. This
and database-dependent. It will usually result in a rollback, which can result in unexpected behavior.
may break your expectations, since you caught the exception.
This is mostly a concern for :exc:`~django.db.DatabaseError` and its
subclasses such as :exc:`~django.db.IntegrityError`. After such an
error, the transaction is broken and Django will perform a rollback at
the end of the ``atomic`` block. If you attempt to run database
queries before the rollback happens, Django will raise a
:class:`~django.db.transaction.TransactionManagementError`. You may
also encounter this behavior when an ORM-related signal handler raises
an exception.
The correct way to catch database errors is around an ``atomic`` block The correct way to catch database errors is around an ``atomic`` block
as shown above. If necessary, add an extra ``atomic`` block for this as shown above. If necessary, add an extra ``atomic`` block for this
purpose -- it's cheap! This pattern is useful to delimit explicitly purpose. This pattern has another advantage: it delimits explicitly
which operations will be rolled back if an exception occurs. which operations will be rolled back if an exception occurs.
If you catch exceptions raised by raw SQL queries, Django's behavior
is unspecified and database-dependent.
In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting
to commit, roll back, or change the autocommit state of the database to commit, roll back, or change the autocommit state of the database
connection within an ``atomic`` block will raise an exception. connection within an ``atomic`` block will raise an exception.

View File

@ -149,11 +149,9 @@ class CustomPKTests(TestCase):
e = Employee.objects.create( e = Employee.objects.create(
employee_code=123, first_name="Frank", last_name="Jones" employee_code=123, first_name="Frank", last_name="Jones"
) )
sid = transaction.savepoint() with self.assertRaises(IntegrityError):
self.assertRaises(IntegrityError, with transaction.atomic():
Employee.objects.create, employee_code=123, first_name="Fred", last_name="Jones" Employee.objects.create(employee_code=123, first_name="Fred", last_name="Jones")
)
transaction.savepoint_rollback(sid)
def test_custom_field_pk(self): def test_custom_field_pk(self):
# Regression for #10785 -- Custom fields can be used for primary keys. # Regression for #10785 -- Custom fields can be used for primary keys.
@ -175,8 +173,6 @@ class CustomPKTests(TestCase):
def test_required_pk(self): def test_required_pk(self):
# The primary key must be specified, so an error is raised if you # The primary key must be specified, so an error is raised if you
# try to create an object without it. # try to create an object without it.
sid = transaction.savepoint() with self.assertRaises(IntegrityError):
self.assertRaises(IntegrityError, with transaction.atomic():
Employee.objects.create, first_name="Tom", last_name="Smith" Employee.objects.create(first_name="Tom", last_name="Smith")
)
transaction.savepoint_rollback(sid)

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import, unicode_literals
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models import F from django.db.models import F
from django.db import transaction
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
@ -185,11 +186,11 @@ class ExpressionsTests(TestCase):
"foo", "foo",
) )
self.assertRaises(FieldError, with transaction.atomic():
lambda: Company.objects.exclude( with self.assertRaises(FieldError):
Company.objects.exclude(
ceo__firstname=F('point_of_contact__firstname') ceo__firstname=F('point_of_contact__firstname')
).update(name=F('point_of_contact__lastname')) ).update(name=F('point_of_contact__lastname'))
)
# F expressions can be used to update attributes on single objects # F expressions can be used to update attributes on single objects
test_gmbh = Company.objects.get(name="Test GmbH") test_gmbh = Company.objects.get(name="Test GmbH")

View File

@ -21,24 +21,29 @@ class ForceTests(TestCase):
# Won't work because force_update and force_insert are mutually # Won't work because force_update and force_insert are mutually
# exclusive # exclusive
c.value = 4 c.value = 4
self.assertRaises(ValueError, c.save, force_insert=True, force_update=True) with self.assertRaises(ValueError):
c.save(force_insert=True, force_update=True)
# Try to update something that doesn't have a primary key in the first # Try to update something that doesn't have a primary key in the first
# place. # place.
c1 = Counter(name="two", value=2) c1 = Counter(name="two", value=2)
self.assertRaises(ValueError, c1.save, force_update=True) with self.assertRaises(ValueError):
with transaction.atomic():
c1.save(force_update=True)
c1.save(force_insert=True) c1.save(force_insert=True)
# Won't work because we can't insert a pk of the same value. # Won't work because we can't insert a pk of the same value.
sid = transaction.savepoint()
c.value = 5 c.value = 5
self.assertRaises(IntegrityError, c.save, force_insert=True) with self.assertRaises(IntegrityError):
transaction.savepoint_rollback(sid) with transaction.atomic():
c.save(force_insert=True)
# Trying to update should still fail, even with manual primary keys, if # Trying to update should still fail, even with manual primary keys, if
# the data isn't in the database already. # the data isn't in the database already.
obj = WithCustomPK(name=1, value=1) obj = WithCustomPK(name=1, value=1)
self.assertRaises(DatabaseError, obj.save, force_update=True) with self.assertRaises(DatabaseError):
with transaction.atomic():
obj.save(force_update=True)
class InheritanceTests(TestCase): class InheritanceTests(TestCase):

View File

@ -118,7 +118,7 @@ class OneToOneTests(TestCase):
self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>') self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>')
# This will fail because each one-to-one field must be unique (and # This will fail because each one-to-one field must be unique (and
# link2=o1 was used for x1, above). # link2=o1 was used for x1, above).
sid = transaction.savepoint()
mm = MultiModel(link1=self.p2, link2=o1, name="x1") mm = MultiModel(link1=self.p2, link2=o1, name="x1")
self.assertRaises(IntegrityError, mm.save) with self.assertRaises(IntegrityError):
transaction.savepoint_rollback(sid) with transaction.atomic():
mm.save()

View File

@ -3,7 +3,7 @@ from __future__ import absolute_import
import sys import sys
from django.db import connection, transaction, DatabaseError, IntegrityError from django.db import connection, transaction, DatabaseError, IntegrityError
from django.test import TransactionTestCase, skipUnlessDBFeature from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import IgnorePendingDeprecationWarningsMixin from django.test.utils import IgnorePendingDeprecationWarningsMixin
from django.utils import six from django.utils import six
from django.utils.unittest import skipIf, skipUnless from django.utils.unittest import skipIf, skipUnless
@ -204,10 +204,10 @@ class AtomicTests(TransactionTestCase):
with transaction.atomic(savepoint=False): with transaction.atomic(savepoint=False):
connection.cursor().execute( connection.cursor().execute(
"SELECT no_such_col FROM transactions_reporter") "SELECT no_such_col FROM transactions_reporter")
transaction.savepoint_rollback(sid) # prevent atomic from rolling back since we're recovering manually
# atomic block should rollback, but prevent it, as we just did it.
self.assertTrue(transaction.get_rollback()) self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False) transaction.set_rollback(False)
transaction.savepoint_rollback(sid)
self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>']) self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
@ -267,11 +267,19 @@ class AtomicMergeTests(TransactionTestCase):
with transaction.atomic(savepoint=False): with transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Tournesol") Reporter.objects.create(first_name="Tournesol")
raise Exception("Oops, that's his last name") raise Exception("Oops, that's his last name")
# It wasn't possible to roll back # The third insert couldn't be roll back. Temporarily mark the
# connection as not needing rollback to check it.
self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False)
self.assertEqual(Reporter.objects.count(), 3) self.assertEqual(Reporter.objects.count(), 3)
# It wasn't possible to roll back transaction.set_rollback(True)
# The second insert couldn't be roll back. Temporarily mark the
# connection as not needing rollback to check it.
self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False)
self.assertEqual(Reporter.objects.count(), 3) self.assertEqual(Reporter.objects.count(), 3)
# The outer block must roll back transaction.set_rollback(True)
# The first block has a savepoint and must roll back.
self.assertQuerysetEqual(Reporter.objects.all(), []) self.assertQuerysetEqual(Reporter.objects.all(), [])
def test_merged_inner_savepoint_rollback(self): def test_merged_inner_savepoint_rollback(self):
@ -283,36 +291,22 @@ class AtomicMergeTests(TransactionTestCase):
with transaction.atomic(savepoint=False): with transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Tournesol") Reporter.objects.create(first_name="Tournesol")
raise Exception("Oops, that's his last name") raise Exception("Oops, that's his last name")
# It wasn't possible to roll back # The third insert couldn't be roll back. Temporarily mark the
# connection as not needing rollback to check it.
self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False)
self.assertEqual(Reporter.objects.count(), 3) self.assertEqual(Reporter.objects.count(), 3)
# The first block with a savepoint must roll back transaction.set_rollback(True)
# The second block has a savepoint and must roll back.
self.assertEqual(Reporter.objects.count(), 1) self.assertEqual(Reporter.objects.count(), 1)
self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>']) self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
def test_merged_outer_rollback_after_inner_failure_and_inner_success(self):
with transaction.atomic():
Reporter.objects.create(first_name="Tintin")
# Inner block without a savepoint fails
with six.assertRaisesRegex(self, Exception, "Oops"):
with transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Haddock")
raise Exception("Oops, that's his last name")
# It wasn't possible to roll back
self.assertEqual(Reporter.objects.count(), 2)
# Inner block with a savepoint succeeds
with transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Archibald", last_name="Haddock")
# It still wasn't possible to roll back
self.assertEqual(Reporter.objects.count(), 3)
# The outer block must rollback
self.assertQuerysetEqual(Reporter.objects.all(), [])
@skipUnless(connection.features.uses_savepoints, @skipUnless(connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints.") "'atomic' requires transactions and savepoints.")
class AtomicErrorsTests(TransactionTestCase): class AtomicErrorsTests(TransactionTestCase):
available_apps = [] available_apps = ['transactions']
def test_atomic_prevents_setting_autocommit(self): def test_atomic_prevents_setting_autocommit(self):
autocommit = transaction.get_autocommit() autocommit = transaction.get_autocommit()
@ -336,6 +330,29 @@ class AtomicErrorsTests(TransactionTestCase):
with self.assertRaises(transaction.TransactionManagementError): with self.assertRaises(transaction.TransactionManagementError):
transaction.leave_transaction_management() transaction.leave_transaction_management()
def test_atomic_prevents_queries_in_broken_transaction(self):
r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
with transaction.atomic():
r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id)
with self.assertRaises(IntegrityError):
r2.save(force_insert=True)
# The transaction is marked as needing rollback.
with self.assertRaises(transaction.TransactionManagementError):
r2.save(force_update=True)
self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Haddock")
@skipIfDBFeature('atomic_transactions')
def test_atomic_allows_queries_after_fixing_transaction(self):
r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
with transaction.atomic():
r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id)
with self.assertRaises(IntegrityError):
r2.save(force_insert=True)
# Mark the transaction as no longer needing rollback.
transaction.set_rollback(False)
r2.save(force_update=True)
self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Calculus")
class AtomicMiscTests(TransactionTestCase): class AtomicMiscTests(TransactionTestCase):