diff --git a/AUTHORS b/AUTHORS index 6c3cd58b7c..a8bfe07f6b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -347,6 +347,7 @@ answer newbie questions, and generally made Django that much better: Jeff Triplett Jens Diemer Jens Page + Jensen Cochran Jeong-Min Lee Jérémie Blaser Jeremy Carbaugh diff --git a/django/db/models/query.py b/django/db/models/query.py index 6e9741c9fe..cff7135ef6 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -482,15 +482,16 @@ class QuerySet(object): defaults = defaults or {} lookup, params = self._extract_model_params(defaults, **kwargs) self._for_write = True - try: - obj = self.get(**lookup) - except self.model.DoesNotExist: - obj, created = self._create_object_from_params(lookup, params) - if created: - return obj, created - for k, v in six.iteritems(defaults): - setattr(obj, k, v() if callable(v) else v) - obj.save(using=self.db) + with transaction.atomic(using=self.db): + try: + obj = self.select_for_update().get(**lookup) + except self.model.DoesNotExist: + obj, created = self._create_object_from_params(lookup, params) + if created: + return obj, created + for k, v in six.iteritems(defaults): + setattr(obj, k, v() if callable(v) else v) + obj.save(using=self.db) return obj, False def _create_object_from_params(self, lookup, params): diff --git a/tests/get_or_create/tests.py b/tests/get_or_create/tests.py index 0cbb9a1fff..0a774eff77 100644 --- a/tests/get_or_create/tests.py +++ b/tests/get_or_create/tests.py @@ -1,10 +1,14 @@ from __future__ import unicode_literals +import time import traceback -from datetime import date +from datetime import date, datetime, timedelta +from threading import Thread from django.db import DatabaseError, IntegrityError -from django.test import TestCase, TransactionTestCase, ignore_warnings +from django.test import ( + TestCase, TransactionTestCase, ignore_warnings, skipUnlessDBFeature, +) from django.utils.encoding import DjangoUnicodeDecodeError from .models import ( @@ -422,3 +426,48 @@ class UpdateOrCreateTests(TestCase): ) self.assertIs(created, False) self.assertEqual(obj.last_name, 'NotHarrison') + + +class UpdateOrCreateTransactionTests(TransactionTestCase): + available_apps = ['get_or_create'] + + @skipUnlessDBFeature('has_select_for_update') + @skipUnlessDBFeature('supports_transactions') + def test_updates_in_transaction(self): + """ + Objects are selected and updated in a transaction to avoid race + conditions. This test forces update_or_create() to hold the lock + in another thread for a relatively long time so that it can update + while it holds the lock. The updated field isn't a field in 'defaults', + so update_or_create() shouldn't have an effect on it. + """ + def birthday_sleep(): + time.sleep(0.3) + return date(1940, 10, 10) + + def update_birthday_slowly(): + Person.objects.update_or_create( + first_name='John', defaults={'birthday': birthday_sleep} + ) + + Person.objects.create(first_name='John', last_name='Lennon', birthday=date(1940, 10, 9)) + + # update_or_create in a separate thread + t = Thread(target=update_birthday_slowly) + before_start = datetime.now() + t.start() + + # Wait for lock to begin + time.sleep(0.05) + + # Update during lock + Person.objects.filter(first_name='John').update(last_name='NotLennon') + after_update = datetime.now() + + # Wait for thread to finish + t.join() + + # The update remains and it blocked. + updated_person = Person.objects.get(first_name='John') + self.assertGreater(after_update - before_start, timedelta(seconds=0.3)) + self.assertEqual(updated_person.last_name, 'NotLennon')