diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index d3cb181e4c..64d58e20f2 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -265,6 +265,10 @@ class BaseDatabaseFeatures: # INSERT? supports_ignore_conflicts = True + # Does this backend require casting the results of CASE expressions used + # in UPDATE statements to ensure the expression has the correct type? + requires_casted_case_in_updates = False + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 5d6ebc9d15..eddca77239 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -48,6 +48,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): V_I := P_I; END; $$ LANGUAGE plpgsql;""" + requires_casted_case_in_updates = True supports_over_clause = True supports_aggregate_filter_clause = True supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'} diff --git a/django/db/models/query.py b/django/db/models/query.py index 00e505d08e..db1dc998fa 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -18,9 +18,9 @@ from django.db import ( from django.db.models import DateField, DateTimeField, sql from django.db.models.constants import LOOKUP_SEP from django.db.models.deletion import Collector -from django.db.models.expressions import F +from django.db.models.expressions import Case, Expression, F, Value, When from django.db.models.fields import AutoField -from django.db.models.functions import Trunc +from django.db.models.functions import Cast, Trunc from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.utils import NotSupportedError @@ -473,6 +473,50 @@ class QuerySet: return objs + def bulk_update(self, objs, fields, batch_size=None): + """ + Update the given fields in each of the given objects in the database. + """ + if batch_size is not None and batch_size < 0: + raise ValueError('Batch size must be a positive integer.') + if not fields: + raise ValueError('Field names must be given to bulk_update().') + objs = tuple(objs) + if not all(obj.pk for obj in objs): + raise ValueError('All bulk_update() objects must have a primary key set.') + fields = [self.model._meta.get_field(name) for name in fields] + if any(not f.concrete or f.many_to_many for f in fields): + raise ValueError('bulk_update() can only be used with concrete fields.') + if any(f.primary_key for f in fields): + raise ValueError('bulk_update() cannot be used with primary key fields.') + if not objs: + return + # PK is used twice in the resulting update query, once in the filter + # and once in the WHEN. Each field will also have one CAST. + max_batch_size = connections[self.db].ops.bulk_batch_size(['pk', 'pk'] + fields, objs) + batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size + requires_casting = connections[self.db].features.requires_casted_case_in_updates + batches = (objs[i:i + batch_size] for i in range(0, len(objs), batch_size)) + updates = [] + for batch_objs in batches: + update_kwargs = {} + for field in fields: + when_statements = [] + for obj in batch_objs: + attr = getattr(obj, field.attname) + if not isinstance(attr, Expression): + attr = Value(attr, output_field=field) + when_statements.append(When(pk=obj.pk, then=attr)) + case_statement = Case(*when_statements, output_field=field) + if requires_casting: + case_statement = Cast(case_statement, output_field=field) + update_kwargs[field.attname] = case_statement + updates.append(([obj.pk for obj in batch_objs], update_kwargs)) + with transaction.atomic(using=self.db, savepoint=False): + for pks, update_kwargs in updates: + self.filter(pk__in=pks).update(**update_kwargs) + bulk_update.alters_data = True + def get_or_create(self, defaults=None, **kwargs): """ Look up an object with the given kwargs, creating one if necessary. diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index d2e20261a7..e5d178d34e 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -2089,6 +2089,42 @@ instance (if the database normally supports it). The ``ignore_conflicts`` parameter was added. +``bulk_update()`` +~~~~~~~~~~~~~~~~~ + +.. versionadded:: 2.2 + +.. method:: bulk_update(objs, fields, batch_size=None) + +This method efficiently updates the given fields on the provided model +instances, generally with one query:: + + >>> objs = [ + ... Entry.objects.create(headline='Entry 1'), + ... Entry.objects.create(headline='Entry 2'), + ... ] + >>> objs[0].headline = 'This is entry 1' + >>> objs[1].headline = 'This is entry 2' + >>> Entry.objects.bulk_update(objs, ['headline']) + +:meth:`.QuerySet.update` is used to save the changes, so this is more efficient +than iterating through the list of models and calling ``save()`` on each of +them, but it has a few caveats: + +* You cannot update the model's primary key. +* Each model's ``save()`` method isn't called, and the + :attr:`~django.db.models.signals.pre_save` and + :attr:`~django.db.models.signals.post_save` signals aren't sent. +* If updating a large number of columns in a large number of rows, the SQL + generated can be very large. Avoid this by specifying a suitable + ``batch_size``. +* Updating fields defined on multi-table inheritance ancestors will incur an + extra query per ancestor. + +The ``batch_size`` parameter controls how many objects are saved in a single +query. The default is to create all objects in one batch, except for SQLite +and Oracle which have restrictions on the number of variables used in a query. + ``count()`` ~~~~~~~~~~~ diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 8a94ef5e98..c36ebab229 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -199,6 +199,9 @@ Models :class:`~django.db.models.DateTimeField`, and the new :lookup:`iso_year` lookup allows querying by an ISO-8601 week-numbering year. +* The new :meth:`.QuerySet.bulk_update` method allows efficiently updating + specific fields on multiple model instances. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/basic/tests.py b/tests/basic/tests.py index 2ec6ace638..d12322b705 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -532,6 +532,7 @@ class ManagerTest(SimpleTestCase): 'update_or_create', 'create', 'bulk_create', + 'bulk_update', 'filter', 'aggregate', 'annotate', diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 9f4417b58d..0e7ba938ca 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -56,9 +56,9 @@ class Migration(migrations.Migration): name='OtherTypesArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('ips', ArrayField(models.GenericIPAddressField(), size=None)), - ('uuids', ArrayField(models.UUIDField(), size=None)), - ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), + ('ips', ArrayField(models.GenericIPAddressField(), size=None, default=list)), + ('uuids', ArrayField(models.UUIDField(), size=None, default=list)), + ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)), ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), ('json', ArrayField(JSONField(default={}), default=[])), ('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)), diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index cd1646a3e6..841f246c6a 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -63,9 +63,9 @@ class NestedIntegerArrayModel(PostgreSQLModel): class OtherTypesArrayModel(PostgreSQLModel): - ips = ArrayField(models.GenericIPAddressField()) - uuids = ArrayField(models.UUIDField()) - decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) + ips = ArrayField(models.GenericIPAddressField(), default=list) + uuids = ArrayField(models.UUIDField(), default=list) + decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list) tags = ArrayField(TagField(), blank=True, null=True) json = ArrayField(JSONField(default=dict), default=list) int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True) diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py new file mode 100644 index 0000000000..6dd7036a9b --- /dev/null +++ b/tests/postgres_tests/test_bulk_update.py @@ -0,0 +1,34 @@ +from datetime import date + +from . import PostgreSQLTestCase +from .models import ( + HStoreModel, IntegerArrayModel, JSONModel, NestedIntegerArrayModel, + NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel, +) + +try: + from psycopg2.extras import NumericRange, DateRange +except ImportError: + pass # psycopg2 isn't installed. + + +class BulkSaveTests(PostgreSQLTestCase): + def test_bulk_update(self): + test_data = [ + (IntegerArrayModel, 'field', [], [1, 2, 3]), + (NullableIntegerArrayModel, 'field', [1, 2, 3], None), + (JSONModel, 'field', {'a': 'b'}, {'c': 'd'}), + (NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]), + (HStoreModel, 'field', {}, {1: 2}), + (RangesModel, 'ints', None, NumericRange(lower=1, upper=10)), + (RangesModel, 'dates', None, DateRange(lower=date.today(), upper=date.today())), + (OtherTypesArrayModel, 'ips', [], ['1.2.3.4']), + (OtherTypesArrayModel, 'json', [], [{'a': 'b'}]) + ] + for Model, field, initial, new in test_data: + with self.subTest(model=Model, field=field): + instances = Model.objects.bulk_create(Model(**{field: initial}) for _ in range(20)) + for instance in instances: + setattr(instance, field, new) + Model.objects.bulk_update(instances, [field]) + self.assertSequenceEqual(Model.objects.filter(**{field: new}), instances) diff --git a/tests/queries/models.py b/tests/queries/models.py index 587d2e683e..ead8439118 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -718,3 +718,8 @@ class RelatedIndividual(models.Model): class Meta: db_table = 'RelatedIndividual' + + +class CustomDbColumn(models.Model): + custom_column = models.IntegerField(db_column='custom_name', null=True) + ip_address = models.GenericIPAddressField(null=True) diff --git a/tests/queries/test_bulk_update.py b/tests/queries/test_bulk_update.py new file mode 100644 index 0000000000..ab2bda289c --- /dev/null +++ b/tests/queries/test_bulk_update.py @@ -0,0 +1,223 @@ +import datetime + +from django.core.exceptions import FieldDoesNotExist +from django.db.models import F +from django.db.models.functions import Lower +from django.test import TestCase + +from .models import ( + Article, CustomDbColumn, CustomPk, Detail, Individual, Member, Note, + Number, Paragraph, SpecialCategory, Tag, Valid, +) + + +class BulkUpdateNoteTests(TestCase): + def setUp(self): + self.notes = [ + Note.objects.create(note=str(i), misc=str(i)) + for i in range(10) + ] + + def create_tags(self): + self.tags = [ + Tag.objects.create(name=str(i)) + for i in range(10) + ] + + def test_simple(self): + for note in self.notes: + note.note = 'test-%s' % note.id + with self.assertNumQueries(1): + Note.objects.bulk_update(self.notes, ['note']) + self.assertCountEqual( + Note.objects.values_list('note', flat=True), + [cat.note for cat in self.notes] + ) + + def test_multiple_fields(self): + for note in self.notes: + note.note = 'test-%s' % note.id + note.misc = 'misc-%s' % note.id + with self.assertNumQueries(1): + Note.objects.bulk_update(self.notes, ['note', 'misc']) + self.assertCountEqual( + Note.objects.values_list('note', flat=True), + [cat.note for cat in self.notes] + ) + self.assertCountEqual( + Note.objects.values_list('misc', flat=True), + [cat.misc for cat in self.notes] + ) + + def test_batch_size(self): + with self.assertNumQueries(len(self.notes)): + Note.objects.bulk_update(self.notes, fields=['note'], batch_size=1) + + def test_unsaved_models(self): + objs = self.notes + [Note(note='test', misc='test')] + msg = 'All bulk_update() objects must have a primary key set.' + with self.assertRaisesMessage(ValueError, msg): + Note.objects.bulk_update(objs, fields=['note']) + + def test_foreign_keys_do_not_lookup(self): + self.create_tags() + for note, tag in zip(self.notes, self.tags): + note.tag = tag + with self.assertNumQueries(1): + Note.objects.bulk_update(self.notes, ['tag']) + self.assertSequenceEqual(Note.objects.filter(tag__isnull=False), self.notes) + + def test_set_field_to_null(self): + self.create_tags() + Note.objects.update(tag=self.tags[0]) + for note in self.notes: + note.tag = None + Note.objects.bulk_update(self.notes, ['tag']) + self.assertCountEqual(Note.objects.filter(tag__isnull=True), self.notes) + + def test_set_mixed_fields_to_null(self): + self.create_tags() + midpoint = len(self.notes) // 2 + top, bottom = self.notes[:midpoint], self.notes[midpoint:] + for note in top: + note.tag = None + for note in bottom: + note.tag = self.tags[0] + Note.objects.bulk_update(self.notes, ['tag']) + self.assertCountEqual(Note.objects.filter(tag__isnull=True), top) + self.assertCountEqual(Note.objects.filter(tag__isnull=False), bottom) + + def test_functions(self): + Note.objects.update(note='TEST') + for note in self.notes: + note.note = Lower('note') + Note.objects.bulk_update(self.notes, ['note']) + self.assertEqual(set(Note.objects.values_list('note', flat=True)), {'test'}) + + # Tests that use self.notes go here, otherwise put them in another class. + + +class BulkUpdateTests(TestCase): + def test_no_fields(self): + msg = 'Field names must be given to bulk_update().' + with self.assertRaisesMessage(ValueError, msg): + Note.objects.bulk_update([], fields=[]) + + def test_invalid_batch_size(self): + msg = 'Batch size must be a positive integer.' + with self.assertRaisesMessage(ValueError, msg): + Note.objects.bulk_update([], fields=['note'], batch_size=-1) + + def test_nonexistent_field(self): + with self.assertRaisesMessage(FieldDoesNotExist, "Note has no field named 'nonexistent'"): + Note.objects.bulk_update([], ['nonexistent']) + + pk_fields_error = 'bulk_update() cannot be used with primary key fields.' + + def test_update_primary_key(self): + with self.assertRaisesMessage(ValueError, self.pk_fields_error): + Note.objects.bulk_update([], ['id']) + + def test_update_custom_primary_key(self): + with self.assertRaisesMessage(ValueError, self.pk_fields_error): + CustomPk.objects.bulk_update([], ['name']) + + def test_empty_objects(self): + with self.assertNumQueries(0): + Note.objects.bulk_update([], ['note']) + + def test_large_batch(self): + Note.objects.bulk_create([ + Note(note=str(i), misc=str(i)) + for i in range(0, 2000) + ]) + notes = list(Note.objects.all()) + Note.objects.bulk_update(notes, ['note']) + + def test_only_concrete_fields_allowed(self): + obj = Valid.objects.create(valid='test') + detail = Detail.objects.create(data='test') + paragraph = Paragraph.objects.create(text='test') + Member.objects.create(name='test', details=detail) + msg = 'bulk_update() can only be used with concrete fields.' + with self.assertRaisesMessage(ValueError, msg): + Detail.objects.bulk_update([detail], fields=['member']) + with self.assertRaisesMessage(ValueError, msg): + Paragraph.objects.bulk_update([paragraph], fields=['page']) + with self.assertRaisesMessage(ValueError, msg): + Valid.objects.bulk_update([obj], fields=['parent']) + + def test_custom_db_columns(self): + model = CustomDbColumn.objects.create(custom_column=1) + model.custom_column = 2 + CustomDbColumn.objects.bulk_update([model], fields=['custom_column']) + model.refresh_from_db() + self.assertEqual(model.custom_column, 2) + + def test_custom_pk(self): + custom_pks = [ + CustomPk.objects.create(name='pk-%s' % i, extra='') + for i in range(10) + ] + for model in custom_pks: + model.extra = 'extra-%s' % model.pk + CustomPk.objects.bulk_update(custom_pks, ['extra']) + self.assertCountEqual( + CustomPk.objects.values_list('extra', flat=True), + [cat.extra for cat in custom_pks] + ) + + def test_inherited_fields(self): + special_categories = [ + SpecialCategory.objects.create(name=str(i), special_name=str(i)) + for i in range(10) + ] + for category in special_categories: + category.name = 'test-%s' % category.id + category.special_name = 'special-test-%s' % category.special_name + SpecialCategory.objects.bulk_update(special_categories, ['name', 'special_name']) + self.assertCountEqual( + SpecialCategory.objects.values_list('name', flat=True), + [cat.name for cat in special_categories] + ) + self.assertCountEqual( + SpecialCategory.objects.values_list('special_name', flat=True), + [cat.special_name for cat in special_categories] + ) + + def test_field_references(self): + numbers = [Number.objects.create(num=0) for _ in range(10)] + for number in numbers: + number.num = F('num') + 1 + Number.objects.bulk_update(numbers, ['num']) + self.assertCountEqual(Number.objects.filter(num=1), numbers) + + def test_booleanfield(self): + individuals = [Individual.objects.create(alive=False) for _ in range(10)] + for individual in individuals: + individual.alive = True + Individual.objects.bulk_update(individuals, ['alive']) + self.assertCountEqual(Individual.objects.filter(alive=True), individuals) + + def test_ipaddressfield(self): + for ip in ('2001::1', '1.2.3.4'): + with self.subTest(ip=ip): + models = [ + CustomDbColumn.objects.create(ip_address='0.0.0.0') + for _ in range(10) + ] + for model in models: + model.ip_address = ip + CustomDbColumn.objects.bulk_update(models, ['ip_address']) + self.assertCountEqual(CustomDbColumn.objects.filter(ip_address=ip), models) + + def test_datetime_field(self): + articles = [ + Article.objects.create(name=str(i), created=datetime.datetime.today()) + for i in range(10) + ] + point_in_time = datetime.datetime(1991, 10, 31) + for article in articles: + article.created = point_in_time + Article.objects.bulk_update(articles, ['created']) + self.assertCountEqual(Article.objects.filter(created=point_in_time), articles)