From cd124295d882e13cff556fdeb78e6278d10ac6d5 Mon Sep 17 00:00:00 2001 From: abhiabhi94 <13880786+abhiabhi94@users.noreply.github.com> Date: Sat, 26 Jun 2021 10:48:38 +0530 Subject: [PATCH] Fixed #32381 -- Made QuerySet.bulk_update() return the number of objects updated. Co-authored-by: Diego Lima --- django/db/models/query.py | 6 ++++-- docs/ref/models/querysets.txt | 12 +++++++++++- docs/releases/4.0.txt | 2 ++ tests/queries/test_bulk_update.py | 14 ++++++++++++-- 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 387deca527..f14ff8d094 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -541,7 +541,7 @@ class QuerySet: if any(f.primary_key for f in fields): raise ValueError('bulk_update() cannot be used with primary key fields.') if not objs: - return + return 0 # PK is used twice in the resulting update query, once in the filter # and once in the WHEN. Each field will also have one CAST. max_batch_size = connections[self.db].ops.bulk_batch_size(['pk', 'pk'] + fields, objs) @@ -563,9 +563,11 @@ class QuerySet: case_statement = Cast(case_statement, output_field=field) update_kwargs[field.attname] = case_statement updates.append(([obj.pk for obj in batch_objs], update_kwargs)) + rows_updated = 0 with transaction.atomic(using=self.db, savepoint=False): for pks, update_kwargs in updates: - self.filter(pk__in=pks).update(**update_kwargs) + rows_updated += self.filter(pk__in=pks).update(**update_kwargs) + return rows_updated bulk_update.alters_data = True def get_or_create(self, defaults=None, **kwargs): diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 5dc7a6b5bc..1201800567 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -2221,7 +2221,8 @@ normally supports it). .. method:: bulk_update(objs, fields, batch_size=None) This method efficiently updates the given fields on the provided model -instances, generally with one query:: +instances, generally with one query, and returns the number of objects +updated:: >>> objs = [ ... Entry.objects.create(headline='Entry 1'), @@ -2230,6 +2231,11 @@ instances, generally with one query:: >>> objs[0].headline = 'This is entry 1' >>> objs[1].headline = 'This is entry 2' >>> Entry.objects.bulk_update(objs, ['headline']) + 2 + +.. versionchanged:: 4.0 + + The return value of the number of objects updated was added. :meth:`.QuerySet.update` is used to save the changes, so this is more efficient than iterating through the list of models and calling ``save()`` on each of @@ -2246,6 +2252,10 @@ them, but it has a few caveats: extra query per ancestor. * When an individual batch contains duplicates, only the first instance in that batch will result in an update. +* The number of objects updated returned by the function may be fewer than the + number of objects passed in. This can be due to duplicate objects passed in + which are updated in the same batch or race conditions such that objects are + no longer present in the database. The ``batch_size`` parameter controls how many objects are saved in a single query. The default is to update all objects in one batch, except for SQLite diff --git a/docs/releases/4.0.txt b/docs/releases/4.0.txt index 9b57a524aa..a9b29d7ce4 100644 --- a/docs/releases/4.0.txt +++ b/docs/releases/4.0.txt @@ -263,6 +263,8 @@ Models * :class:`~django.db.models.DurationField` now supports multiplying and dividing by scalar values on SQLite. +* :meth:`.QuerySet.bulk_update` now returns the number of objects updated. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/queries/test_bulk_update.py b/tests/queries/test_bulk_update.py index be794df718..6ca7f201c1 100644 --- a/tests/queries/test_bulk_update.py +++ b/tests/queries/test_bulk_update.py @@ -125,7 +125,8 @@ class BulkUpdateTests(TestCase): def test_empty_objects(self): with self.assertNumQueries(0): - Note.objects.bulk_update([], ['note']) + rows_updated = Note.objects.bulk_update([], ['note']) + self.assertEqual(rows_updated, 0) def test_large_batch(self): Note.objects.bulk_create([ @@ -133,7 +134,16 @@ class BulkUpdateTests(TestCase): for i in range(0, 2000) ]) notes = list(Note.objects.all()) - Note.objects.bulk_update(notes, ['note']) + rows_updated = Note.objects.bulk_update(notes, ['note']) + self.assertEqual(rows_updated, 2000) + + def test_updated_rows_when_passing_duplicates(self): + note = Note.objects.create(note='test-note', misc='test') + rows_updated = Note.objects.bulk_update([note, note], ['note']) + self.assertEqual(rows_updated, 1) + # Duplicates in different batches. + rows_updated = Note.objects.bulk_update([note, note], ['note'], batch_size=1) + self.assertEqual(rows_updated, 2) def test_only_concrete_fields_allowed(self): obj = Valid.objects.create(valid='test')