Added savepoint protection to get_or_create() to avoid problems on PostgreSQL.

Fixed #7402.

Also made savepoint handling easier to use when wrapped around calls that might
commit a transaction. This is tested by the get_or_create tests.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@8315 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-08-12 05:59:43 +00:00
parent 220993bcc5
commit 3eb8074808
2 changed files with 14 additions and 2 deletions

View File

@ -326,9 +326,12 @@ class QuerySet(object):
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
params.update(defaults) params.update(defaults)
obj = self.model(**params) obj = self.model(**params)
sid = transaction.savepoint()
obj.save() obj.save()
transaction.savepoint_commit(sid)
return obj, True return obj, True
except IntegrityError, e: except IntegrityError, e:
transaction.savepoint_rollback(sid)
return self.get(**kwargs), False return self.get(**kwargs), False
def latest(self, field_name=None): def latest(self, field_name=None):

View File

@ -105,6 +105,12 @@ def set_clean():
dirty[thread_ident] = False dirty[thread_ident] = False
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
clean_savepoints()
def clean_savepoints():
thread_ident = thread.get_ident()
if thread_ident in savepoint_state:
del savepoint_state[thread_ident]
def is_managed(): def is_managed():
""" """
@ -139,6 +145,7 @@ def commit_unless_managed():
""" """
if not is_managed(): if not is_managed():
connection._commit() connection._commit()
clean_savepoints()
else: else:
set_dirty() set_dirty()
@ -186,14 +193,16 @@ def savepoint_rollback(sid):
Rolls back the most recent savepoint (if one exists). Does nothing if Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported. savepoints are not supported.
""" """
connection._savepoint_rollback(sid) if thread.get_ident() in savepoint_state:
connection._savepoint_rollback(sid)
def savepoint_commit(sid): def savepoint_commit(sid):
""" """
Commits the most recent savepoint (if one exists). Does nothing if Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported. savepoints are not supported.
""" """
connection._savepoint_commit(sid) if thread.get_ident() in savepoint_state:
connection._savepoint_commit(sid)
############## ##############
# DECORATORS # # DECORATORS #