Fixed #7539, #13067 -- Added on_delete argument to ForeignKey to control cascade behavior. Also refactored deletion for efficiency and code clarity. Many thanks to Johannes Dollinger and Michael Glassford for extensive work on the patch, and to Alex Gaynor, Russell Keith-Magee, and Jacob Kaplan-Moss for review.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14507 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Carl Meyer 2010-11-09 16:46:42 +00:00
parent 3ba3294c6b
commit 616b30227d
28 changed files with 850 additions and 608 deletions

View File

@ -6,6 +6,7 @@ from django import template
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.contrib.admin import helpers from django.contrib.admin import helpers
from django.contrib.admin.util import get_deleted_objects, model_ngettext from django.contrib.admin.util import get_deleted_objects, model_ngettext
from django.db import router
from django.shortcuts import render_to_response from django.shortcuts import render_to_response
from django.utils.encoding import force_unicode from django.utils.encoding import force_unicode
from django.utils.translation import ugettext_lazy, ugettext as _ from django.utils.translation import ugettext_lazy, ugettext as _
@ -27,9 +28,12 @@ def delete_selected(modeladmin, request, queryset):
if not modeladmin.has_delete_permission(request): if not modeladmin.has_delete_permission(request):
raise PermissionDenied raise PermissionDenied
using = router.db_for_write(modeladmin.model)
# Populate deletable_objects, a data structure of all related objects that # Populate deletable_objects, a data structure of all related objects that
# will also be deleted. # will also be deleted.
deletable_objects, perms_needed = get_deleted_objects(queryset, opts, request.user, modeladmin.admin_site) deletable_objects, perms_needed = get_deleted_objects(
queryset, opts, request.user, modeladmin.admin_site, using)
# The user has already confirmed the deletion. # The user has already confirmed the deletion.
# Do the deletion and return a None to display the change list view again. # Do the deletion and return a None to display the change list view again.

View File

@ -9,7 +9,7 @@ from django.contrib.admin.util import unquote, flatten_fieldsets, get_deleted_ob
from django.contrib import messages from django.contrib import messages
from django.views.decorators.csrf import csrf_protect from django.views.decorators.csrf import csrf_protect
from django.core.exceptions import PermissionDenied, ValidationError from django.core.exceptions import PermissionDenied, ValidationError
from django.db import models, transaction from django.db import models, transaction, router
from django.db.models.fields import BLANK_CHOICE_DASH from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import Http404, HttpResponse, HttpResponseRedirect from django.http import Http404, HttpResponse, HttpResponseRedirect
from django.shortcuts import get_object_or_404, render_to_response from django.shortcuts import get_object_or_404, render_to_response
@ -1110,9 +1110,12 @@ class ModelAdmin(BaseModelAdmin):
if obj is None: if obj is None:
raise Http404(_('%(name)s object with primary key %(key)r does not exist.') % {'name': force_unicode(opts.verbose_name), 'key': escape(object_id)}) raise Http404(_('%(name)s object with primary key %(key)r does not exist.') % {'name': force_unicode(opts.verbose_name), 'key': escape(object_id)})
using = router.db_for_write(self.model)
# Populate deleted_objects, a data structure of all related objects that # Populate deleted_objects, a data structure of all related objects that
# will also be deleted. # will also be deleted.
(deleted_objects, perms_needed) = get_deleted_objects((obj,), opts, request.user, self.admin_site) (deleted_objects, perms_needed) = get_deleted_objects(
[obj], opts, request.user, self.admin_site, using)
if request.POST: # The user has already confirmed the deletion. if request.POST: # The user has already confirmed the deletion.
if perms_needed: if perms_needed:

View File

@ -1,4 +1,5 @@
from django.db import models from django.db import models
from django.db.models.deletion import Collector
from django.db.models.related import RelatedObject from django.db.models.related import RelatedObject
from django.forms.forms import pretty_name from django.forms.forms import pretty_name
from django.utils import formats from django.utils import formats
@ -10,6 +11,7 @@ from django.utils.translation import ungettext
from django.core.urlresolvers import reverse, NoReverseMatch from django.core.urlresolvers import reverse, NoReverseMatch
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
def quote(s): def quote(s):
""" """
Ensure that primary key values do not confuse the admin URLs by escaping Ensure that primary key values do not confuse the admin URLs by escaping
@ -26,6 +28,7 @@ def quote(s):
res[i] = '_%02X' % ord(c) res[i] = '_%02X' % ord(c)
return ''.join(res) return ''.join(res)
def unquote(s): def unquote(s):
""" """
Undo the effects of quote(). Based heavily on urllib.unquote(). Undo the effects of quote(). Based heavily on urllib.unquote().
@ -46,6 +49,7 @@ def unquote(s):
myappend('_' + item) myappend('_' + item)
return "".join(res) return "".join(res)
def flatten_fieldsets(fieldsets): def flatten_fieldsets(fieldsets):
"""Returns a list of field names from an admin fieldsets structure.""" """Returns a list of field names from an admin fieldsets structure."""
field_names = [] field_names = []
@ -58,144 +62,94 @@ def flatten_fieldsets(fieldsets):
field_names.append(field) field_names.append(field)
return field_names return field_names
def _format_callback(obj, user, admin_site, perms_needed):
has_admin = obj.__class__ in admin_site._registry
opts = obj._meta
if has_admin:
admin_url = reverse('%s:%s_%s_change'
% (admin_site.name,
opts.app_label,
opts.object_name.lower()),
None, (quote(obj._get_pk_val()),))
p = '%s.%s' % (opts.app_label,
opts.get_delete_permission())
if not user.has_perm(p):
perms_needed.add(opts.verbose_name)
# Display a link to the admin page.
return mark_safe(u'%s: <a href="%s">%s</a>' %
(escape(capfirst(opts.verbose_name)),
admin_url,
escape(obj)))
else:
# Don't display link to edit, because it either has no
# admin or is edited inline.
return u'%s: %s' % (capfirst(opts.verbose_name),
force_unicode(obj))
def get_deleted_objects(objs, opts, user, admin_site): def get_deleted_objects(objs, opts, user, admin_site, using):
""" """
Find all objects related to ``objs`` that should also be Find all objects related to ``objs`` that should also be deleted. ``objs``
deleted. ``objs`` should be an iterable of objects. must be a homogenous iterable of objects (e.g. a QuerySet).
Returns a nested list of strings suitable for display in the Returns a nested list of strings suitable for display in the
template with the ``unordered_list`` filter. template with the ``unordered_list`` filter.
""" """
collector = NestedObjects() collector = NestedObjects(using=using)
for obj in objs: collector.collect(objs)
# TODO using a private model API!
obj._collect_sub_objects(collector)
perms_needed = set() perms_needed = set()
to_delete = collector.nested(_format_callback, def format_callback(obj):
user=user, has_admin = obj.__class__ in admin_site._registry
admin_site=admin_site, opts = obj._meta
perms_needed=perms_needed)
if has_admin:
admin_url = reverse('%s:%s_%s_change'
% (admin_site.name,
opts.app_label,
opts.object_name.lower()),
None, (quote(obj._get_pk_val()),))
p = '%s.%s' % (opts.app_label,
opts.get_delete_permission())
if not user.has_perm(p):
perms_needed.add(opts.verbose_name)
# Display a link to the admin page.
return mark_safe(u'%s: <a href="%s">%s</a>' %
(escape(capfirst(opts.verbose_name)),
admin_url,
escape(obj)))
else:
# Don't display link to edit, because it either has no
# admin or is edited inline.
return u'%s: %s' % (capfirst(opts.verbose_name),
force_unicode(obj))
to_delete = collector.nested(format_callback)
return to_delete, perms_needed return to_delete, perms_needed
class NestedObjects(object): class NestedObjects(Collector):
""" def __init__(self, *args, **kwargs):
A directed acyclic graph collection that exposes the add() API super(NestedObjects, self).__init__(*args, **kwargs)
expected by Model._collect_sub_objects and can present its data as self.edges = {} # {from_instance: [to_instances]}
a nested list of objects.
""" def add_edge(self, source, target):
def __init__(self): self.edges.setdefault(source, []).append(target)
# Use object keys of the form (model, pk) because actual model
# objects may not be unique
# maps object key to list of child keys def collect(self, objs, source_attr=None, **kwargs):
self.children = SortedDict() for obj in objs:
if source_attr:
self.add_edge(getattr(obj, source_attr), obj)
else:
self.add_edge(None, obj)
return super(NestedObjects, self).collect(objs, source_attr=source_attr, **kwargs)
# maps object key to parent key def related_objects(self, related, objs):
self.parents = SortedDict() qs = super(NestedObjects, self).related_objects(related, objs)
return qs.select_related(related.field.name)
# maps object key to actual object def _nested(self, obj, seen, format_callback):
self.seen = SortedDict() if obj in seen:
return []
def add(self, model, pk, obj, seen.add(obj)
parent_model=None, parent_obj=None, nullable=False): children = []
""" for child in self.edges.get(obj, ()):
Add item ``obj`` to the graph. Returns True (and does nothing) children.extend(self._nested(child, seen, format_callback))
if the item has been seen already.
The ``parent_obj`` argument must already exist in the graph; if
not, it's ignored (but ``obj`` is still added with no
parent). In any case, Model._collect_sub_objects (for whom
this API exists) will never pass a parent that hasn't already
been added itself.
These restrictions in combination ensure the graph will remain
acyclic (but can have multiple roots).
``model``, ``pk``, and ``parent_model`` arguments are ignored
in favor of the appropriate lookups on ``obj`` and
``parent_obj``; unlike CollectedObjects, we can't maintain
independence from the knowledge that we're operating on model
instances, and we don't want to allow for inconsistency.
``nullable`` arg is ignored: it doesn't affect how the tree of
collected objects should be nested for display.
"""
model, pk = type(obj), obj._get_pk_val()
# auto-created M2M models don't interest us
if model._meta.auto_created:
return True
key = model, pk
if key in self.seen:
return True
self.seen.setdefault(key, obj)
if parent_obj is not None:
parent_model, parent_pk = (type(parent_obj),
parent_obj._get_pk_val())
parent_key = (parent_model, parent_pk)
if parent_key in self.seen:
self.children.setdefault(parent_key, list()).append(key)
self.parents.setdefault(key, parent_key)
def _nested(self, key, format_callback=None, **kwargs):
obj = self.seen[key]
if format_callback: if format_callback:
ret = [format_callback(obj, **kwargs)] ret = [format_callback(obj)]
else: else:
ret = [obj] ret = [obj]
children = []
for child in self.children.get(key, ()):
children.extend(self._nested(child, format_callback, **kwargs))
if children: if children:
ret.append(children) ret.append(children)
return ret return ret
def nested(self, format_callback=None, **kwargs): def nested(self, format_callback=None):
""" """
Return the graph as a nested list. Return the graph as a nested list.
Passes **kwargs back to the format_callback as kwargs.
""" """
seen = set()
roots = [] roots = []
for key in self.seen.keys(): for root in self.edges.get(None, ()):
if key not in self.parents: roots.extend(self._nested(root, seen, format_callback))
roots.extend(self._nested(key, format_callback, **kwargs))
return roots return roots
@ -218,6 +172,7 @@ def model_format_dict(obj):
'verbose_name_plural': force_unicode(opts.verbose_name_plural) 'verbose_name_plural': force_unicode(opts.verbose_name_plural)
} }
def model_ngettext(obj, n=None): def model_ngettext(obj, n=None):
""" """
Return the appropriate `verbose_name` or `verbose_name_plural` value for Return the appropriate `verbose_name` or `verbose_name_plural` value for
@ -236,6 +191,7 @@ def model_ngettext(obj, n=None):
singular, plural = d["verbose_name"], d["verbose_name_plural"] singular, plural = d["verbose_name"], d["verbose_name_plural"]
return ungettext(singular, plural, n or 0) return ungettext(singular, plural, n or 0)
def lookup_field(name, obj, model_admin=None): def lookup_field(name, obj, model_admin=None):
opts = obj._meta opts = obj._meta
try: try:
@ -262,6 +218,7 @@ def lookup_field(name, obj, model_admin=None):
value = getattr(obj, name) value = getattr(obj, name)
return f, attr, value return f, attr, value
def label_for_field(name, model, model_admin=None, return_attr=False): def label_for_field(name, model, model_admin=None, return_attr=False):
attr = None attr = None
try: try:

View File

@ -5,7 +5,7 @@ Classes allowing "generic" relations through ContentType and object-id fields.
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db import connection from django.db import connection
from django.db.models import signals from django.db.models import signals
from django.db import models, router from django.db import models, router, DEFAULT_DB_ALIAS
from django.db.models.fields.related import RelatedField, Field, ManyToManyRel from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
from django.db.models.loading import get_model from django.db.models.loading import get_model
from django.forms import ModelForm from django.forms import ModelForm
@ -13,6 +13,9 @@ from django.forms.models import BaseModelFormSet, modelformset_factory, save_ins
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
from django.utils.encoding import smart_unicode from django.utils.encoding import smart_unicode
from django.contrib.contenttypes.models import ContentType
class GenericForeignKey(object): class GenericForeignKey(object):
""" """
Provides a generic relation to any object through content-type/object-id Provides a generic relation to any object through content-type/object-id
@ -167,6 +170,19 @@ class GenericRelation(RelatedField, Field):
return [("%s__%s" % (prefix, self.content_type_field_name), return [("%s__%s" % (prefix, self.content_type_field_name),
content_type)] content_type)]
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
Return all objects related to ``objs`` via this ``GenericRelation``.
"""
return self.rel.to._base_manager.db_manager(using).filter(**{
"%s__pk" % self.content_type_field_name:
ContentType.objects.db_manager(using).get_for_model(self.model).pk,
"%s__in" % self.object_id_field_name:
[obj.pk for obj in objs]
})
class ReverseGenericRelatedObjectsDescriptor(object): class ReverseGenericRelatedObjectsDescriptor(object):
""" """
This class provides the functionality that makes the related-object This class provides the functionality that makes the related-object

View File

@ -22,6 +22,7 @@ def get_validation_errors(outfile, app=None):
from django.db import models, connection from django.db import models, connection
from django.db.models.loading import get_app_errors from django.db.models.loading import get_app_errors
from django.db.models.fields.related import RelatedObject from django.db.models.fields.related import RelatedObject
from django.db.models.deletion import SET_NULL, SET_DEFAULT
e = ModelErrorCollection(outfile) e = ModelErrorCollection(outfile)
@ -85,6 +86,13 @@ def get_validation_errors(outfile, app=None):
# Perform any backend-specific field validation. # Perform any backend-specific field validation.
connection.validation.validate_field(e, opts, f) connection.validation.validate_field(e, opts, f)
# Check if the on_delete behavior is sane
if f.rel and hasattr(f.rel, 'on_delete'):
if f.rel.on_delete == SET_NULL and not f.null:
e.add(opts, "'%s' specifies on_delete=SET_NULL, but cannot be null." % f.name)
elif f.rel.on_delete == SET_DEFAULT and not f.has_default():
e.add(opts, "'%s' specifies on_delete=SET_DEFAULT, but has no default value." % f.name)
# Check to see if the related field will clash with any existing # Check to see if the related field will clash with any existing
# fields, m2m fields, m2m related objects or related objects # fields, m2m fields, m2m related objects or related objects
if f.rel: if f.rel:

View File

@ -150,6 +150,10 @@ class BaseDatabaseFeatures(object):
# Can an object have a primary key of 0? MySQL says No. # Can an object have a primary key of 0? MySQL says No.
allows_primary_key_0 = True allows_primary_key_0 = True
# Do we need to NULL a ForeignKey out, or can the constraint check be
# deferred
can_defer_constraint_checks = False
# Features that need to be confirmed at runtime # Features that need to be confirmed at runtime
# Cache whether the confirmation has been performed. # Cache whether the confirmation has been performed.
_confirmed = False _confirmed = False

View File

@ -53,6 +53,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_subqueries_in_group_by = True supports_subqueries_in_group_by = True
supports_timezones = False supports_timezones = False
supports_bitwise_or = False supports_bitwise_or = False
can_defer_constraint_checks = True
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.oracle.compiler" compiler_module = "django.db.backends.oracle.compiler"

View File

@ -82,6 +82,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
uses_savepoints = True uses_savepoints = True
requires_rollback_on_dirty_transaction = True requires_rollback_on_dirty_transaction = True
has_real_datatype = True has_real_datatype = True
can_defer_constraint_checks = True
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql' vendor = 'postgresql'

View File

@ -69,6 +69,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_return_id_from_insert = False can_return_id_from_insert = False
requires_rollback_on_dirty_transaction = True requires_rollback_on_dirty_transaction = True
has_real_datatype = True has_real_datatype = True
can_defer_constraint_checks = True
class DatabaseOperations(PostgresqlDatabaseOperations): class DatabaseOperations(PostgresqlDatabaseOperations):
def last_executed_query(self, cursor, sql, params): def last_executed_query(self, cursor, sql, params):

View File

@ -11,6 +11,7 @@ from django.db.models.fields import *
from django.db.models.fields.subclassing import SubfieldBase from django.db.models.fields.subclassing import SubfieldBase
from django.db.models.fields.files import FileField, ImageField from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING
from django.db.models import signals from django.db.models import signals
# Admin stages. # Admin stages.

View File

@ -7,10 +7,12 @@ from django.core import validators
from django.db.models.fields import AutoField, FieldDoesNotExist from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.related import (OneToOneRel, ManyToOneRel, from django.db.models.fields.related import (OneToOneRel, ManyToOneRel,
OneToOneField, add_lazy_relation) OneToOneField, add_lazy_relation)
from django.db.models.query import delete_objects, Q from django.db.models.query import Q
from django.db.models.query_utils import CollectedObjects, DeferredAttribute from django.db.models.query_utils import DeferredAttribute
from django.db.models.deletion import Collector
from django.db.models.options import Options from django.db.models.options import Options
from django.db import connections, router, transaction, DatabaseError, DEFAULT_DB_ALIAS from django.db import (connections, router, transaction, DatabaseError,
DEFAULT_DB_ALIAS)
from django.db.models import signals from django.db.models import signals
from django.db.models.loading import register_models, get_model from django.db.models.loading import register_models, get_model
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -561,99 +563,13 @@ class Model(object):
save_base.alters_data = True save_base.alters_data = True
def _collect_sub_objects(self, seen_objs, parent=None, nullable=False):
"""
Recursively populates seen_objs with all objects related to this
object.
When done, seen_objs.items() will be in the format:
[(model_class, {pk_val: obj, pk_val: obj, ...}),
(model_class, {pk_val: obj, pk_val: obj, ...}), ...]
"""
pk_val = self._get_pk_val()
if seen_objs.add(self.__class__, pk_val, self,
type(parent), parent, nullable):
return
for related in self._meta.get_all_related_objects():
rel_opts_name = related.get_accessor_name()
if not related.field.rel.multiple:
try:
sub_obj = getattr(self, rel_opts_name)
except ObjectDoesNotExist:
pass
else:
sub_obj._collect_sub_objects(seen_objs, self, related.field.null)
else:
# To make sure we can access all elements, we can't use the
# normal manager on the related object. So we work directly
# with the descriptor object.
for cls in self.__class__.mro():
if rel_opts_name in cls.__dict__:
rel_descriptor = cls.__dict__[rel_opts_name]
break
else:
# in the case of a hidden fkey just skip it, it'll get
# processed as an m2m
if not related.field.rel.is_hidden():
raise AssertionError("Should never get here.")
else:
continue
delete_qs = rel_descriptor.delete_manager(self).all()
for sub_obj in delete_qs:
sub_obj._collect_sub_objects(seen_objs, self, related.field.null)
for related in self._meta.get_all_related_many_to_many_objects():
if related.field.rel.through:
db = router.db_for_write(related.field.rel.through.__class__, instance=self)
opts = related.field.rel.through._meta
reverse_field_name = related.field.m2m_reverse_field_name()
nullable = opts.get_field(reverse_field_name).null
filters = {reverse_field_name: self}
for sub_obj in related.field.rel.through._base_manager.using(db).filter(**filters):
sub_obj._collect_sub_objects(seen_objs, self, nullable)
for f in self._meta.many_to_many:
if f.rel.through:
db = router.db_for_write(f.rel.through.__class__, instance=self)
opts = f.rel.through._meta
field_name = f.m2m_field_name()
nullable = opts.get_field(field_name).null
filters = {field_name: self}
for sub_obj in f.rel.through._base_manager.using(db).filter(**filters):
sub_obj._collect_sub_objects(seen_objs, self, nullable)
else:
# m2m-ish but with no through table? GenericRelation: cascade delete
for sub_obj in f.value_from_object(self).all():
# Generic relations not enforced by db constraints, thus we can set
# nullable=True, order does not matter
sub_obj._collect_sub_objects(seen_objs, self, True)
# Handle any ancestors (for the model-inheritance case). We do this by
# traversing to the most remote parent classes -- those with no parents
# themselves -- and then adding those instances to the collection. That
# will include all the child instances down to "self".
parent_stack = [p for p in self._meta.parents.values() if p is not None]
while parent_stack:
link = parent_stack.pop()
parent_obj = getattr(self, link.name)
if parent_obj._meta.parents:
parent_stack.extend(parent_obj._meta.parents.values())
continue
# At this point, parent_obj is base class (no ancestor models). So
# delete it and all its descendents.
parent_obj._collect_sub_objects(seen_objs)
def delete(self, using=None): def delete(self, using=None):
using = using or router.db_for_write(self.__class__, instance=self) using = using or router.db_for_write(self.__class__, instance=self)
assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname) assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname)
# Find all the objects than need to be deleted. collector = Collector(using=using)
seen_objs = CollectedObjects() collector.collect([self])
self._collect_sub_objects(seen_objs) collector.delete()
# Actually delete the objects.
delete_objects(seen_objs, using)
delete.alters_data = True delete.alters_data = True

View File

@ -0,0 +1,245 @@
from operator import attrgetter
from django.db import connections, transaction, IntegrityError
from django.db.models import signals, sql
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
from django.utils.datastructures import SortedDict
from django.utils.functional import wraps
def CASCADE(collector, field, sub_objs, using):
collector.collect(sub_objs, source=field.rel.to,
source_attr=field.name, nullable=field.null)
if field.null and not connections[using].features.can_defer_constraint_checks:
collector.add_field_update(field, None, sub_objs)
def PROTECT(collector, field, sub_objs, using):
raise IntegrityError("Cannot delete some instances of model '%s' because "
"they are referenced through a protected foreign key: '%s.%s'" % (
field.rel.to.__name__, sub_objs[0].__class__.__name__, field.name
))
def SET(value):
if callable(value):
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value(), sub_objs)
else:
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value, sub_objs)
return set_on_delete
SET_NULL = SET(None)
def SET_DEFAULT(collector, field, sub_objs, using):
collector.add_field_update(field, field.get_default(), sub_objs)
def DO_NOTHING(collector, field, sub_objs, using):
pass
def force_managed(func):
@wraps(func)
def decorated(self, *args, **kwargs):
if not transaction.is_managed(using=self.using):
transaction.enter_transaction_management(using=self.using)
forced_managed = True
else:
forced_managed = False
try:
func(self, *args, **kwargs)
if forced_managed:
transaction.commit(using=self.using)
else:
transaction.commit_unless_managed(using=self.using)
finally:
if forced_managed:
transaction.leave_transaction_management(using=self.using)
return decorated
class Collector(object):
def __init__(self, using):
self.using = using
self.data = {} # {model: [instances]}
self.batches = {} # {model: {field: set([instances])}}
self.field_updates = {} # {model: {(field, value): set([instances])}}
self.dependencies = {} # {model: set([models])}
def add(self, objs, source=None, nullable=False):
"""
Adds 'objs' to the collection of objects to be deleted. If the call is
the result of a cascade, 'source' should be the model that caused it
and 'nullable' should be set to True, if the relation can be null.
Returns a list of all objects that were not already collected.
"""
if not objs:
return []
new_objs = []
model = objs[0].__class__
instances = self.data.setdefault(model, [])
for obj in objs:
if obj not in instances:
new_objs.append(obj)
instances.extend(new_objs)
# Nullable relationships can be ignored -- they are nulled out before
# deleting, and therefore do not affect the order in which objects have
# to be deleted.
if new_objs and source is not None and not nullable:
self.dependencies.setdefault(source, set()).add(model)
return new_objs
def add_batch(self, model, field, objs):
"""
Schedules a batch delete. Every instance of 'model' that is related to
an instance of 'obj' through 'field' will be deleted.
"""
self.batches.setdefault(model, {}).setdefault(field, set()).update(objs)
def add_field_update(self, field, value, objs):
"""
Schedules a field update. 'objs' must be a homogenous iterable
collection of model instances (e.g. a QuerySet).
"""
if not objs:
return
model = objs[0].__class__
self.field_updates.setdefault(
model, {}).setdefault(
(field, value), set()).update(objs)
def collect(self, objs, source=None, nullable=False, collect_related=True,
source_attr=None):
"""
Adds 'objs' to the collection of objects to be deleted as well as all
parent instances. 'objs' must be a homogenous iterable collection of
model instances (e.g. a QuerySet). If 'collect_related' is True,
related objects will be handled by their respective on_delete handler.
If the call is the result of a cascade, 'source' should be the model
that caused it and 'nullable' should be set to True, if the relation
can be null.
"""
new_objs = self.add(objs, source, nullable)
if not new_objs:
return
model = new_objs[0].__class__
# Recursively collect parent models, but not their related objects.
# These will be found by meta.get_all_related_objects()
for parent_model, ptr in model._meta.parents.iteritems():
if ptr:
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
self.collect(parent_objs, source=model,
source_attr=ptr.rel.related_name,
collect_related=False)
if collect_related:
for related in model._meta.get_all_related_objects(include_hidden=True):
field = related.field
if related.model._meta.auto_created:
self.add_batch(related.model, field, new_objs)
else:
sub_objs = self.related_objects(related, new_objs)
if not sub_objs:
continue
field.rel.on_delete(self, field, sub_objs, self.using)
# TODO This entire block is only needed as a special case to
# support cascade-deletes for GenericRelation. It should be
# removed/fixed when the ORM gains a proper abstraction for virtual
# or composite fields, and GFKs are reworked to fit into that.
for relation in model._meta.many_to_many:
if not relation.rel.through:
sub_objs = relation.bulk_related_objects(new_objs, self.using)
self.collect(sub_objs,
source=model,
source_attr=relation.rel.related_name,
nullable=True)
def related_objects(self, related, objs):
"""
Gets a QuerySet of objects related to ``objs`` via the relation ``related``.
"""
return related.model._base_manager.using(self.using).filter(
**{"%s__in" % related.field.name: objs}
)
def instances_with_model(self):
for model, instances in self.data.iteritems():
for obj in instances:
yield model, obj
def sort(self):
sorted_models = []
models = self.data.keys()
while len(sorted_models) < len(models):
found = False
for model in models:
if model in sorted_models:
continue
dependencies = self.dependencies.get(model)
if not (dependencies and dependencies.difference(sorted_models)):
sorted_models.append(model)
found = True
if not found:
return
self.data = SortedDict([(model, self.data[model])
for model in sorted_models])
@force_managed
def delete(self):
# sort instance collections
for instances in self.data.itervalues():
instances.sort(key=attrgetter("pk"))
# if possible, bring the models in an order suitable for databases that
# don't support transactions or cannot defer contraint checks until the
# end of a transaction.
self.sort()
# send pre_delete signals
for model, obj in self.instances_with_model():
if not model._meta.auto_created:
signals.pre_delete.send(
sender=model, instance=obj, using=self.using
)
# update fields
for model, instances_for_fieldvalues in self.field_updates.iteritems():
query = sql.UpdateQuery(model)
for (field, value), instances in instances_for_fieldvalues.iteritems():
query.update_batch([obj.pk for obj in instances],
{field.name: value}, self.using)
# reverse instance collections
for instances in self.data.itervalues():
instances.reverse()
# delete batches
for model, batches in self.batches.iteritems():
query = sql.DeleteQuery(model)
for field, instances in batches.iteritems():
query.delete_batch([obj.pk for obj in instances], self.using, field)
# delete instances
for model, instances in self.data.iteritems():
query = sql.DeleteQuery(model)
pk_list = [obj.pk for obj in instances]
query.delete_batch(pk_list, self.using)
# send post_delete signals
for model, obj in self.instances_with_model():
if not model._meta.auto_created:
signals.post_delete.send(
sender=model, instance=obj, using=self.using
)
# update collected instances
for model, instances_for_fieldvalues in self.field_updates.iteritems():
for (field, value), instances in instances_for_fieldvalues.iteritems():
for obj in instances:
setattr(obj, field.attname, value)
for model, instances in self.data.iteritems():
for instance in instances:
setattr(instance, model._meta.pk.attname, None)

View File

@ -7,8 +7,10 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
from django.db.models.related import RelatedObject from django.db.models.related import RelatedObject
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from django.db.models.deletion import CASCADE
from django.utils.encoding import smart_unicode from django.utils.encoding import smart_unicode
from django.utils.translation import ugettext_lazy as _, string_concat, ungettext, ugettext from django.utils.translation import (ugettext_lazy as _, string_concat,
ungettext, ugettext)
from django.utils.functional import curry from django.utils.functional import curry
from django.core import exceptions from django.core import exceptions
from django import forms from django import forms
@ -733,8 +735,8 @@ class ReverseManyRelatedObjectsDescriptor(object):
manager.add(*value) manager.add(*value)
class ManyToOneRel(object): class ManyToOneRel(object):
def __init__(self, to, field_name, related_name=None, def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
limit_choices_to=None, lookup_overrides=None, parent_link=False): parent_link=False, on_delete=None):
try: try:
to._meta to._meta
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
@ -744,9 +746,9 @@ class ManyToOneRel(object):
if limit_choices_to is None: if limit_choices_to is None:
limit_choices_to = {} limit_choices_to = {}
self.limit_choices_to = limit_choices_to self.limit_choices_to = limit_choices_to
self.lookup_overrides = lookup_overrides or {}
self.multiple = True self.multiple = True
self.parent_link = parent_link self.parent_link = parent_link
self.on_delete = on_delete
def is_hidden(self): def is_hidden(self):
"Should the related object be hidden?" "Should the related object be hidden?"
@ -764,11 +766,12 @@ class ManyToOneRel(object):
return data[0] return data[0]
class OneToOneRel(ManyToOneRel): class OneToOneRel(ManyToOneRel):
def __init__(self, to, field_name, related_name=None, def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
limit_choices_to=None, lookup_overrides=None, parent_link=False): parent_link=False, on_delete=None):
super(OneToOneRel, self).__init__(to, field_name, super(OneToOneRel, self).__init__(to, field_name,
related_name=related_name, limit_choices_to=limit_choices_to, related_name=related_name, limit_choices_to=limit_choices_to,
lookup_overrides=lookup_overrides, parent_link=parent_link) parent_link=parent_link, on_delete=on_delete
)
self.multiple = False self.multiple = False
class ManyToManyRel(object): class ManyToManyRel(object):
@ -820,8 +823,9 @@ class ForeignKey(RelatedField, Field):
kwargs['rel'] = rel_class(to, to_field, kwargs['rel'] = rel_class(to, to_field,
related_name=kwargs.pop('related_name', None), related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None), limit_choices_to=kwargs.pop('limit_choices_to', None),
lookup_overrides=kwargs.pop('lookup_overrides', None), parent_link=kwargs.pop('parent_link', False),
parent_link=kwargs.pop('parent_link', False)) on_delete=kwargs.pop('on_delete', CASCADE),
)
Field.__init__(self, **kwargs) Field.__init__(self, **kwargs)
def validate(self, value, model_instance): def validate(self, value, model_instance):

View File

@ -11,6 +11,11 @@ from django.utils.translation import activate, deactivate_all, get_language, str
from django.utils.encoding import force_unicode, smart_str from django.utils.encoding import force_unicode, smart_str
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
try:
all
except NameError:
from django.utils.itercompat import all
# Calculate the verbose_name by converting from InitialCaps to "lowercase with spaces". # Calculate the verbose_name by converting from InitialCaps to "lowercase with spaces".
get_verbose_name = lambda class_name: re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))', ' \\1', class_name).lower().strip() get_verbose_name = lambda class_name: re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))', ' \\1', class_name).lower().strip()
@ -339,16 +344,12 @@ class Options(object):
def get_delete_permission(self): def get_delete_permission(self):
return 'delete_%s' % self.object_name.lower() return 'delete_%s' % self.object_name.lower()
def get_all_related_objects(self, local_only=False): def get_all_related_objects(self, local_only=False, include_hidden=False):
try: return [k for k, v in self.get_all_related_objects_with_model(
self._related_objects_cache local_only=local_only, include_hidden=include_hidden)]
except AttributeError:
self._fill_related_objects_cache()
if local_only:
return [k for k, v in self._related_objects_cache.items() if not v]
return self._related_objects_cache.keys()
def get_all_related_objects_with_model(self): def get_all_related_objects_with_model(self, local_only=False,
include_hidden=False):
""" """
Returns a list of (related-object, model) pairs. Similar to Returns a list of (related-object, model) pairs. Similar to
get_fields_with_model(). get_fields_with_model().
@ -357,7 +358,13 @@ class Options(object):
self._related_objects_cache self._related_objects_cache
except AttributeError: except AttributeError:
self._fill_related_objects_cache() self._fill_related_objects_cache()
return self._related_objects_cache.items() predicates = []
if local_only:
predicates.append(lambda k, v: not v)
if not include_hidden:
predicates.append(lambda k, v: not k.field.rel.is_hidden())
return filter(lambda t: all([p(*t) for p in predicates]),
self._related_objects_cache.items())
def _fill_related_objects_cache(self): def _fill_related_objects_cache(self):
cache = SortedDict() cache = SortedDict()
@ -370,7 +377,7 @@ class Options(object):
cache[obj] = parent cache[obj] = parent
else: else:
cache[obj] = model cache[obj] = model
for klass in get_models(): for klass in get_models(include_auto_created=True):
for f in klass._meta.local_fields: for f in klass._meta.local_fields:
if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta: if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta:
cache[RelatedObject(f.rel.to, klass, f)] = None cache[RelatedObject(f.rel.to, klass, f)] = None

View File

@ -8,7 +8,8 @@ from django.db import connections, router, transaction, IntegrityError
from django.db.models.aggregates import Aggregate from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField from django.db.models.fields import DateField
from django.db.models.query_utils import (Q, select_related_descend, from django.db.models.query_utils import (Q, select_related_descend,
CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery) deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
from django.db.models import signals, sql from django.db.models import signals, sql
from django.utils.copycompat import deepcopy from django.utils.copycompat import deepcopy
@ -427,22 +428,9 @@ class QuerySet(object):
del_query.query.select_related = False del_query.query.select_related = False
del_query.query.clear_ordering() del_query.query.clear_ordering()
# Delete objects in chunks to prevent the list of related objects from collector = Collector(using=del_query.db)
# becoming too long. collector.collect(del_query)
seen_objs = None collector.delete()
del_itr = iter(del_query)
while 1:
# Collect a chunk of objects to be deleted, and then all the
# objects that are related to the objects that are to be deleted.
# The chunking *isn't* done by slicing the del_query because we
# need to maintain the query cache on del_query (see #12328)
seen_objs = CollectedObjects(seen_objs)
for i, obj in izip(xrange(CHUNK_SIZE), del_itr):
obj._collect_sub_objects(seen_objs)
if not seen_objs:
break
delete_objects(seen_objs, del_query.db)
# Clear the result cache, in case this QuerySet gets reused. # Clear the result cache, in case this QuerySet gets reused.
self._result_cache = None self._result_cache = None
@ -1287,79 +1275,6 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
pass pass
return obj, index_end return obj, index_end
def delete_objects(seen_objs, using):
"""
Iterate through a list of seen classes, and remove any instances that are
referred to.
"""
connection = connections[using]
if not transaction.is_managed(using=using):
transaction.enter_transaction_management(using=using)
forced_managed = True
else:
forced_managed = False
try:
ordered_classes = seen_objs.keys()
except CyclicDependency:
# If there is a cyclic dependency, we cannot in general delete the
# objects. However, if an appropriate transaction is set up, or if the
# database is lax enough, it will succeed. So for now, we go ahead and
# try anyway.
ordered_classes = seen_objs.unordered_keys()
obj_pairs = {}
try:
for cls in ordered_classes:
items = seen_objs[cls].items()
items.sort()
obj_pairs[cls] = items
# Pre-notify all instances to be deleted.
for pk_val, instance in items:
if not cls._meta.auto_created:
signals.pre_delete.send(sender=cls, instance=instance,
using=using)
pk_list = [pk for pk,instance in items]
update_query = sql.UpdateQuery(cls)
for field, model in cls._meta.get_fields_with_model():
if (field.rel and field.null and field.rel.to in seen_objs and
filter(lambda f: f.column == field.rel.get_related_field().column,
field.rel.to._meta.fields)):
if model:
sql.UpdateQuery(model).clear_related(field, pk_list, using=using)
else:
update_query.clear_related(field, pk_list, using=using)
# Now delete the actual data.
for cls in ordered_classes:
items = obj_pairs[cls]
items.reverse()
pk_list = [pk for pk,instance in items]
del_query = sql.DeleteQuery(cls)
del_query.delete_batch(pk_list, using=using)
# Last cleanup; set NULLs where there once was a reference to the
# object, NULL the primary key of the found objects, and perform
# post-notification.
for pk_val, instance in items:
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
setattr(instance, field.attname, None)
if not cls._meta.auto_created:
signals.post_delete.send(sender=cls, instance=instance, using=using)
setattr(instance, cls._meta.pk.attname, None)
if forced_managed:
transaction.commit(using=using)
else:
transaction.commit_unless_managed(using=using)
finally:
if forced_managed:
transaction.leave_transaction_management(using=using)
class RawQuerySet(object): class RawQuerySet(object):
""" """

View File

@ -14,13 +14,6 @@ from django.utils import tree
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
class CyclicDependency(Exception):
"""
An error when dealing with a collection of objects that have a cyclic
dependency, i.e. when deleting multiple objects.
"""
pass
class InvalidQuery(Exception): class InvalidQuery(Exception):
""" """
The query passed to raw isn't a safe query to use with raw. The query passed to raw isn't a safe query to use with raw.
@ -28,107 +21,6 @@ class InvalidQuery(Exception):
pass pass
class CollectedObjects(object):
"""
A container that stores keys and lists of values along with remembering the
parent objects for all the keys.
This is used for the database object deletion routines so that we can
calculate the 'leaf' objects which should be deleted first.
previously_seen is an optional argument. It must be a CollectedObjects
instance itself; any previously_seen collected object will be blocked from
being added to this instance.
"""
def __init__(self, previously_seen=None):
self.data = {}
self.children = {}
if previously_seen:
self.blocked = previously_seen.blocked
for cls, seen in previously_seen.data.items():
self.blocked.setdefault(cls, SortedDict()).update(seen)
else:
self.blocked = {}
def add(self, model, pk, obj, parent_model, parent_obj=None, nullable=False):
"""
Adds an item to the container.
Arguments:
* model - the class of the object being added.
* pk - the primary key.
* obj - the object itself.
* parent_model - the model of the parent object that this object was
reached through.
* parent_obj - the parent object this object was reached
through (not used here, but needed in the API for use elsewhere)
* nullable - should be True if this relation is nullable.
Returns True if the item already existed in the structure and
False otherwise.
"""
if pk in self.blocked.get(model, {}):
return True
d = self.data.setdefault(model, SortedDict())
retval = pk in d
d[pk] = obj
# Nullable relationships can be ignored -- they are nulled out before
# deleting, and therefore do not affect the order in which objects
# have to be deleted.
if parent_model is not None and not nullable:
self.children.setdefault(parent_model, []).append(model)
return retval
def __contains__(self, key):
return self.data.__contains__(key)
def __getitem__(self, key):
return self.data[key]
def __nonzero__(self):
return bool(self.data)
def iteritems(self):
for k in self.ordered_keys():
yield k, self[k]
def items(self):
return list(self.iteritems())
def keys(self):
return self.ordered_keys()
def ordered_keys(self):
"""
Returns the models in the order that they should be dealt with (i.e.
models with no dependencies first).
"""
dealt_with = SortedDict()
# Start with items that have no children
models = self.data.keys()
while len(dealt_with) < len(models):
found = False
for model in models:
if model in dealt_with:
continue
children = self.children.setdefault(model, [])
if len([c for c in children if c not in dealt_with]) == 0:
dealt_with[model] = None
found = True
if not found:
raise CyclicDependency(
"There is a cyclic dependency of items to be processed.")
return dealt_with.keys()
def unordered_keys(self):
"""
Fallback for the case where is a cyclic dependency but we don't care.
"""
return self.data.keys()
class QueryWrapper(object): class QueryWrapper(object):
""" """
A type that indicates the contents are an SQL fragment and the associate A type that indicates the contents are an SQL fragment and the associate

View File

@ -27,16 +27,17 @@ class DeleteQuery(Query):
self.where = where self.where = where
self.get_compiler(using).execute_sql(None) self.get_compiler(using).execute_sql(None)
def delete_batch(self, pk_list, using): def delete_batch(self, pk_list, using, field=None):
""" """
Set up and execute delete queries for all the objects in pk_list. Set up and execute delete queries for all the objects in pk_list.
More than one physical query may be executed if there are a More than one physical query may be executed if there are a
lot of values in pk_list. lot of values in pk_list.
""" """
if not field:
field = self.model._meta.pk
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class() where = self.where_class()
field = self.model._meta.pk
where.add((Constraint(None, field.column, field), 'in', where.add((Constraint(None, field.column, field), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
self.do_query(self.model._meta.db_table, where, using=using) self.do_query(self.model._meta.db_table, where, using=using)
@ -68,20 +69,14 @@ class UpdateQuery(Query):
related_updates=self.related_updates.copy(), **kwargs) related_updates=self.related_updates.copy(), **kwargs)
def clear_related(self, related_field, pk_list, using): def update_batch(self, pk_list, values, using):
""" pk_field = self.model._meta.pk
Set up and execute an update query that clears related entries for the self.add_update_values(values)
keys in pk_list.
This is used by the QuerySet.delete_objects() method.
"""
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class() self.where = self.where_class()
f = self.model._meta.pk self.where.add((Constraint(None, pk_field.column, pk_field), 'in',
self.where.add((Constraint(None, f.column, f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
AND) AND)
self.values = [(related_field, None, None)]
self.get_compiler(using).execute_sql(None) self.get_compiler(using).execute_sql(None)
def add_update_values(self, values): def add_update_values(self, values):

View File

@ -341,6 +341,16 @@ pointing at it will be deleted as well. In the example above, this means that
if a ``Bookmark`` object were deleted, any ``TaggedItem`` objects pointing at if a ``Bookmark`` object were deleted, any ``TaggedItem`` objects pointing at
it would be deleted at the same time. it would be deleted at the same time.
.. versionadded:: 1.3
Unlike :class:`~django.db.models.ForeignKey`,
:class:`~django.contrib.contenttypes.generic.GenericForeignKey` does not accept
an :attr:`~django.db.models.ForeignKey.on_delete` argument to customize this
behavior; if desired, you can avoid the cascade-deletion simply by not using
:class:`~django.contrib.contenttypes.generic.GenericRelation`, and alternate
behavior can be provided via the :data:`~django.db.models.signals.pre_delete`
signal.
Generic relations and aggregation Generic relations and aggregation
--------------------------------- ---------------------------------

View File

@ -930,7 +930,7 @@ define the details of how the relation works.
If you'd prefer Django didn't create a backwards relation, set ``related_name`` If you'd prefer Django didn't create a backwards relation, set ``related_name``
to ``'+'``. For example, this will ensure that the ``User`` model won't get a to ``'+'``. For example, this will ensure that the ``User`` model won't get a
backwards relation to this model:: backwards relation to this model::
user = models.ForeignKey(User, related_name='+') user = models.ForeignKey(User, related_name='+')
.. attribute:: ForeignKey.to_field .. attribute:: ForeignKey.to_field
@ -938,6 +938,51 @@ define the details of how the relation works.
The field on the related object that the relation is to. By default, Django The field on the related object that the relation is to. By default, Django
uses the primary key of the related object. uses the primary key of the related object.
.. versionadded:: 1.3
.. attribute:: ForeignKey.on_delete
When an object referenced by a :class:`ForeignKey` is deleted, Django by
default emulates the behavior of the SQL constraint ``ON DELETE CASCADE``
and also deletes the object containing the ``ForeignKey``. This behavior
can be overridden by specifying the :attr:`on_delete` argument. For
example, if you have a nullable :class:`ForeignKey` and you want it to be
set null when the referenced object is deleted::
user = models.ForeignKey(User, blank=True, null=True, on_delete=models.SET_NULL)
The possible values for :attr:`on_delete` are found in
:mod:`django.db.models`:
* :attr:`~django.db.models.CASCADE`: Cascade deletes; the default.
* :attr:`~django.db.models.PROTECT`: Prevent deletion of the referenced
object by raising :exc:`django.db.IntegrityError`.
* :attr:`~django.db.models.SET_NULL`: Set the :class:`ForeignKey` null;
this is only possible if :attr:`null` is ``True``.
* :attr:`~django.db.models.SET_DEFAULT`: Set the :class:`ForeignKey` to its
default value; a default for the :class:`ForeignKey` must be set.
* :func:`~django.db.models.SET()`: Set the :class:`ForeignKey` to the value
passed to :func:`~django.db.models.SET()`, or if a callable is passed in,
the result of calling it. In most cases, passing a callable will be
necessary to avoid executing queries at the time your models.py is
imported::
def get_sentinel_user():
return User.objects.get_or_create(username='deleted')[0]
class MyModel(models.Model):
user = models.ForeignKey(User, on_delete=models.SET(get_sentinel_user))
* :attr:`~django.db.models.DO_NOTHING`: Take no action. If your database
backend enforces referential integrity, this will cause an
:exc:`~django.db.IntegrityError` unless you manually add a SQL ``ON
DELETE`` constraint to the database field (perhaps using
:ref:`initial sql<initial-sql>`).
.. _ref-manytomany: .. _ref-manytomany:
``ManyToManyField`` ``ManyToManyField``

View File

@ -1263,14 +1263,20 @@ For example, to delete all the entries in a particular blog::
# Delete all the entries belonging to this Blog. # Delete all the entries belonging to this Blog.
>>> Entry.objects.filter(blog=b).delete() >>> Entry.objects.filter(blog=b).delete()
Django emulates the SQL constraint ``ON DELETE CASCADE`` -- in other words, any By default, Django's :class:`~django.db.models.ForeignKey` emulates the SQL
objects with foreign keys pointing at the objects to be deleted will be deleted constraint ``ON DELETE CASCADE`` -- in other words, any objects with foreign
along with them. For example:: keys pointing at the objects to be deleted will be deleted along with them.
For example::
blogs = Blog.objects.all() blogs = Blog.objects.all()
# This will delete all Blogs and all of their Entry objects. # This will delete all Blogs and all of their Entry objects.
blogs.delete() blogs.delete()
.. versionadded:: 1.3
This cascade behavior is customizable via the
:attr:`~django.db.models.ForeignKey.on_delete` argument to the
:class:`~django.db.models.ForeignKey`.
The ``delete()`` method does a bulk delete and does not call any ``delete()`` The ``delete()`` method does a bulk delete and does not call any ``delete()``
methods on your models. It does, however, emit the methods on your models. It does, however, emit the
:data:`~django.db.models.signals.pre_delete` and :data:`~django.db.models.signals.pre_delete` and

View File

@ -86,6 +86,19 @@ Users of Python 2.5 and above may now use :ref:`transaction management functions
For more information, see :ref:`transaction-management-functions`. For more information, see :ref:`transaction-management-functions`.
Configurable delete-cascade
~~~~~~~~~~~~~~~~~~~~~~~~~~~
:class:`~django.db.models.ForeignKey` and
:class:`~django.db.models.OneToOneField` now accept an
:attr:`~django.db.models.ForeignKey.on_delete` argument to customize behavior
when the referenced object is deleted. Previously, deletes were always
cascaded; available alternatives now include set null, set default, set to any
value, protect, or do nothing.
For more information, see the :attr:`~django.db.models.ForeignKey.on_delete`
documentation.
Contextual markers in translatable strings Contextual markers in translatable strings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -749,15 +749,20 @@ model (e.g., by iterating over a ``QuerySet`` and calling ``delete()``
on each object individually) rather than using the bulk ``delete()`` on each object individually) rather than using the bulk ``delete()``
method of a ``QuerySet``. method of a ``QuerySet``.
When Django deletes an object, it emulates the behavior of the SQL When Django deletes an object, by default it emulates the behavior of the SQL
constraint ``ON DELETE CASCADE`` -- in other words, any objects which constraint ``ON DELETE CASCADE`` -- in other words, any objects which had
had foreign keys pointing at the object to be deleted will be deleted foreign keys pointing at the object to be deleted will be deleted along with
along with it. For example:: it. For example::
b = Blog.objects.get(pk=1) b = Blog.objects.get(pk=1)
# This will delete the Blog and all of its Entry objects. # This will delete the Blog and all of its Entry objects.
b.delete() b.delete()
.. versionadded:: 1.3
This cascade behavior is customizable via the
:attr:`~django.db.models.ForeignKey.on_delete` argument to the
:class:`~django.db.models.ForeignKey`.
Note that ``delete()`` is the only ``QuerySet`` method that is not exposed on a Note that ``delete()`` is the only ``QuerySet`` method that is not exposed on a
``Manager`` itself. This is a safety mechanism to prevent you from accidentally ``Manager`` itself. This is a safety mechanism to prevent you from accidentally
requesting ``Entry.objects.delete()``, and deleting *all* the entries. If you requesting ``Entry.objects.delete()``, and deleting *all* the entries. If you

View File

@ -1 +0,0 @@

View File

@ -1,42 +1,106 @@
# coding: utf-8 from django.db import models, IntegrityError
"""
Tests for some corner cases with deleting.
"""
from django.db import models
class DefaultRepr(object): class R(models.Model):
def __repr__(self): is_default = models.BooleanField(default=False)
return u"<%s: %s>" % (self.__class__.__name__, self.__dict__)
class A(DefaultRepr, models.Model): def __str__(self):
return "%s" % self.pk
get_default_r = lambda: R.objects.get_or_create(is_default=True)[0]
class S(models.Model):
r = models.ForeignKey(R)
class T(models.Model):
s = models.ForeignKey(S)
class U(models.Model):
t = models.ForeignKey(T)
class RChild(R):
pass pass
class B(DefaultRepr, models.Model):
a = models.ForeignKey(A)
class C(DefaultRepr, models.Model): class A(models.Model):
b = models.ForeignKey(B) name = models.CharField(max_length=30)
class D(DefaultRepr, models.Model): auto = models.ForeignKey(R, related_name="auto_set")
c = models.ForeignKey(C) auto_nullable = models.ForeignKey(R, null=True,
a = models.ForeignKey(A) related_name='auto_nullable_set')
setvalue = models.ForeignKey(R, on_delete=models.SET(get_default_r),
related_name='setvalue')
setnull = models.ForeignKey(R, on_delete=models.SET_NULL, null=True,
related_name='setnull_set')
setdefault = models.ForeignKey(R, on_delete=models.SET_DEFAULT,
default=get_default_r, related_name='setdefault_set')
setdefault_none = models.ForeignKey(R, on_delete=models.SET_DEFAULT,
default=None, null=True, related_name='setnull_nullable_set')
cascade = models.ForeignKey(R, on_delete=models.CASCADE,
related_name='cascade_set')
cascade_nullable = models.ForeignKey(R, on_delete=models.CASCADE, null=True,
related_name='cascade_nullable_set')
protect = models.ForeignKey(R, on_delete=models.PROTECT, null=True)
donothing = models.ForeignKey(R, on_delete=models.DO_NOTHING, null=True,
related_name='donothing_set')
child = models.ForeignKey(RChild, related_name="child")
child_setnull = models.ForeignKey(RChild, on_delete=models.SET_NULL, null=True,
related_name="child_setnull")
# Simplified, we have: # A OneToOneField is just a ForeignKey unique=True, so we don't duplicate
# A # all the tests; just one smoke test to ensure on_delete works for it as
# B -> A # well.
# C -> B o2o_setnull = models.ForeignKey(R, null=True,
# D -> C on_delete=models.SET_NULL, related_name="o2o_nullable_set")
# D -> A
# So, we must delete Ds first of all, then Cs then Bs then As.
# However, if we start at As, we might find Bs first (in which
# case things will be nice), or find Ds first.
# Some mutually dependent models, but nullable def create_a(name):
class E(DefaultRepr, models.Model): a = A(name=name)
f = models.ForeignKey('F', null=True, related_name='e_rel') for name in ('auto', 'auto_nullable', 'setvalue', 'setnull', 'setdefault',
'setdefault_none', 'cascade', 'cascade_nullable', 'protect',
'donothing', 'o2o_setnull'):
r = R.objects.create()
setattr(a, name, r)
a.child = RChild.objects.create()
a.child_setnull = RChild.objects.create()
a.save()
return a
class F(DefaultRepr, models.Model):
e = models.ForeignKey(E, related_name='f_rel')
class M(models.Model):
m2m = models.ManyToManyField(R, related_name="m_set")
m2m_through = models.ManyToManyField(R, through="MR",
related_name="m_through_set")
m2m_through_null = models.ManyToManyField(R, through="MRNull",
related_name="m_through_null_set")
class MR(models.Model):
m = models.ForeignKey(M)
r = models.ForeignKey(R)
class MRNull(models.Model):
m = models.ForeignKey(M)
r = models.ForeignKey(R, null=True, on_delete=models.SET_NULL)
class Avatar(models.Model):
pass
class User(models.Model):
avatar = models.ForeignKey(Avatar, null=True)
class HiddenUser(models.Model):
r = models.ForeignKey(R, related_name="+")
class HiddenUserProfile(models.Model):
user = models.ForeignKey(HiddenUser)

View File

@ -1,135 +1,253 @@
from django.db.models import sql from django.db import models, IntegrityError
from django.db.models.loading import cache from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature
from django.db.models.query import CollectedObjects
from django.db.models.query_utils import CyclicDependency
from django.test import TestCase
from models import A, B, C, D, E, F from modeltests.delete.models import (R, RChild, S, T, U, A, M, MR, MRNull,
create_a, get_default_r, User, Avatar, HiddenUser, HiddenUserProfile)
class DeleteTests(TestCase): class OnDeleteTests(TestCase):
def clear_rel_obj_caches(self, *models):
for m in models:
if hasattr(m._meta, '_related_objects_cache'):
del m._meta._related_objects_cache
def order_models(self, *models):
cache.app_models["delete"].keyOrder = models
def setUp(self): def setUp(self):
self.order_models("a", "b", "c", "d", "e", "f") self.DEFAULT = get_default_r()
self.clear_rel_obj_caches(A, B, C, D, E, F)
def tearDown(self): def test_auto(self):
self.order_models("a", "b", "c", "d", "e", "f") a = create_a('auto')
self.clear_rel_obj_caches(A, B, C, D, E, F) a.auto.delete()
self.assertFalse(A.objects.filter(name='auto').exists())
def test_collected_objects(self): def test_auto_nullable(self):
g = CollectedObjects() a = create_a('auto_nullable')
self.assertFalse(g.add("key1", 1, "item1", None)) a.auto_nullable.delete()
self.assertEqual(g["key1"], {1: "item1"}) self.assertFalse(A.objects.filter(name='auto_nullable').exists())
self.assertFalse(g.add("key2", 1, "item1", "key1")) def test_setvalue(self):
self.assertFalse(g.add("key2", 2, "item2", "key1")) a = create_a('setvalue')
a.setvalue.delete()
a = A.objects.get(pk=a.pk)
self.assertEqual(self.DEFAULT, a.setvalue)
self.assertEqual(g["key2"], {1: "item1", 2: "item2"}) def test_setnull(self):
a = create_a('setnull')
a.setnull.delete()
a = A.objects.get(pk=a.pk)
self.assertEqual(None, a.setnull)
self.assertFalse(g.add("key3", 1, "item1", "key1")) def test_setdefault(self):
self.assertTrue(g.add("key3", 1, "item1", "key2")) a = create_a('setdefault')
self.assertEqual(g.ordered_keys(), ["key3", "key2", "key1"]) a.setdefault.delete()
a = A.objects.get(pk=a.pk)
self.assertEqual(self.DEFAULT, a.setdefault)
self.assertTrue(g.add("key2", 1, "item1", "key3")) def test_setdefault_none(self):
self.assertRaises(CyclicDependency, g.ordered_keys) a = create_a('setdefault_none')
a.setdefault_none.delete()
a = A.objects.get(pk=a.pk)
self.assertEqual(None, a.setdefault_none)
def test_delete(self): def test_cascade(self):
## Second, test the usage of CollectedObjects by Model.delete() a = create_a('cascade')
a.cascade.delete()
self.assertFalse(A.objects.filter(name='cascade').exists())
# Due to the way that transactions work in the test harness, doing def test_cascade_nullable(self):
# m.delete() here can work but fail in a real situation, since it may a = create_a('cascade_nullable')
# delete all objects, but not in the right order. So we manually check a.cascade_nullable.delete()
# that the order of deletion is correct. self.assertFalse(A.objects.filter(name='cascade_nullable').exists())
# Also, it is possible that the order is correct 'accidentally', due def test_protect(self):
# solely to order of imports etc. To check this, we set the order that a = create_a('protect')
# 'get_models()' will retrieve to a known 'nice' order, and then try self.assertRaises(IntegrityError, a.protect.delete)
# again with a known 'tricky' order. Slightly naughty access to
# internals here :-)
# If implementation changes, then the tests may need to be simplified: def test_do_nothing(self):
# - remove the lines that set the .keyOrder and clear the related # Testing DO_NOTHING is a bit harder: It would raise IntegrityError for a normal model,
# object caches # so we connect to pre_delete and set the fk to a known value.
# - remove the second set of tests (with a2, b2 etc) replacement_r = R.objects.create()
def check_do_nothing(sender, **kwargs):
obj = kwargs['instance']
obj.donothing_set.update(donothing=replacement_r)
models.signals.pre_delete.connect(check_do_nothing)
a = create_a('do_nothing')
a.donothing.delete()
a = A.objects.get(pk=a.pk)
self.assertEqual(replacement_r, a.donothing)
models.signals.pre_delete.disconnect(check_do_nothing)
a1 = A.objects.create() def test_inheritance_cascade_up(self):
b1 = B.objects.create(a=a1) child = RChild.objects.create()
c1 = C.objects.create(b=b1) child.delete()
d1 = D.objects.create(c=c1, a=a1) self.assertFalse(R.objects.filter(pk=child.pk).exists())
o = CollectedObjects() def test_inheritance_cascade_down(self):
a1._collect_sub_objects(o) child = RChild.objects.create()
self.assertEqual(o.keys(), [D, C, B, A]) parent = child.r_ptr
a1.delete() parent.delete()
self.assertFalse(RChild.objects.filter(pk=child.pk).exists())
# Same again with a known bad order def test_cascade_from_child(self):
self.order_models("d", "c", "b", "a") a = create_a('child')
self.clear_rel_obj_caches(A, B, C, D) a.child.delete()
self.assertFalse(A.objects.filter(name='child').exists())
self.assertFalse(R.objects.filter(pk=a.child_id).exists())
a2 = A.objects.create() def test_cascade_from_parent(self):
b2 = B.objects.create(a=a2) a = create_a('child')
c2 = C.objects.create(b=b2) R.objects.get(pk=a.child_id).delete()
d2 = D.objects.create(c=c2, a=a2) self.assertFalse(A.objects.filter(name='child').exists())
self.assertFalse(RChild.objects.filter(pk=a.child_id).exists())
o = CollectedObjects() def test_setnull_from_child(self):
a2._collect_sub_objects(o) a = create_a('child_setnull')
self.assertEqual(o.keys(), [D, C, B, A]) a.child_setnull.delete()
a2.delete() self.assertFalse(R.objects.filter(pk=a.child_setnull_id).exists())
def test_collected_objects_null(self): a = A.objects.get(pk=a.pk)
g = CollectedObjects() self.assertEqual(None, a.child_setnull)
self.assertFalse(g.add("key1", 1, "item1", None))
self.assertFalse(g.add("key2", 1, "item1", "key1", nullable=True))
self.assertTrue(g.add("key1", 1, "item1", "key2"))
self.assertEqual(g.ordered_keys(), ["key1", "key2"])
def test_delete_nullable(self): def test_setnull_from_parent(self):
e1 = E.objects.create() a = create_a('child_setnull')
f1 = F.objects.create(e=e1) R.objects.get(pk=a.child_setnull_id).delete()
e1.f = f1 self.assertFalse(RChild.objects.filter(pk=a.child_setnull_id).exists())
e1.save()
# Since E.f is nullable, we should delete F first (after nulling out a = A.objects.get(pk=a.pk)
# the E.f field), then E. self.assertEqual(None, a.child_setnull)
o = CollectedObjects() def test_o2o_setnull(self):
e1._collect_sub_objects(o) a = create_a('o2o_setnull')
self.assertEqual(o.keys(), [F, E]) a.o2o_setnull.delete()
a = A.objects.get(pk=a.pk)
self.assertEqual(None, a.o2o_setnull)
# temporarily replace the UpdateQuery class to verify that E.f is
# actually nulled out first
logged = [] class DeletionTests(TestCase):
class LoggingUpdateQuery(sql.UpdateQuery): def test_m2m(self):
def clear_related(self, related_field, pk_list, using): m = M.objects.create()
logged.append(related_field.name) r = R.objects.create()
return super(LoggingUpdateQuery, self).clear_related(related_field, pk_list, using) MR.objects.create(m=m, r=r)
original = sql.UpdateQuery r.delete()
sql.UpdateQuery = LoggingUpdateQuery self.assertFalse(MR.objects.exists())
e1.delete() r = R.objects.create()
self.assertEqual(logged, ["f"]) MR.objects.create(m=m, r=r)
logged = [] m.delete()
self.assertFalse(MR.objects.exists())
e2 = E.objects.create() m = M.objects.create()
f2 = F.objects.create(e=e2) r = R.objects.create()
e2.f = f2 m.m2m.add(r)
e2.save() r.delete()
through = M._meta.get_field('m2m').rel.through
self.assertFalse(through.objects.exists())
# Same deal as before, though we are starting from the other object. r = R.objects.create()
o = CollectedObjects() m.m2m.add(r)
f2._collect_sub_objects(o) m.delete()
self.assertEqual(o.keys(), [F, E]) self.assertFalse(through.objects.exists())
f2.delete()
self.assertEqual(logged, ["f"])
logged = []
sql.UpdateQuery = original m = M.objects.create()
r = R.objects.create()
MRNull.objects.create(m=m, r=r)
r.delete()
self.assertFalse(not MRNull.objects.exists())
self.assertFalse(m.m2m_through_null.exists())
def test_bulk(self):
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
s = S.objects.create(r=R.objects.create())
for i in xrange(2*GET_ITERATOR_CHUNK_SIZE):
T.objects.create(s=s)
# 1 (select related `T` instances)
# + 1 (select related `U` instances)
# + 2 (delete `T` instances in batches)
# + 1 (delete `s`)
self.assertNumQueries(5, s.delete)
self.assertFalse(S.objects.exists())
def test_instance_update(self):
deleted = []
related_setnull_sets = []
def pre_delete(sender, **kwargs):
obj = kwargs['instance']
deleted.append(obj)
if isinstance(obj, R):
related_setnull_sets.append(list(a.pk for a in obj.setnull_set.all()))
models.signals.pre_delete.connect(pre_delete)
a = create_a('update_setnull')
a.setnull.delete()
a = create_a('update_cascade')
a.cascade.delete()
for obj in deleted:
self.assertEqual(None, obj.pk)
for pk_list in related_setnull_sets:
for a in A.objects.filter(id__in=pk_list):
self.assertEqual(None, a.setnull)
models.signals.pre_delete.disconnect(pre_delete)
def test_deletion_order(self):
pre_delete_order = []
post_delete_order = []
def log_post_delete(sender, **kwargs):
pre_delete_order.append((sender, kwargs['instance'].pk))
def log_pre_delete(sender, **kwargs):
post_delete_order.append((sender, kwargs['instance'].pk))
models.signals.post_delete.connect(log_post_delete)
models.signals.pre_delete.connect(log_pre_delete)
r = R.objects.create(pk=1)
s1 = S.objects.create(pk=1, r=r)
s2 = S.objects.create(pk=2, r=r)
t1 = T.objects.create(pk=1, s=s1)
t2 = T.objects.create(pk=2, s=s2)
r.delete()
self.assertEqual(
pre_delete_order, [(T, 2), (T, 1), (S, 2), (S, 1), (R, 1)]
)
self.assertEqual(
post_delete_order, [(T, 1), (T, 2), (S, 1), (S, 2), (R, 1)]
)
models.signals.post_delete.disconnect(log_post_delete)
models.signals.post_delete.disconnect(log_pre_delete)
@skipUnlessDBFeature("can_defer_constraint_checks")
def test_can_defer_constraint_checks(self):
u = User.objects.create(
avatar=Avatar.objects.create()
)
a = Avatar.objects.get(pk=u.avatar_id)
# 1 query to find the users for the avatar.
# 1 query to delete the user
# 1 query to delete the avatar
# The important thing is that when we can defer constraint checks there
# is no need to do an UPDATE on User.avatar to null it out.
self.assertNumQueries(3, a.delete)
self.assertFalse(User.objects.exists())
self.assertFalse(Avatar.objects.exists())
@skipIfDBFeature("can_defer_constraint_checks")
def test_cannot_defer_constraint_checks(self):
u = User.objects.create(
avatar=Avatar.objects.create()
)
a = Avatar.objects.get(pk=u.avatar_id)
# 1 query to find the users for the avatar.
# 1 query to delete the user
# 1 query to null out user.avatar, because we can't defer the constraint
# 1 query to delete the avatar
self.assertNumQueries(4, a.delete)
self.assertFalse(User.objects.exists())
self.assertFalse(Avatar.objects.exists())
def test_hidden_related(self):
r = R.objects.create()
h = HiddenUser.objects.create(r=r)
p = HiddenUserProfile.objects.create(user=h)
r.delete()
self.assertEqual(HiddenUserProfile.objects.count(), 0)

View File

@ -210,6 +210,13 @@ class NonExistingOrderingWithSingleUnderscore(models.Model):
class Meta: class Meta:
ordering = ("does_not_exist",) ordering = ("does_not_exist",)
class InvalidSetNull(models.Model):
fk = models.ForeignKey('self', on_delete=models.SET_NULL)
class InvalidSetDefault(models.Model):
fk = models.ForeignKey('self', on_delete=models.SET_DEFAULT)
model_errors = """invalid_models.fielderrors: "charfield": CharFields require a "max_length" attribute that is a positive integer. model_errors = """invalid_models.fielderrors: "charfield": CharFields require a "max_length" attribute that is a positive integer.
invalid_models.fielderrors: "charfield2": CharFields require a "max_length" attribute that is a positive integer. invalid_models.fielderrors: "charfield2": CharFields require a "max_length" attribute that is a positive integer.
invalid_models.fielderrors: "charfield3": CharFields require a "max_length" attribute that is a positive integer. invalid_models.fielderrors: "charfield3": CharFields require a "max_length" attribute that is a positive integer.
@ -315,4 +322,6 @@ invalid_models.uniquem2m: ManyToManyFields cannot be unique. Remove the unique
invalid_models.nonuniquefktarget1: Field 'bad' under model 'FKTarget' must have a unique=True constraint. invalid_models.nonuniquefktarget1: Field 'bad' under model 'FKTarget' must have a unique=True constraint.
invalid_models.nonuniquefktarget2: Field 'bad' under model 'FKTarget' must have a unique=True constraint. invalid_models.nonuniquefktarget2: Field 'bad' under model 'FKTarget' must have a unique=True constraint.
invalid_models.nonexistingorderingwithsingleunderscore: "ordering" refers to "does_not_exist", a field that doesn't exist. invalid_models.nonexistingorderingwithsingleunderscore: "ordering" refers to "does_not_exist", a field that doesn't exist.
invalid_models.invalidsetnull: 'fk' specifies on_delete=SET_NULL, but cannot be null.
invalid_models.invalidsetdefault: 'fk' specifies on_delete=SET_DEFAULT, but has no default value.
""" """

View File

@ -18,6 +18,10 @@ class Article(models.Model):
class Count(models.Model): class Count(models.Model):
num = models.PositiveSmallIntegerField() num = models.PositiveSmallIntegerField()
parent = models.ForeignKey('self', null=True)
def __unicode__(self):
return unicode(self.num)
class Event(models.Model): class Event(models.Model):
date = models.DateTimeField(auto_now_add=True) date = models.DateTimeField(auto_now_add=True)

View File

@ -6,7 +6,7 @@ from django.contrib.admin.util import display_for_field, label_for_field, lookup
from django.contrib.admin.util import NestedObjects from django.contrib.admin.util import NestedObjects
from django.contrib.admin.views.main import EMPTY_CHANGELIST_VALUE from django.contrib.admin.views.main import EMPTY_CHANGELIST_VALUE
from django.contrib.sites.models import Site from django.contrib.sites.models import Site
from django.db import models from django.db import models, DEFAULT_DB_ALIAS
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from django.utils.formats import localize from django.utils.formats import localize
@ -20,51 +20,50 @@ class NestedObjectsTests(TestCase):
""" """
def setUp(self): def setUp(self):
self.n = NestedObjects() self.n = NestedObjects(using=DEFAULT_DB_ALIAS)
self.objs = [Count.objects.create(num=i) for i in range(5)] self.objs = [Count.objects.create(num=i) for i in range(5)]
def _check(self, target): def _check(self, target):
self.assertEquals(self.n.nested(lambda obj: obj.num), target) self.assertEquals(self.n.nested(lambda obj: obj.num), target)
def _add(self, obj, parent=None): def _connect(self, i, j):
# don't bother providing the extra args that NestedObjects ignores self.objs[i].parent = self.objs[j]
self.n.add(None, None, obj, None, parent) self.objs[i].save()
def _collect(self, *indices):
self.n.collect([self.objs[i] for i in indices])
def test_unrelated_roots(self): def test_unrelated_roots(self):
self._add(self.objs[0]) self._connect(2, 1)
self._add(self.objs[1]) self._collect(0)
self._add(self.objs[2], self.objs[1]) self._collect(1)
self._check([0, 1, [2]]) self._check([0, 1, [2]])
def test_siblings(self): def test_siblings(self):
self._add(self.objs[0]) self._connect(1, 0)
self._add(self.objs[1], self.objs[0]) self._connect(2, 0)
self._add(self.objs[2], self.objs[0]) self._collect(0)
self._check([0, [1, 2]]) self._check([0, [1, 2]])
def test_duplicate_instances(self):
self._add(self.objs[0])
self._add(self.objs[1])
dupe = Count.objects.get(num=1)
self._add(dupe, self.objs[0])
self._check([0, 1])
def test_non_added_parent(self): def test_non_added_parent(self):
self._add(self.objs[0], self.objs[1]) self._connect(0, 1)
self._collect(0)
self._check([0]) self._check([0])
def test_cyclic(self): def test_cyclic(self):
self._add(self.objs[0], self.objs[2]) self._connect(0, 2)
self._add(self.objs[1], self.objs[0]) self._connect(1, 0)
self._add(self.objs[2], self.objs[1]) self._connect(2, 1)
self._add(self.objs[0], self.objs[2]) self._collect(0)
self._check([0, [1, [2]]]) self._check([0, [1, [2]]])
def test_queries(self):
self._connect(1, 0)
self._connect(2, 0)
# 1 query to fetch all children of 0 (1 and 2)
# 1 query to fetch all children of 1 and 2 (none)
# Should not require additional queries to populate the nested graph.
self.assertNumQueries(2, self._collect, 0)
class UtilTests(unittest.TestCase): class UtilTests(unittest.TestCase):
def test_values_from_lookup_field(self): def test_values_from_lookup_field(self):