diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index f25a65bafb1..074fa0ed700 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -9,7 +9,7 @@ try: except NameError: # Python 2.3 compat from sets import Set as set - + from django.db.backends import util from django.utils import datetime_safe @@ -31,6 +31,21 @@ class BaseDatabaseWrapper(local): if self.connection is not None: return self.connection.rollback() + def _savepoint(self, sid): + if not self.features.uses_savepoints: + return + self.connection.cursor().execute(self.ops.savepoint_create_sql(sid)) + + def _savepoint_rollback(self, sid): + if not self.features.uses_savepoints: + return + self.connection.cursor().execute(self.ops.savepoint_rollback_sql(sid)) + + def _savepoint_commit(self, sid): + if not self.features.uses_savepoints: + return + self.connection.cursor().execute(self.ops.savepoint_commit_sql(sid)) + def close(self): if self.connection is not None: self.connection.close() @@ -55,6 +70,7 @@ class BaseDatabaseFeatures(object): update_can_self_select = True interprets_empty_strings_as_nulls = False can_use_chunked_reads = True + uses_savepoints = False class BaseDatabaseOperations(object): """ @@ -226,6 +242,26 @@ class BaseDatabaseOperations(object): """ raise NotImplementedError + def savepoint_create_sql(self, sid): + """ + Returns the SQL for starting a new savepoint. Only required if the + "uses_savepoints" feature is True. The "sid" parameter is a string + for the savepoint id. + """ + raise NotImplementedError + + def savepoint_commit_sql(self, sid): + """ + Returns the SQL for committing the given savepoint. + """ + raise NotImplementedError + + def savepoint_rollback_sql(self, sid): + """ + Returns the SQL for rolling back the given savepoint. + """ + raise NotImplementedError + def sql_flush(self, style, tables, sequences): """ Returns a list of SQL statements required to remove all data from @@ -259,7 +295,7 @@ class BaseDatabaseOperations(object): a tablespace. Returns '' if the backend doesn't use tablespaces. """ return '' - + def prep_for_like_query(self, x): """Prepares a value for use in a LIKE query.""" from django.utils.encoding import smart_unicode @@ -336,11 +372,11 @@ class BaseDatabaseIntrospection(object): def table_name_converter(self, name): """Apply a conversion to the name for the purposes of comparison. - + The default table name converter is for case sensitive comparison. """ return name - + def table_names(self): "Returns a list of names of all tables that exist in the database." cursor = self.connection.cursor() @@ -371,10 +407,10 @@ class BaseDatabaseIntrospection(object): for app in models.get_apps(): for model in models.get_models(app): all_models.append(model) - return set([m for m in all_models + return set([m for m in all_models if self.table_name_converter(m._meta.db_table) in map(self.table_name_converter, tables) ]) - + def sequence_list(self): "Returns a list of information about all DB sequences for all models in all apps." from django.db import models @@ -393,8 +429,7 @@ class BaseDatabaseIntrospection(object): sequence_list.append({'table': f.m2m_db_table(), 'column': None}) return sequence_list - - + class BaseDatabaseClient(object): """ This class encapsualtes all backend-specific methods for opening a diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 4a8d6ebef04..792026530f2 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -63,6 +63,9 @@ class UnicodeCursorWrapper(object): def __iter__(self): return iter(self.cursor) +class DatabaseFeatures(BaseDatabaseFeatures): + uses_savepoints = True + class DatabaseWrapper(BaseDatabaseWrapper): operators = { 'exact': '= %s', @@ -83,8 +86,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, *args, **kwargs): super(DatabaseWrapper, self).__init__(*args, **kwargs) - - self.features = BaseDatabaseFeatures() + + self.features = DatabaseFeatures() self.ops = DatabaseOperations() self.client = DatabaseClient() self.creation = DatabaseCreation(self) diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index de7b5a95202..4eb5ead47c3 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -124,3 +124,13 @@ class DatabaseOperations(BaseDatabaseOperations): style.SQL_KEYWORD('FROM'), style.SQL_TABLE(qn(f.m2m_db_table())))) return output + + def savepoint_create_sql(self, sid): + return "SAVEPOINT %s" % sid + + def savepoint_commit_sql(self, sid): + return "RELEASE SAVEPOINT %s" % sid + + def savepoint_rollback_sql(self, sid): + return "ROLLBACK TO SAVEPOINT %s" % sid + diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 139e36ba591..08014bd9936 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -26,6 +26,7 @@ psycopg2.extensions.register_adapter(SafeUnicode, psycopg2.extensions.QuotedStri class DatabaseFeatures(BaseDatabaseFeatures): needs_datetime_string_cast = False + uses_savepoints = True class DatabaseOperations(PostgresqlDatabaseOperations): def last_executed_query(self, cursor, sql, params): diff --git a/django/db/transaction.py b/django/db/transaction.py index cd27cf6044c..55fad9e4579 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -19,7 +19,7 @@ except ImportError: try: from functools import wraps except ImportError: - from django.utils.functional import wraps # Python 2.3, 2.4 fallback. + from django.utils.functional import wraps # Python 2.3, 2.4 fallback. from django.db import connection from django.conf import settings @@ -30,9 +30,10 @@ class TransactionManagementError(Exception): """ pass -# The state is a dictionary of lists. The key to the dict is the current +# The states are dictionaries of lists. The key to the dict is the current # thread and the list is handled as a stack of values. state = {} +savepoint_state = {} # The dirty flag is set by *_unless_managed functions to denote that the # code under transaction management has changed things to require a @@ -164,6 +165,36 @@ def rollback(): connection._rollback() set_clean() +def savepoint(): + """ + Creates a savepoint (if supported and required by the backend) inside the + current transaction. Returns an identifier for the savepoint that will be + used for the subsequent rollback or commit. + """ + thread_ident = thread.get_ident() + if thread_ident in savepoint_state: + savepoint_state[thread_ident].append(None) + else: + savepoint_state[thread_ident] = [None] + tid = str(thread_ident).replace('-', '') + sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident])) + connection._savepoint(sid) + return sid + +def savepoint_rollback(sid): + """ + Rolls back the most recent savepoint (if one exists). Does nothing if + savepoints are not supported. + """ + connection._savepoint_rollback(sid) + +def savepoint_commit(sid): + """ + Commits the most recent savepoint (if one exists). Does nothing if + savepoints are not supported. + """ + connection._savepoint_commit(sid) + ############## # DECORATORS # ############## diff --git a/tests/modeltests/force_insert_update/models.py b/tests/modeltests/force_insert_update/models.py index feffed5faf7..c9b9fe0c765 100644 --- a/tests/modeltests/force_insert_update/models.py +++ b/tests/modeltests/force_insert_update/models.py @@ -2,7 +2,7 @@ Tests for forcing insert and update queries (instead of Django's normal automatic behaviour). """ -from django.db import models +from django.db import models, transaction class Counter(models.Model): name = models.CharField(max_length = 10) @@ -40,15 +40,13 @@ ValueError: Cannot force an update in save() with no primary key. >>> c1.save(force_insert=True) # Won't work because we can't insert a pk of the same value. +>>> sid = transaction.savepoint() >>> c.value = 5 >>> c.save(force_insert=True) Traceback (most recent call last): ... IntegrityError: ... - -# Work around transaction failure cleaning up for PostgreSQL. ->>> from django.db import connection ->>> connection.close() +>>> transaction.savepoint_rollback(sid) # Trying to update should still fail, even with manual primary keys, if the # data isn't in the database already. diff --git a/tests/modeltests/one_to_one/models.py b/tests/modeltests/one_to_one/models.py index 6fa4dd8c183..348a543c528 100644 --- a/tests/modeltests/one_to_one/models.py +++ b/tests/modeltests/one_to_one/models.py @@ -6,7 +6,7 @@ To define a one-to-one relationship, use ``OneToOneField()``. In this example, a ``Place`` optionally can be a ``Restaurant``. """ -from django.db import models, connection +from django.db import models, transaction class Place(models.Model): name = models.CharField(max_length=50) @@ -178,13 +178,11 @@ DoesNotExist: Restaurant matching query does not exist. # This will fail because each one-to-one field must be unique (and link2=o1 was # used for x1, above). +>>> sid = transaction.savepoint() >>> MultiModel(link1=p2, link2=o1, name="x1").save() Traceback (most recent call last): ... IntegrityError: ... +>>> transaction.savepoint_rollback(sid) -# Because the unittests all use a single connection, we need to force a -# reconnect here to ensure the connection is clean (after the previous -# IntegrityError). ->>> connection.close() """}