mirror of https://github.com/django/django.git
Fixed #33506 -- Made QuerySet.bulk_update() perform atomic writes against write database.
The lack of _for_write = True assignment in bulk_update prior to accessing self.db resulted in the db_for_read database being used to wrap batched UPDATEs in a transaction. Also tweaked the batch queryset creation to also ensure they are executed against the same database as the opened transaction under all circumstances. Refs #23646, #33501.
This commit is contained in:
parent
d70b4bea18
commit
d35ce682e3
|
@ -725,6 +725,7 @@ class QuerySet:
|
||||||
)
|
)
|
||||||
# PK is used twice in the resulting update query, once in the filter
|
# PK is used twice in the resulting update query, once in the filter
|
||||||
# and once in the WHEN. Each field will also have one CAST.
|
# and once in the WHEN. Each field will also have one CAST.
|
||||||
|
self._for_write = True
|
||||||
connection = connections[self.db]
|
connection = connections[self.db]
|
||||||
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
|
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
|
||||||
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
|
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
|
||||||
|
@ -746,9 +747,10 @@ class QuerySet:
|
||||||
update_kwargs[field.attname] = case_statement
|
update_kwargs[field.attname] = case_statement
|
||||||
updates.append(([obj.pk for obj in batch_objs], update_kwargs))
|
updates.append(([obj.pk for obj in batch_objs], update_kwargs))
|
||||||
rows_updated = 0
|
rows_updated = 0
|
||||||
|
queryset = self.using(self.db)
|
||||||
with transaction.atomic(using=self.db, savepoint=False):
|
with transaction.atomic(using=self.db, savepoint=False):
|
||||||
for pks, update_kwargs in updates:
|
for pks, update_kwargs in updates:
|
||||||
rows_updated += self.filter(pk__in=pks).update(**update_kwargs)
|
rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
|
||||||
return rows_updated
|
return rows_updated
|
||||||
|
|
||||||
bulk_update.alters_data = True
|
bulk_update.alters_data = True
|
||||||
|
|
|
@ -3,13 +3,15 @@ import datetime
|
||||||
from django.core.exceptions import FieldDoesNotExist
|
from django.core.exceptions import FieldDoesNotExist
|
||||||
from django.db.models import F
|
from django.db.models import F
|
||||||
from django.db.models.functions import Lower
|
from django.db.models.functions import Lower
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.db.utils import IntegrityError
|
||||||
|
from django.test import TestCase, override_settings, skipUnlessDBFeature
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
Article,
|
Article,
|
||||||
CustomDbColumn,
|
CustomDbColumn,
|
||||||
CustomPk,
|
CustomPk,
|
||||||
Detail,
|
Detail,
|
||||||
|
Food,
|
||||||
Individual,
|
Individual,
|
||||||
JSONFieldNullable,
|
JSONFieldNullable,
|
||||||
Member,
|
Member,
|
||||||
|
@ -25,6 +27,11 @@ from .models import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteToOtherRouter:
|
||||||
|
def db_for_write(self, model, **hints):
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
|
||||||
class BulkUpdateNoteTests(TestCase):
|
class BulkUpdateNoteTests(TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpTestData(cls):
|
def setUpTestData(cls):
|
||||||
|
@ -107,6 +114,8 @@ class BulkUpdateNoteTests(TestCase):
|
||||||
|
|
||||||
|
|
||||||
class BulkUpdateTests(TestCase):
|
class BulkUpdateTests(TestCase):
|
||||||
|
databases = {"default", "other"}
|
||||||
|
|
||||||
def test_no_fields(self):
|
def test_no_fields(self):
|
||||||
msg = "Field names must be given to bulk_update()."
|
msg = "Field names must be given to bulk_update()."
|
||||||
with self.assertRaisesMessage(ValueError, msg):
|
with self.assertRaisesMessage(ValueError, msg):
|
||||||
|
@ -302,3 +311,20 @@ class BulkUpdateTests(TestCase):
|
||||||
parent.refresh_from_db()
|
parent.refresh_from_db()
|
||||||
self.assertEqual(parent.f, 42)
|
self.assertEqual(parent.f, 42)
|
||||||
self.assertIsNone(parent.single)
|
self.assertIsNone(parent.single)
|
||||||
|
|
||||||
|
@override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()])
|
||||||
|
def test_database_routing(self):
|
||||||
|
note = Note.objects.create(note="create")
|
||||||
|
note.note = "bulk_update"
|
||||||
|
with self.assertNumQueries(1, using="other"):
|
||||||
|
Note.objects.bulk_update([note], fields=["note"])
|
||||||
|
|
||||||
|
@override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()])
|
||||||
|
def test_database_routing_batch_atomicity(self):
|
||||||
|
f1 = Food.objects.create(name="Banana")
|
||||||
|
f2 = Food.objects.create(name="Apple")
|
||||||
|
f1.name = "Kiwi"
|
||||||
|
f2.name = "Kiwi"
|
||||||
|
with self.assertRaises(IntegrityError):
|
||||||
|
Food.objects.bulk_update([f1, f2], fields=["name"], batch_size=1)
|
||||||
|
self.assertIs(Food.objects.filter(name="Kiwi").exists(), False)
|
||||||
|
|
Loading…
Reference in New Issue