From 3eb80748087004608ead65834a766d257e9edfc9 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Tue, 12 Aug 2008 05:59:43 +0000 Subject: [PATCH] Added savepoint protection to get_or_create() to avoid problems on PostgreSQL. Fixed #7402. Also made savepoint handling easier to use when wrapped around calls that might commit a transaction. This is tested by the get_or_create tests. git-svn-id: http://code.djangoproject.com/svn/django/trunk@8315 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 3 +++ django/db/transaction.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 14d89dacae..9d46b4046d 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -326,9 +326,12 @@ class QuerySet(object): params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params.update(defaults) obj = self.model(**params) + sid = transaction.savepoint() obj.save() + transaction.savepoint_commit(sid) return obj, True except IntegrityError, e: + transaction.savepoint_rollback(sid) return self.get(**kwargs), False def latest(self, field_name=None): diff --git a/django/db/transaction.py b/django/db/transaction.py index 55fad9e457..e5e8890ee7 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -105,6 +105,12 @@ def set_clean(): dirty[thread_ident] = False else: raise TransactionManagementError("This code isn't under transaction management") + clean_savepoints() + +def clean_savepoints(): + thread_ident = thread.get_ident() + if thread_ident in savepoint_state: + del savepoint_state[thread_ident] def is_managed(): """ @@ -139,6 +145,7 @@ def commit_unless_managed(): """ if not is_managed(): connection._commit() + clean_savepoints() else: set_dirty() @@ -186,14 +193,16 @@ def savepoint_rollback(sid): Rolls back the most recent savepoint (if one exists). Does nothing if savepoints are not supported. """ - connection._savepoint_rollback(sid) + if thread.get_ident() in savepoint_state: + connection._savepoint_rollback(sid) def savepoint_commit(sid): """ Commits the most recent savepoint (if one exists). Does nothing if savepoints are not supported. """ - connection._savepoint_commit(sid) + if thread.get_ident() in savepoint_state: + connection._savepoint_commit(sid) ############## # DECORATORS #