Fixed #33506 -- Made QuerySet.bulk_update() perform atomic writes against write database.

The lack of _for_write = True assignment in bulk_update prior to
accessing self.db resulted in the db_for_read database being used to
wrap batched UPDATEs in a transaction.

Also tweaked the batch queryset creation to also ensure they are
executed against the same database as the opened transaction under all
circumstances.

Refs #23646, #33501.
This commit is contained in:
Simon Charette 2022-02-09 01:34:45 -05:00 committed by Mariusz Felisiak
parent d70b4bea18
commit d35ce682e3
2 changed files with 30 additions and 2 deletions

View File

@ -725,6 +725,7 @@ class QuerySet:
) )
# PK is used twice in the resulting update query, once in the filter # PK is used twice in the resulting update query, once in the filter
# and once in the WHEN. Each field will also have one CAST. # and once in the WHEN. Each field will also have one CAST.
self._for_write = True
connection = connections[self.db] connection = connections[self.db]
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
@ -746,9 +747,10 @@ class QuerySet:
update_kwargs[field.attname] = case_statement update_kwargs[field.attname] = case_statement
updates.append(([obj.pk for obj in batch_objs], update_kwargs)) updates.append(([obj.pk for obj in batch_objs], update_kwargs))
rows_updated = 0 rows_updated = 0
queryset = self.using(self.db)
with transaction.atomic(using=self.db, savepoint=False): with transaction.atomic(using=self.db, savepoint=False):
for pks, update_kwargs in updates: for pks, update_kwargs in updates:
rows_updated += self.filter(pk__in=pks).update(**update_kwargs) rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
return rows_updated return rows_updated
bulk_update.alters_data = True bulk_update.alters_data = True

View File

@ -3,13 +3,15 @@ import datetime
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist
from django.db.models import F from django.db.models import F
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.test import TestCase, skipUnlessDBFeature from django.db.utils import IntegrityError
from django.test import TestCase, override_settings, skipUnlessDBFeature
from .models import ( from .models import (
Article, Article,
CustomDbColumn, CustomDbColumn,
CustomPk, CustomPk,
Detail, Detail,
Food,
Individual, Individual,
JSONFieldNullable, JSONFieldNullable,
Member, Member,
@ -25,6 +27,11 @@ from .models import (
) )
class WriteToOtherRouter:
def db_for_write(self, model, **hints):
return "other"
class BulkUpdateNoteTests(TestCase): class BulkUpdateNoteTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -107,6 +114,8 @@ class BulkUpdateNoteTests(TestCase):
class BulkUpdateTests(TestCase): class BulkUpdateTests(TestCase):
databases = {"default", "other"}
def test_no_fields(self): def test_no_fields(self):
msg = "Field names must be given to bulk_update()." msg = "Field names must be given to bulk_update()."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
@ -302,3 +311,20 @@ class BulkUpdateTests(TestCase):
parent.refresh_from_db() parent.refresh_from_db()
self.assertEqual(parent.f, 42) self.assertEqual(parent.f, 42)
self.assertIsNone(parent.single) self.assertIsNone(parent.single)
@override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()])
def test_database_routing(self):
note = Note.objects.create(note="create")
note.note = "bulk_update"
with self.assertNumQueries(1, using="other"):
Note.objects.bulk_update([note], fields=["note"])
@override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()])
def test_database_routing_batch_atomicity(self):
f1 = Food.objects.create(name="Banana")
f2 = Food.objects.create(name="Apple")
f1.name = "Kiwi"
f2.name = "Kiwi"
with self.assertRaises(IntegrityError):
Food.objects.bulk_update([f1, f2], fields=["name"], batch_size=1)
self.assertIs(Food.objects.filter(name="Kiwi").exists(), False)