From d35ce682e31ea4a86c2079c60721fae171f03d7c Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Wed, 9 Feb 2022 01:34:45 -0500 Subject: [PATCH] Fixed #33506 -- Made QuerySet.bulk_update() perform atomic writes against write database. The lack of _for_write = True assignment in bulk_update prior to accessing self.db resulted in the db_for_read database being used to wrap batched UPDATEs in a transaction. Also tweaked the batch queryset creation to also ensure they are executed against the same database as the opened transaction under all circumstances. Refs #23646, #33501. --- django/db/models/query.py | 4 +++- tests/queries/test_bulk_update.py | 28 +++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index aa10176dc06..0cebcc70d6a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -725,6 +725,7 @@ class QuerySet: ) # PK is used twice in the resulting update query, once in the filter # and once in the WHEN. Each field will also have one CAST. + self._for_write = True connection = connections[self.db] max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size @@ -746,9 +747,10 @@ class QuerySet: update_kwargs[field.attname] = case_statement updates.append(([obj.pk for obj in batch_objs], update_kwargs)) rows_updated = 0 + queryset = self.using(self.db) with transaction.atomic(using=self.db, savepoint=False): for pks, update_kwargs in updates: - rows_updated += self.filter(pk__in=pks).update(**update_kwargs) + rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs) return rows_updated bulk_update.alters_data = True diff --git a/tests/queries/test_bulk_update.py b/tests/queries/test_bulk_update.py index 389d6c1c41b..bc252c21c62 100644 --- a/tests/queries/test_bulk_update.py +++ b/tests/queries/test_bulk_update.py @@ -3,13 +3,15 @@ import datetime from django.core.exceptions import FieldDoesNotExist from django.db.models import F from django.db.models.functions import Lower -from django.test import TestCase, skipUnlessDBFeature +from django.db.utils import IntegrityError +from django.test import TestCase, override_settings, skipUnlessDBFeature from .models import ( Article, CustomDbColumn, CustomPk, Detail, + Food, Individual, JSONFieldNullable, Member, @@ -25,6 +27,11 @@ from .models import ( ) +class WriteToOtherRouter: + def db_for_write(self, model, **hints): + return "other" + + class BulkUpdateNoteTests(TestCase): @classmethod def setUpTestData(cls): @@ -107,6 +114,8 @@ class BulkUpdateNoteTests(TestCase): class BulkUpdateTests(TestCase): + databases = {"default", "other"} + def test_no_fields(self): msg = "Field names must be given to bulk_update()." with self.assertRaisesMessage(ValueError, msg): @@ -302,3 +311,20 @@ class BulkUpdateTests(TestCase): parent.refresh_from_db() self.assertEqual(parent.f, 42) self.assertIsNone(parent.single) + + @override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()]) + def test_database_routing(self): + note = Note.objects.create(note="create") + note.note = "bulk_update" + with self.assertNumQueries(1, using="other"): + Note.objects.bulk_update([note], fields=["note"]) + + @override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()]) + def test_database_routing_batch_atomicity(self): + f1 = Food.objects.create(name="Banana") + f2 = Food.objects.create(name="Apple") + f1.name = "Kiwi" + f2.name = "Kiwi" + with self.assertRaises(IntegrityError): + Food.objects.bulk_update([f1, f2], fields=["name"], batch_size=1) + self.assertIs(Food.objects.filter(name="Kiwi").exists(), False)