From 728548e483a5a3486939b0c8e62520296587482e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2013 22:14:17 +0200 Subject: [PATCH] Fixed #21134 -- Prevented queries in broken transactions. Squashed commit of the following: commit 63ddb271a44df389b2c302e421fc17b7f0529755 Author: Aymeric Augustin Date: Sun Sep 29 22:51:00 2013 +0200 Clarified interactions between atomic and exceptions. commit 2899ec299228217c876ba3aa4024e523a41c8504 Author: Aymeric Augustin Date: Sun Sep 22 22:45:32 2013 +0200 Fixed TransactionManagementError in tests. Previous commit introduced an additional check to prevent running queries in transactions that will be rolled back, which triggered a few failures in the tests. In practice using transaction.atomic instead of the low-level savepoint APIs was enough to fix the problems. commit 4a639b059ea80aeb78f7f160a7d4b9f609b9c238 Author: Aymeric Augustin Date: Tue Sep 24 22:24:17 2013 +0200 Allowed nesting constraint_checks_disabled inside atomic. Since MySQL handles transactions loosely, this isn't a problem. commit 2a4ab1cb6e83391ff7e25d08479e230ca564bfef Author: Aymeric Augustin Date: Sat Sep 21 18:43:12 2013 +0200 Prevented running queries in transactions that will be rolled back. This avoids a counter-intuitive behavior in an edge case on databases with non-atomic transaction semantics. It prevents using savepoint_rollback() inside an atomic block without calling set_rollback(False) first, which is backwards-incompatible in tests. Refs #21134. commit 8e3db393853c7ac64a445b66e57f3620a3fde7b0 Author: Aymeric Augustin Date: Sun Sep 22 22:14:17 2013 +0200 Replaced manual savepoints by atomic blocks. This ensures the rollback flag is handled consistently in internal APIs. --- django/contrib/sessions/backends/db.py | 5 +- django/db/backends/__init__.py | 9 ++++ django/db/backends/mysql/base.py | 9 +++- django/db/backends/oracle/base.py | 1 + django/db/backends/sqlite3/base.py | 1 + django/db/backends/utils.py | 47 +++++++++++------ django/db/models/query.py | 26 ++++------ django/db/transaction.py | 9 ++-- docs/topics/db/transactions.txt | 27 +++++++--- tests/custom_pk/tests.py | 16 +++--- tests/expressions/tests.py | 11 ++-- tests/force_insert_update/tests.py | 17 +++--- tests/one_to_one/tests.py | 6 +-- tests/transactions/tests.py | 71 ++++++++++++++++---------- 14 files changed, 156 insertions(+), 99 deletions(-) diff --git a/django/contrib/sessions/backends/db.py b/django/contrib/sessions/backends/db.py index 206fca2700..7be99c3e16 100644 --- a/django/contrib/sessions/backends/db.py +++ b/django/contrib/sessions/backends/db.py @@ -58,12 +58,11 @@ class SessionStore(SessionBase): expire_date=self.get_expiry_date() ) using = router.db_for_write(Session, instance=obj) - sid = transaction.savepoint(using=using) try: - obj.save(force_insert=must_create, using=using) + with transaction.atomic(using=using): + obj.save(force_insert=must_create, using=using) except IntegrityError: if must_create: - transaction.savepoint_rollback(sid, using=using) raise CreateError raise diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 74046a0d9b..8dd15dfee1 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -361,6 +361,12 @@ class BaseDatabaseWrapper(object): raise TransactionManagementError( "This is forbidden when an 'atomic' block is active.") + def validate_no_broken_transaction(self): + if self.needs_rollback: + raise TransactionManagementError( + "An error occurred in the current transaction. You can't " + "execute queries until the end of the 'atomic' block.") + def abort(self): """ Roll back any ongoing transaction and clean the transaction state @@ -638,6 +644,9 @@ class BaseDatabaseFeatures(object): # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965 autocommits_when_autocommit_is_off = False + # Does the backend prevent running SQL queries in broken transactions? + atomic_transactions = True + # Can we roll back DDL in a transaction? can_rollback_ddl = False diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 69e32fe627..f7fb8217a9 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -172,6 +172,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): requires_explicit_null_ordering_when_grouping = True allows_primary_key_0 = False uses_savepoints = True + atomic_transactions = False supports_check_constraints = False def __init__(self, connection): @@ -484,7 +485,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): """ Re-enable foreign key checks after they have been disabled. """ - self.cursor().execute('SET foreign_key_checks=1') + # Override needs_rollback in case constraint_checks_disabled is + # nested inside transaction.atomic. + self.needs_rollback, needs_rollback = False, self.needs_rollback + try: + self.cursor().execute('SET foreign_key_checks=1') + finally: + self.needs_rollback = needs_rollback def check_constraints(self, table_names=None): """ diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index eead105b24..cb4de1dd56 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -96,6 +96,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_bulk_insert = True supports_tablespaces = True supports_sequence_reset = False + atomic_transactions = False supports_combined_alters = False max_index_name_length = 30 nulls_order_largest = True diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 80d88c5cab..ac2376a819 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -105,6 +105,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_foreign_keys = False supports_check_constraints = False autocommits_when_autocommit_is_off = True + atomic_transactions = False supports_paramstyle_pyformat = False supports_sequence_reset = False diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index ccda77f658..095cb4efe4 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -19,14 +19,9 @@ class CursorWrapper(object): self.cursor = cursor self.db = db - SET_DIRTY_ATTRS = frozenset(['execute', 'executemany', 'callproc']) - WRAP_ERROR_ATTRS = frozenset([ - 'callproc', 'close', 'execute', 'executemany', - 'fetchone', 'fetchmany', 'fetchall', 'nextset']) + WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset']) def __getattr__(self, attr): - if attr in CursorWrapper.SET_DIRTY_ATTRS: - self.db.set_dirty() cursor_attr = getattr(self.cursor, attr) if attr in CursorWrapper.WRAP_ERROR_ATTRS: return self.db.wrap_database_errors(cursor_attr) @@ -44,18 +39,42 @@ class CursorWrapper(object): # specific behavior. self.close() + # The following methods cannot be implemented in __getattr__, because the + # code must run when the method is invoked, not just when it is accessed. + + def callproc(self, procname, params=None): + self.db.validate_no_broken_transaction() + self.db.set_dirty() + with self.db.wrap_database_errors: + if params is None: + return self.cursor.callproc(procname) + else: + return self.cursor.callproc(procname, params) + + def execute(self, sql, params=None): + self.db.validate_no_broken_transaction() + self.db.set_dirty() + with self.db.wrap_database_errors: + if params is None: + return self.cursor.execute(sql) + else: + return self.cursor.execute(sql, params) + + def executemany(self, sql, param_list): + self.db.validate_no_broken_transaction() + self.db.set_dirty() + with self.db.wrap_database_errors: + return self.cursor.executemany(sql, param_list) + class CursorDebugWrapper(CursorWrapper): + # XXX callproc isn't instrumented at this time. + def execute(self, sql, params=None): - self.db.set_dirty() start = time() try: - with self.db.wrap_database_errors: - if params is None: - # params default might be backend specific - return self.cursor.execute(sql) - return self.cursor.execute(sql, params) + return super(CursorDebugWrapper, self).execute(sql, params) finally: stop = time() duration = stop - start @@ -69,11 +88,9 @@ class CursorDebugWrapper(CursorWrapper): ) def executemany(self, sql, param_list): - self.db.set_dirty() start = time() try: - with self.db.wrap_database_errors: - return self.cursor.executemany(sql, param_list) + return super(CursorDebugWrapper, self).executemany(sql, param_list) finally: stop = time() duration = stop - start diff --git a/django/db/models/query.py b/django/db/models/query.py index 0aff4d89ad..baf436791a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -436,14 +436,9 @@ class QuerySet(object): for k, v in six.iteritems(defaults): setattr(obj, k, v) - sid = transaction.savepoint(using=self.db) - try: + with transaction.atomic(using=self.db): obj.save(using=self.db) - transaction.savepoint_commit(sid, using=self.db) - return obj, False - except DatabaseError: - transaction.savepoint_rollback(sid, using=self.db) - six.reraise(*sys.exc_info()) + return obj, False def _create_object_from_params(self, lookup, params): """ @@ -451,19 +446,16 @@ class QuerySet(object): Used by get_or_create and update_or_create """ obj = self.model(**params) - sid = transaction.savepoint(using=self.db) try: - obj.save(force_insert=True, using=self.db) - transaction.savepoint_commit(sid, using=self.db) + with transaction.atomic(using=self.db): + obj.save(force_insert=True, using=self.db) return obj, True - except DatabaseError as e: - transaction.savepoint_rollback(sid, using=self.db) + except IntegrityError: exc_info = sys.exc_info() - if isinstance(e, IntegrityError): - try: - return self.get(**lookup), False - except self.model.DoesNotExist: - pass + try: + return self.get(**lookup), False + except self.model.DoesNotExist: + pass six.reraise(*exc_info) def _extract_model_params(self, defaults, **kwargs): diff --git a/django/db/transaction.py b/django/db/transaction.py index 7509ad3788..86a357f1cc 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -16,14 +16,15 @@ import warnings from functools import wraps -from django.db import connections, DatabaseError, DEFAULT_DB_ALIAS +from django.db import ( + connections, DEFAULT_DB_ALIAS, + DatabaseError, ProgrammingError) from django.utils.decorators import available_attrs -class TransactionManagementError(Exception): +class TransactionManagementError(ProgrammingError): """ - This exception is thrown when something bad happens with transaction - management. + This exception is thrown when transaction management is used improperly. """ pass diff --git a/docs/topics/db/transactions.txt b/docs/topics/db/transactions.txt index 1483bddd0b..c7adb9e191 100644 --- a/docs/topics/db/transactions.txt +++ b/docs/topics/db/transactions.txt @@ -163,20 +163,31 @@ Django provides a single API to control database transactions. called, so the exception handler can also operate on the database if necessary. - .. admonition:: Don't catch database exceptions inside ``atomic``! + .. admonition:: Avoid catching exceptions inside ``atomic``! - If you catch :exc:`~django.db.DatabaseError` or a subclass such as - :exc:`~django.db.IntegrityError` inside an ``atomic`` block, you will - hide from Django the fact that an error has occurred and that the - transaction is broken. At this point, Django's behavior is unspecified - and database-dependent. It will usually result in a rollback, which - may break your expectations, since you caught the exception. + When exiting an ``atomic`` block, Django looks at whether it's exited + normally or with an exception to determine whether to commit or roll + back. If you catch and handle exceptions inside an ``atomic`` block, + you may hide from Django the fact that a problem has happened. This + can result in unexpected behavior. + + This is mostly a concern for :exc:`~django.db.DatabaseError` and its + subclasses such as :exc:`~django.db.IntegrityError`. After such an + error, the transaction is broken and Django will perform a rollback at + the end of the ``atomic`` block. If you attempt to run database + queries before the rollback happens, Django will raise a + :class:`~django.db.transaction.TransactionManagementError`. You may + also encounter this behavior when an ORM-related signal handler raises + an exception. The correct way to catch database errors is around an ``atomic`` block as shown above. If necessary, add an extra ``atomic`` block for this - purpose -- it's cheap! This pattern is useful to delimit explicitly + purpose. This pattern has another advantage: it delimits explicitly which operations will be rolled back if an exception occurs. + If you catch exceptions raised by raw SQL queries, Django's behavior + is unspecified and database-dependent. + In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting to commit, roll back, or change the autocommit state of the database connection within an ``atomic`` block will raise an exception. diff --git a/tests/custom_pk/tests.py b/tests/custom_pk/tests.py index a452561edb..22369747a9 100644 --- a/tests/custom_pk/tests.py +++ b/tests/custom_pk/tests.py @@ -149,11 +149,9 @@ class CustomPKTests(TestCase): Employee.objects.create( employee_code=123, first_name="Frank", last_name="Jones" ) - sid = transaction.savepoint() - self.assertRaises(IntegrityError, - Employee.objects.create, employee_code=123, first_name="Fred", last_name="Jones" - ) - transaction.savepoint_rollback(sid) + with self.assertRaises(IntegrityError): + with transaction.atomic(): + Employee.objects.create(employee_code=123, first_name="Fred", last_name="Jones") def test_custom_field_pk(self): # Regression for #10785 -- Custom fields can be used for primary keys. @@ -175,8 +173,6 @@ class CustomPKTests(TestCase): def test_required_pk(self): # The primary key must be specified, so an error is raised if you # try to create an object without it. - sid = transaction.savepoint() - self.assertRaises(IntegrityError, - Employee.objects.create, first_name="Tom", last_name="Smith" - ) - transaction.savepoint_rollback(sid) + with self.assertRaises(IntegrityError): + with transaction.atomic(): + Employee.objects.create(first_name="Tom", last_name="Smith") diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 9801d0acbb..a24c2fbc10 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from django.core.exceptions import FieldError from django.db.models import F +from django.db import transaction from django.test import TestCase from django.utils import six @@ -185,11 +186,11 @@ class ExpressionsTests(TestCase): "foo", ) - self.assertRaises(FieldError, - lambda: Company.objects.exclude( - ceo__firstname=F('point_of_contact__firstname') - ).update(name=F('point_of_contact__lastname')) - ) + with transaction.atomic(): + with self.assertRaises(FieldError): + Company.objects.exclude( + ceo__firstname=F('point_of_contact__firstname') + ).update(name=F('point_of_contact__lastname')) # F expressions can be used to update attributes on single objects test_gmbh = Company.objects.get(name="Test GmbH") diff --git a/tests/force_insert_update/tests.py b/tests/force_insert_update/tests.py index 706a099872..6d87151d83 100644 --- a/tests/force_insert_update/tests.py +++ b/tests/force_insert_update/tests.py @@ -21,24 +21,29 @@ class ForceTests(TestCase): # Won't work because force_update and force_insert are mutually # exclusive c.value = 4 - self.assertRaises(ValueError, c.save, force_insert=True, force_update=True) + with self.assertRaises(ValueError): + c.save(force_insert=True, force_update=True) # Try to update something that doesn't have a primary key in the first # place. c1 = Counter(name="two", value=2) - self.assertRaises(ValueError, c1.save, force_update=True) + with self.assertRaises(ValueError): + with transaction.atomic(): + c1.save(force_update=True) c1.save(force_insert=True) # Won't work because we can't insert a pk of the same value. - sid = transaction.savepoint() c.value = 5 - self.assertRaises(IntegrityError, c.save, force_insert=True) - transaction.savepoint_rollback(sid) + with self.assertRaises(IntegrityError): + with transaction.atomic(): + c.save(force_insert=True) # Trying to update should still fail, even with manual primary keys, if # the data isn't in the database already. obj = WithCustomPK(name=1, value=1) - self.assertRaises(DatabaseError, obj.save, force_update=True) + with self.assertRaises(DatabaseError): + with transaction.atomic(): + obj.save(force_update=True) class InheritanceTests(TestCase): diff --git a/tests/one_to_one/tests.py b/tests/one_to_one/tests.py index 6e16d81cea..45a72f0df8 100644 --- a/tests/one_to_one/tests.py +++ b/tests/one_to_one/tests.py @@ -118,7 +118,7 @@ class OneToOneTests(TestCase): self.assertEqual(repr(o1.multimodel), '') # This will fail because each one-to-one field must be unique (and # link2=o1 was used for x1, above). - sid = transaction.savepoint() mm = MultiModel(link1=self.p2, link2=o1, name="x1") - self.assertRaises(IntegrityError, mm.save) - transaction.savepoint_rollback(sid) + with self.assertRaises(IntegrityError): + with transaction.atomic(): + mm.save() diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py index 70a77b719e..0cef7a545b 100644 --- a/tests/transactions/tests.py +++ b/tests/transactions/tests.py @@ -4,7 +4,7 @@ import sys from unittest import skipIf, skipUnless from django.db import connection, transaction, DatabaseError, IntegrityError -from django.test import TransactionTestCase, skipUnlessDBFeature +from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import IgnoreDeprecationWarningsMixin from django.utils import six @@ -204,10 +204,10 @@ class AtomicTests(TransactionTestCase): with transaction.atomic(savepoint=False): connection.cursor().execute( "SELECT no_such_col FROM transactions_reporter") - transaction.savepoint_rollback(sid) - # atomic block should rollback, but prevent it, as we just did it. + # prevent atomic from rolling back since we're recovering manually self.assertTrue(transaction.get_rollback()) transaction.set_rollback(False) + transaction.savepoint_rollback(sid) self.assertQuerysetEqual(Reporter.objects.all(), ['']) @@ -267,11 +267,19 @@ class AtomicMergeTests(TransactionTestCase): with transaction.atomic(savepoint=False): Reporter.objects.create(first_name="Calculus") raise Exception("Oops, that's his last name") - # It wasn't possible to roll back + # The third insert couldn't be roll back. Temporarily mark the + # connection as not needing rollback to check it. + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) self.assertEqual(Reporter.objects.count(), 3) - # It wasn't possible to roll back + transaction.set_rollback(True) + # The second insert couldn't be roll back. Temporarily mark the + # connection as not needing rollback to check it. + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) self.assertEqual(Reporter.objects.count(), 3) - # The outer block must roll back + transaction.set_rollback(True) + # The first block has a savepoint and must roll back. self.assertQuerysetEqual(Reporter.objects.all(), []) def test_merged_inner_savepoint_rollback(self): @@ -283,36 +291,22 @@ class AtomicMergeTests(TransactionTestCase): with transaction.atomic(savepoint=False): Reporter.objects.create(first_name="Calculus") raise Exception("Oops, that's his last name") - # It wasn't possible to roll back + # The third insert couldn't be roll back. Temporarily mark the + # connection as not needing rollback to check it. + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) self.assertEqual(Reporter.objects.count(), 3) - # The first block with a savepoint must roll back + transaction.set_rollback(True) + # The second block has a savepoint and must roll back. self.assertEqual(Reporter.objects.count(), 1) self.assertQuerysetEqual(Reporter.objects.all(), ['']) - def test_merged_outer_rollback_after_inner_failure_and_inner_success(self): - with transaction.atomic(): - Reporter.objects.create(first_name="Tintin") - # Inner block without a savepoint fails - with six.assertRaisesRegex(self, Exception, "Oops"): - with transaction.atomic(savepoint=False): - Reporter.objects.create(first_name="Haddock") - raise Exception("Oops, that's his last name") - # It wasn't possible to roll back - self.assertEqual(Reporter.objects.count(), 2) - # Inner block with a savepoint succeeds - with transaction.atomic(savepoint=False): - Reporter.objects.create(first_name="Archibald", last_name="Haddock") - # It still wasn't possible to roll back - self.assertEqual(Reporter.objects.count(), 3) - # The outer block must rollback - self.assertQuerysetEqual(Reporter.objects.all(), []) - @skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") class AtomicErrorsTests(TransactionTestCase): - available_apps = [] + available_apps = ['transactions'] def test_atomic_prevents_setting_autocommit(self): autocommit = transaction.get_autocommit() @@ -336,6 +330,29 @@ class AtomicErrorsTests(TransactionTestCase): with self.assertRaises(transaction.TransactionManagementError): transaction.leave_transaction_management() + def test_atomic_prevents_queries_in_broken_transaction(self): + r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with transaction.atomic(): + r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id) + with self.assertRaises(IntegrityError): + r2.save(force_insert=True) + # The transaction is marked as needing rollback. + with self.assertRaises(transaction.TransactionManagementError): + r2.save(force_update=True) + self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Haddock") + + @skipIfDBFeature('atomic_transactions') + def test_atomic_allows_queries_after_fixing_transaction(self): + r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with transaction.atomic(): + r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id) + with self.assertRaises(IntegrityError): + r2.save(force_insert=True) + # Mark the transaction as no longer needing rollback. + transaction.set_rollback(False) + r2.save(force_update=True) + self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Calculus") + class AtomicMiscTests(TransactionTestCase):