diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 93aaf9971a..8ce2c11344 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -150,18 +150,28 @@ class DatabaseFeatures(BaseDatabaseFeatures): requires_explicit_null_ordering_when_grouping = True allows_primary_key_0 = False + def __init__(self, connection): + super(DatabaseFeatures, self).__init__(connection) + self._storage_engine = None + + def _mysql_storage_engine(self): + "Internal method used in Django tests. Don't rely on this from your code" + if self._storage_engine is None: + cursor = self.connection.cursor() + cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') + # This command is MySQL specific; the second column + # will tell you the default table type of the created + # table. Since all Django's test tables will have the same + # table type, that's enough to evaluate the feature. + cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") + result = cursor.fetchone() + cursor.execute('DROP TABLE INTROSPECT_TEST') + self._storage_engine = result[1] + return self._storage_engine + def _can_introspect_foreign_keys(self): "Confirm support for introspected foreign keys" - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') - # This command is MySQL specific; the second column - # will tell you the default table type of the created - # table. Since all Django's test tables will have the same - # table type, that's enough to evaluate the feature. - cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") - result = cursor.fetchone() - cursor.execute('DROP TABLE INTROSPECT_TEST') - return result[1] != 'MyISAM' + return self._mysql_storage_engine() != 'MyISAM' class DatabaseOperations(BaseDatabaseOperations): compiler_module = "django.db.backends.mysql.compiler" @@ -285,6 +295,15 @@ class DatabaseOperations(BaseDatabaseOperations): items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) return "VALUES " + ", ".join([items_sql] * num_values) + def savepoint_create_sql(self, sid): + return "SAVEPOINT %s" % sid + + def savepoint_commit_sql(self, sid): + return "RELEASE SAVEPOINT %s" % sid + + def savepoint_rollback_sql(self, sid): + return "ROLLBACK TO SAVEPOINT %s" % sid + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'mysql' operators = { @@ -354,6 +373,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.connection = Database.connect(**kwargs) self.connection.encoders[SafeUnicode] = self.connection.encoders[unicode] self.connection.encoders[SafeString] = self.connection.encoders[str] + self.features.uses_savepoints = \ + self.get_server_version() >= (5, 0, 3) connection_created.send(sender=self.__class__, connection=self) cursor = self.connection.cursor() if new_connection: diff --git a/docs/releases/1.4.txt b/docs/releases/1.4.txt index 0cb30d633a..de7afc83c3 100644 --- a/docs/releases/1.4.txt +++ b/docs/releases/1.4.txt @@ -553,6 +553,9 @@ Django 1.4 also includes several smaller improvements worth noting: password reset mechanism and making it available is now much easier. For details, see :ref:`auth_password_reset`. +* The MySQL database backend can now make use of the savepoint feature + implemented by MySQL version 5.0.3 or newer with the InnoDB storage engine. + Backwards incompatible changes in 1.4 ===================================== diff --git a/docs/topics/db/transactions.txt b/docs/topics/db/transactions.txt index b2ee26d6b0..6e6754a9d8 100644 --- a/docs/topics/db/transactions.txt +++ b/docs/topics/db/transactions.txt @@ -225,11 +225,14 @@ transaction middleware, and only modify selected functions as needed. Savepoints ========== -A savepoint is a marker within a transaction that enables you to roll back -part of a transaction, rather than the full transaction. Savepoints are -available to the PostgreSQL 8 and Oracle backends. Other backends will -provide the savepoint functions, but they are empty operations - they won't -actually do anything. +A savepoint is a marker within a transaction that enables you to roll back part +of a transaction, rather than the full transaction. Savepoints are available to +the PostgreSQL 8, Oracle and MySQL (version 5.0.3 and newer, when using the +InnoDB storage engine) backends. Other backends will provide the savepoint +functions, but they are empty operations - they won't actually do anything. + +.. versionchanged:: 1.4 + Savepoint support when using the MySQL backend was added in Django 1.4 Savepoints aren't especially useful if you are using the default ``autocommit`` behavior of Django. However, if you are using diff --git a/tests/regressiontests/transactions_regress/tests.py b/tests/regressiontests/transactions_regress/tests.py index bdc1b53600..22f1b0f911 100644 --- a/tests/regressiontests/transactions_regress/tests.py +++ b/tests/regressiontests/transactions_regress/tests.py @@ -4,6 +4,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import connection, transaction from django.db.transaction import commit_on_success, commit_manually, TransactionManagementError from django.test import TransactionTestCase, skipUnlessDBFeature +from django.utils.unittest import skipIf from .models import Mod, M2mA, M2mB @@ -165,6 +166,7 @@ class TestTransactionClosing(TransactionTestCase): except: self.fail("A transaction consisting of a failed operation was not closed.") + class TestManyToManyAddTransaction(TransactionTestCase): def test_manyrelated_add_commit(self): "Test for https://code.djangoproject.com/ticket/16818" @@ -178,3 +180,39 @@ class TestManyToManyAddTransaction(TransactionTestCase): # that the bulk insert was not auto-committed. transaction.rollback() self.assertEqual(a.others.count(), 1) + + +class SavepointTest(TransactionTestCase): + + @skipUnlessDBFeature('uses_savepoints') + def test_savepoint_commit(self): + @commit_manually + def work(): + mod = Mod.objects.create(fld=1) + pk = mod.pk + sid = transaction.savepoint() + mod1 = Mod.objects.filter(pk=pk).update(fld=10) + transaction.savepoint_commit(sid) + mod2 = Mod.objects.get(pk=pk) + transaction.commit() + self.assertEqual(mod2.fld, 10) + + work() + + @skipIf(connection.vendor == 'mysql' and \ + connection.features._mysql_storage_engine() == 'MyISAM', + "MyISAM MySQL storage engine doesn't support savepoints") + @skipUnlessDBFeature('uses_savepoints') + def test_savepoint_rollback(self): + @commit_manually + def work(): + mod = Mod.objects.create(fld=1) + pk = mod.pk + sid = transaction.savepoint() + mod1 = Mod.objects.filter(pk=pk).update(fld=20) + transaction.savepoint_rollback(sid) + mod2 = Mod.objects.get(pk=pk) + transaction.commit() + self.assertEqual(mod2.fld, 1) + + work()