Fixed #21174 -- transaction control in related manager methods

This commit is contained in:
Anssi Kääriäinen 2013-09-26 21:16:50 +03:00
parent 93cc6dcdac
commit 1df3c49a1a
4 changed files with 25 additions and 17 deletions

View File

@ -372,14 +372,12 @@ def create_generic_related_manager(superclass):
def remove(self, *objs): def remove(self, *objs):
db = router.db_for_write(self.model, instance=self.instance) db = router.db_for_write(self.model, instance=self.instance)
for obj in objs: self.using(db).filter(pk__in=[o.pk for o in objs]).delete()
obj.delete(using=db)
remove.alters_data = True remove.alters_data = True
def clear(self): def clear(self):
db = router.db_for_write(self.model, instance=self.instance) db = router.db_for_write(self.model, instance=self.instance)
for obj in self.all(): self.using(db).delete()
obj.delete(using=db)
clear.alters_data = True clear.alters_data = True
def create(self, **kwargs): def create(self, **kwargs):

View File

@ -1,6 +1,6 @@
from operator import attrgetter from operator import attrgetter
from django.db import connection, connections, router from django.db import connection, connections, router, transaction
from django.db.backends import utils from django.db.backends import utils
from django.db.models import signals from django.db.models import signals
from django.db.models.fields import (AutoField, Field, IntegerField, from django.db.models.fields import (AutoField, Field, IntegerField,
@ -18,7 +18,6 @@ from django import forms
RECURSIVE_RELATIONSHIP_CONSTANT = 'self' RECURSIVE_RELATIONSHIP_CONSTANT = 'self'
def add_lazy_relation(cls, field, relation, operation): def add_lazy_relation(cls, field, relation, operation):
""" """
Adds a lookup on ``cls`` when a related field is defined using a string, Adds a lookup on ``cls`` when a related field is defined using a string,
@ -416,9 +415,14 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
return qs, rel_obj_attr, instance_attr, False, cache_name return qs, rel_obj_attr, instance_attr, False, cache_name
def add(self, *objs): def add(self, *objs):
objs = list(objs)
db = router.db_for_write(self.model, instance=self.instance)
with transaction.commit_on_success_unless_managed(
using=db, savepoint=False):
for obj in objs: for obj in objs:
if not isinstance(obj, self.model): if not isinstance(obj, self.model):
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) raise TypeError("'%s' instance expected, got %r" %
(self.model._meta.object_name, obj))
setattr(obj, rel_field.name, self.instance) setattr(obj, rel_field.name, self.instance)
obj.save() obj.save()
add.alters_data = True add.alters_data = True

View File

@ -2,6 +2,7 @@ from copy import deepcopy
import datetime import datetime
from django.core.exceptions import MultipleObjectsReturned, FieldError from django.core.exceptions import MultipleObjectsReturned, FieldError
from django.db import transaction
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from django.utils.translation import ugettext_lazy from django.utils.translation import ugettext_lazy
@ -68,7 +69,9 @@ class ManyToOneTests(TestCase):
self.assertQuerysetEqual(self.r2.article_set.all(), ["<Article: Paul's story>"]) self.assertQuerysetEqual(self.r2.article_set.all(), ["<Article: Paul's story>"])
# Adding an object of the wrong type raises TypeError. # Adding an object of the wrong type raises TypeError.
with six.assertRaisesRegex(self, TypeError, "'Article' instance expected, got <Reporter.*"): with transaction.atomic():
with six.assertRaisesRegex(self, TypeError,
"'Article' instance expected, got <Reporter.*"):
self.r.article_set.add(self.r2) self.r.article_set.add(self.r2)
self.assertQuerysetEqual(self.r.article_set.all(), self.assertQuerysetEqual(self.r.article_set.all(),
[ [

View File

@ -7,7 +7,7 @@ from operator import attrgetter
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core import management from django.core import management
from django.db import connections, router, DEFAULT_DB_ALIAS from django.db import connections, router, DEFAULT_DB_ALIAS, transaction
from django.db.models import signals from django.db.models import signals
from django.db.utils import ConnectionRouter from django.db.utils import ConnectionRouter
from django.test import TestCase from django.test import TestCase
@ -490,6 +490,7 @@ class QueryTestCase(TestCase):
# Set a foreign key with an object from a different database # Set a foreign key with an object from a different database
try: try:
with transaction.atomic(using='default'):
dive.editor = marty dive.editor = marty
self.fail("Shouldn't be able to assign across databases") self.fail("Shouldn't be able to assign across databases")
except ValueError: except ValueError:
@ -497,6 +498,7 @@ class QueryTestCase(TestCase):
# Set a foreign key set with an object from a different database # Set a foreign key set with an object from a different database
try: try:
with transaction.atomic(using='default'):
marty.edited = [pro, dive] marty.edited = [pro, dive]
self.fail("Shouldn't be able to assign across databases") self.fail("Shouldn't be able to assign across databases")
except ValueError: except ValueError:
@ -504,6 +506,7 @@ class QueryTestCase(TestCase):
# Add to a foreign key set with an object from a different database # Add to a foreign key set with an object from a different database
try: try:
with transaction.atomic(using='default'):
marty.edited.add(dive) marty.edited.add(dive)
self.fail("Shouldn't be able to assign across databases") self.fail("Shouldn't be able to assign across databases")
except ValueError: except ValueError: