diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index 631b455aa35..bc26d82e934 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -1,9 +1,9 @@ from collections import Counter, defaultdict -from functools import partial +from functools import partial, reduce from itertools import chain -from operator import attrgetter +from operator import attrgetter, or_ -from django.db import IntegrityError, connections, transaction +from django.db import IntegrityError, connections, models, transaction from django.db.models import query_utils, signals, sql @@ -61,6 +61,7 @@ def SET(value): collector.add_field_update(field, value, sub_objs) set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {}) + set_on_delete.lazy_sub_objs = True return set_on_delete @@ -68,10 +69,16 @@ def SET_NULL(collector, field, sub_objs, using): collector.add_field_update(field, None, sub_objs) +SET_NULL.lazy_sub_objs = True + + def SET_DEFAULT(collector, field, sub_objs, using): collector.add_field_update(field, field.get_default(), sub_objs) +SET_DEFAULT.lazy_sub_objs = True + + def DO_NOTHING(collector, field, sub_objs, using): pass @@ -93,8 +100,8 @@ class Collector: self.origin = origin # Initially, {model: {instances}}, later values become lists. self.data = defaultdict(set) - # {model: {(field, value): {instances}}} - self.field_updates = defaultdict(partial(defaultdict, set)) + # {(field, value): [instances, …]} + self.field_updates = defaultdict(list) # {model: {field: {instances}}} self.restricted_objects = defaultdict(partial(defaultdict, set)) # fast_deletes is a list of queryset-likes that can be deleted without @@ -145,10 +152,7 @@ class Collector: Schedule a field update. 'objs' must be a homogeneous iterable collection of model instances (e.g. a QuerySet). """ - if not objs: - return - model = objs[0].__class__ - self.field_updates[model][field, value].update(objs) + self.field_updates[field, value].append(objs) def add_restricted_objects(self, field, objs): if objs: @@ -312,7 +316,8 @@ class Collector: if keep_parents and related.model in parents: continue field = related.field - if field.remote_field.on_delete == DO_NOTHING: + on_delete = field.remote_field.on_delete + if on_delete == DO_NOTHING: continue related_model = related.related_model if self.can_fast_delete(related_model, from_field=field): @@ -340,9 +345,9 @@ class Collector: ) ) sub_objs = sub_objs.only(*tuple(referenced_fields)) - if sub_objs: + if getattr(on_delete, "lazy_sub_objs", False) or sub_objs: try: - field.remote_field.on_delete(self, field, sub_objs, self.using) + on_delete(self, field, sub_objs, self.using) except ProtectedError as error: key = "'%s.%s'" % (field.model.__name__, field.name) protected_objects[key] += error.protected_objects @@ -469,11 +474,25 @@ class Collector: deleted_counter[qs.model._meta.label] += count # update fields - for model, instances_for_fieldvalues in self.field_updates.items(): - for (field, value), instances in instances_for_fieldvalues.items(): + for (field, value), instances_list in self.field_updates.items(): + updates = [] + objs = [] + for instances in instances_list: + if ( + isinstance(instances, models.QuerySet) + and instances._result_cache is None + ): + updates.append(instances) + else: + objs.extend(instances) + if updates: + combined_updates = reduce(or_, updates) + combined_updates.update(**{field.name: value}) + if objs: + model = objs[0].__class__ query = sql.UpdateQuery(model) query.update_batch( - [obj.pk for obj in instances], {field.name: value}, self.using + list({obj.pk for obj in objs}), {field.name: value}, self.using ) # reverse instance collections diff --git a/tests/delete_regress/models.py b/tests/delete_regress/models.py index dc8658e6c40..cbe6fef3343 100644 --- a/tests/delete_regress/models.py +++ b/tests/delete_regress/models.py @@ -90,6 +90,12 @@ class Location(models.Model): class Item(models.Model): version = models.ForeignKey(Version, models.CASCADE) location = models.ForeignKey(Location, models.SET_NULL, blank=True, null=True) + location_value = models.ForeignKey( + Location, models.SET(42), default=1, db_constraint=False, related_name="+" + ) + location_default = models.ForeignKey( + Location, models.SET_DEFAULT, default=1, db_constraint=False, related_name="+" + ) # Models for #16128 diff --git a/tests/delete_regress/tests.py b/tests/delete_regress/tests.py index c9d0ff8d0ab..7dccbf555d5 100644 --- a/tests/delete_regress/tests.py +++ b/tests/delete_regress/tests.py @@ -399,3 +399,19 @@ class DeleteDistinct(SimpleTestCase): Book.objects.distinct().delete() with self.assertRaisesMessage(TypeError, msg): Book.objects.distinct("id").delete() + + +class SetQueryCountTests(TestCase): + def test_set_querycount(self): + policy = Policy.objects.create() + version = Version.objects.create(policy=policy) + location = Location.objects.create(version=version) + Item.objects.create( + version=version, + location=location, + location_default=location, + location_value=location, + ) + # 3 UPDATEs for SET of item values and one for DELETE locations. + with self.assertNumQueries(4): + location.delete()