diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index fd4d2c83a3..6cd42701d9 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -64,7 +64,7 @@ class Collector(object): self.field_updates = {} # {model: {(field, value): set([instances])}} self.dependencies = {} # {model: set([models])} - def add(self, objs, source=None, nullable=False): + def add(self, objs, source=None, nullable=False, reverse_dependency=False): """ Adds 'objs' to the collection of objects to be deleted. If the call is the result of a cascade, 'source' should be the model that caused it @@ -85,6 +85,8 @@ class Collector(object): # deleting, and therefore do not affect the order in which objects have # to be deleted. if new_objs and source is not None and not nullable: + if reverse_dependency: + source, model = model, source self.dependencies.setdefault(source, set()).add(model) return new_objs @@ -108,7 +110,7 @@ class Collector(object): (field, value), set()).update(objs) def collect(self, objs, source=None, nullable=False, collect_related=True, - source_attr=None): + source_attr=None, reverse_dependency=False): """ Adds 'objs' to the collection of objects to be deleted as well as all parent instances. 'objs' must be a homogenous iterable collection of @@ -118,9 +120,14 @@ class Collector(object): If the call is the result of a cascade, 'source' should be the model that caused it and 'nullable' should be set to True, if the relation can be null. - """ - new_objs = self.add(objs, source, nullable) + If 'reverse_dependency' is True, 'source' will be deleted before the + current model, rather than after. (Needed for cascading to parent + models, the one case in which the cascade follows the forwards + direction of an FK rather than the reverse direction.) + """ + new_objs = self.add(objs, source, nullable, + reverse_dependency=reverse_dependency) if not new_objs: return model = new_objs[0].__class__ @@ -132,7 +139,8 @@ class Collector(object): parent_objs = [getattr(obj, ptr.name) for obj in new_objs] self.collect(parent_objs, source=model, source_attr=ptr.rel.related_name, - collect_related=False) + collect_related=False, + reverse_dependency=True) if collect_related: for related in model._meta.get_all_related_objects(include_hidden=True):