diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index ae82d263401..02dfc43e406 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -1,4 +1,5 @@ -from collections import Counter +from collections import Counter, defaultdict +from functools import partial from itertools import chain from operator import attrgetter @@ -65,8 +66,9 @@ class Collector: def __init__(self, using): self.using = using # Initially, {model: {instances}}, later values become lists. - self.data = {} - self.field_updates = {} # {model: {(field, value): {instances}}} + self.data = defaultdict(set) + # {model: {(field, value): {instances}}} + self.field_updates = defaultdict(partial(defaultdict, set)) # fast_deletes is a list of queryset-likes that can be deleted without # fetching the objects into memory. self.fast_deletes = [] @@ -76,7 +78,7 @@ class Collector: # should be included, as the dependencies exist only between actual # database tables; proxy models are represented here by their concrete # parent. - self.dependencies = {} # {model: {models}} + self.dependencies = defaultdict(set) # {model: {models}} def add(self, objs, source=None, nullable=False, reverse_dependency=False): """ @@ -90,7 +92,7 @@ class Collector: return [] new_objs = [] model = objs[0].__class__ - instances = self.data.setdefault(model, set()) + instances = self.data[model] for obj in objs: if obj not in instances: new_objs.append(obj) @@ -101,8 +103,7 @@ class Collector: if source is not None and not nullable: if reverse_dependency: source, model = model, source - self.dependencies.setdefault( - source._meta.concrete_model, set()).add(model._meta.concrete_model) + self.dependencies[source._meta.concrete_model].add(model._meta.concrete_model) return new_objs def add_field_update(self, field, value, objs): @@ -113,9 +114,7 @@ class Collector: if not objs: return model = objs[0].__class__ - self.field_updates.setdefault( - model, {}).setdefault( - (field, value), set()).update(objs) + self.field_updates[model][field, value].update(objs) def _has_signal_listeners(self, model): return (