diff --git a/django/core/management/validation.py b/django/core/management/validation.py index 0a976d72a1..5bf9413c20 100644 --- a/django/core/management/validation.py +++ b/django/core/management/validation.py @@ -1,5 +1,6 @@ import collections import sys +import types from django.conf import settings from django.core.management.color import color_style @@ -25,7 +26,7 @@ def get_validation_errors(outfile, app=None): validates all models of all installed apps. Writes errors, if any, to outfile. Returns number of errors. """ - from django.db import models, connection + from django.db import connection, models from django.db.models.loading import get_app_errors from django.db.models.deletion import SET_NULL, SET_DEFAULT @@ -363,6 +364,8 @@ def get_validation_errors(outfile, app=None): for it in opts.index_together: validate_local_fields(e, opts, "index_together", it) + validate_model_signals(e) + return len(e.errors) @@ -382,3 +385,28 @@ def validate_local_fields(e, opts, field_name, fields): e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name)) if f not in opts.local_fields: e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name)) + + +def validate_model_signals(e): + """Ensure lazily referenced model signals senders are installed.""" + from django.db import models + + for name in dir(models.signals): + obj = getattr(models.signals, name) + if isinstance(obj, models.signals.ModelSignal): + for reference, receivers in obj.unresolved_references.items(): + for receiver, _, _ in receivers: + # The receiver is either a function or an instance of class + # defining a `__call__` method. + if isinstance(receiver, types.FunctionType): + description = "The `%s` function" % receiver.__name__ + else: + description = "An instance of the `%s` class" % receiver.__class__.__name__ + e.add( + receiver.__module__, + "%s was connected to the `%s` signal " + "with a lazy reference to the '%s' sender, " + "which has not been installed." % ( + description, name, '.'.join(reference) + ) + ) diff --git a/django/db/models/signals.py b/django/db/models/signals.py index 6b7605839c..6b011c2099 100644 --- a/django/db/models/signals.py +++ b/django/db/models/signals.py @@ -1,20 +1,70 @@ +from collections import defaultdict + +from django.db.models.loading import get_model from django.dispatch import Signal +from django.utils import six + class_prepared = Signal(providing_args=["class"]) -pre_init = Signal(providing_args=["instance", "args", "kwargs"], use_caching=True) -post_init = Signal(providing_args=["instance"], use_caching=True) -pre_save = Signal(providing_args=["instance", "raw", "using", "update_fields"], - use_caching=True) -post_save = Signal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True) +class ModelSignal(Signal): + """ + Signal subclass that allows the sender to be lazily specified as a string + of the `app_label.ModelName` form. + """ -pre_delete = Signal(providing_args=["instance", "using"], use_caching=True) -post_delete = Signal(providing_args=["instance", "using"], use_caching=True) + def __init__(self, *args, **kwargs): + super(ModelSignal, self).__init__(*args, **kwargs) + self.unresolved_references = defaultdict(list) + class_prepared.connect(self._resolve_references) + + def _resolve_references(self, sender, **kwargs): + opts = sender._meta + reference = (opts.app_label, opts.object_name) + try: + receivers = self.unresolved_references.pop(reference) + except KeyError: + pass + else: + for receiver, weak, dispatch_uid in receivers: + super(ModelSignal, self).connect( + receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid + ) + + def connect(self, receiver, sender=None, weak=True, dispatch_uid=None): + if isinstance(sender, six.string_types): + try: + app_label, object_name = sender.split('.') + except ValueError: + raise ValueError( + "Specified sender must either be a model or a " + "model name of the 'app_label.ModelName' form." + ) + sender = get_model(app_label, object_name, only_installed=False) + if sender is None: + reference = (app_label, object_name) + self.unresolved_references[reference].append( + (receiver, weak, dispatch_uid) + ) + return + super(ModelSignal, self).connect( + receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid + ) + +pre_init = ModelSignal(providing_args=["instance", "args", "kwargs"], use_caching=True) +post_init = ModelSignal(providing_args=["instance"], use_caching=True) + +pre_save = ModelSignal(providing_args=["instance", "raw", "using", "update_fields"], + use_caching=True) +post_save = ModelSignal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True) + +pre_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True) +post_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True) + +m2m_changed = ModelSignal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True) pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"]) pre_syncdb = pre_migrate post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"]) post_syncdb = post_migrate - -m2m_changed = Signal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True) diff --git a/docs/ref/signals.txt b/docs/ref/signals.txt index 051f1fa915..1abfed8806 100644 --- a/docs/ref/signals.txt +++ b/docs/ref/signals.txt @@ -22,7 +22,7 @@ Model signals :synopsis: Signals sent by the model system. The :mod:`django.db.models.signals` module defines a set of signals sent by the -module system. +model system. .. warning:: @@ -37,6 +37,14 @@ module system. so if your handler is a local function, it may be garbage collected. To prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`. +.. versionadded:: 1.7 + + Model signals ``sender`` model can be lazily referenced when connecting a + receiver by specifying its full application label. For example, an + ``Answer`` model defined in the ``polls`` application could be referenced + as ``'polls.Answer'``. This sort of reference can be quite handy when + dealing with circular import dependencies and swappable models. + pre_init -------- diff --git a/docs/releases/1.7.txt b/docs/releases/1.7.txt index 8c3226ff48..4999fc11a8 100644 --- a/docs/releases/1.7.txt +++ b/docs/releases/1.7.txt @@ -425,7 +425,7 @@ Models * Is it now possible to avoid creating a backward relation for :class:`~django.db.models.OneToOneField` by setting its :attr:`~django.db.models.ForeignKey.related_name` to - `'+'` or ending it with `'+'`. + ``'+'`` or ending it with ``'+'``. * :class:`F expressions ` support the power operator (``**``). @@ -436,6 +436,10 @@ Signals * The ``enter`` argument was added to the :data:`~django.test.signals.setting_changed` signal. +* The model signals can be now be connected to using a ``str`` of the + ``'app_label.ModelName'`` form – just like related fields – to lazily + reference their senders. + Templates ^^^^^^^^^ diff --git a/docs/topics/auth/customizing.txt b/docs/topics/auth/customizing.txt index fa6075fac8..7b15a2f8d3 100644 --- a/docs/topics/auth/customizing.txt +++ b/docs/topics/auth/customizing.txt @@ -413,6 +413,19 @@ different User model. class Article(models.Model): author = models.ForeignKey(settings.AUTH_USER_MODEL) + .. versionadded:: 1.7 + + When connecting to signals sent by the User model, you should specify the + custom model using the :setting:`AUTH_USER_MODEL` setting. For example:: + + from django.conf import settings + from django.db.models.signals import post_save + + def post_save_receiver(signal, sender, instance, **kwargs): + pass + + post_save.connect(post_save_receiver, sender=settings.AUTH_USER_MODEL) + Specifying a custom User model ------------------------------ diff --git a/tests/model_validation/tests.py b/tests/model_validation/tests.py index 494af97f96..c55a1307e1 100644 --- a/tests/model_validation/tests.py +++ b/tests/model_validation/tests.py @@ -1,10 +1,22 @@ from django.core import management +from django.core.management.validation import ( + ModelErrorCollection, validate_model_signals +) +from django.db.models.signals import post_init from django.test import TestCase from django.utils import six -class ModelValidationTest(TestCase): +class OnPostInit(object): + def __call__(self, **kwargs): + pass + +def on_post_init(**kwargs): + pass + + +class ModelValidationTest(TestCase): def test_models_validate(self): # All our models should validate properly # Validation Tests: @@ -13,3 +25,23 @@ class ModelValidationTest(TestCase): # * related_name='+' doesn't clash with another '+' # See: https://code.djangoproject.com/ticket/21375 management.call_command("validate", stdout=six.StringIO()) + + def test_model_signal(self): + unresolved_references = post_init.unresolved_references.copy() + post_init.connect(on_post_init, sender='missing-app.Model') + post_init.connect(OnPostInit(), sender='missing-app.Model') + e = ModelErrorCollection(six.StringIO()) + validate_model_signals(e) + self.assertSetEqual(set(e.errors), { + ('model_validation.tests', + "The `on_post_init` function was connected to the `post_init` " + "signal with a lazy reference to the 'missing-app.Model' " + "sender, which has not been installed." + ), + ('model_validation.tests', + "An instance of the `OnPostInit` class was connected to " + "the `post_init` signal with a lazy reference to the " + "'missing-app.Model' sender, which has not been installed." + ) + }) + post_init.unresolved_references = unresolved_references diff --git a/tests/signals/tests.py b/tests/signals/tests.py index 4be8d9b65d..c5c5b9c5e7 100644 --- a/tests/signals/tests.py +++ b/tests/signals/tests.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +from django.db import models from django.db.models import signals from django.dispatch import receiver from django.test import TestCase @@ -8,8 +9,7 @@ from django.utils import six from .models import Author, Book, Car, Person -class SignalTests(TestCase): - +class BaseSignalTest(TestCase): def setUp(self): # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered (#9989) @@ -30,6 +30,8 @@ class SignalTests(TestCase): ) self.assertEqual(self.pre_signals, post_signals) + +class SignalTests(BaseSignalTest): def test_save_signals(self): data = [] @@ -239,3 +241,48 @@ class SignalTests(TestCase): self.assertTrue(a._run) self.assertTrue(b._run) self.assertEqual(signals.post_save.receivers, []) + + +class LazyModelRefTest(BaseSignalTest): + def setUp(self): + super(LazyModelRefTest, self).setUp() + self.received = [] + + def receiver(self, **kwargs): + self.received.append(kwargs) + + def test_invalid_sender_model_name(self): + with self.assertRaisesMessage(ValueError, + "Specified sender must either be a model or a " + "model name of the 'app_label.ModelName' form."): + signals.post_init.connect(self.receiver, sender='invalid') + + def test_already_loaded_model(self): + signals.post_init.connect( + self.receiver, sender='signals.Book', weak=False + ) + try: + instance = Book() + self.assertEqual(self.received, [{ + 'signal': signals.post_init, + 'sender': Book, + 'instance': instance + }]) + finally: + signals.post_init.disconnect(self.receiver, sender=Book) + + def test_not_loaded_model(self): + signals.post_init.connect( + self.receiver, sender='signals.Created', weak=False + ) + + try: + class Created(models.Model): + pass + + instance = Created() + self.assertEqual(self.received, [{ + 'signal': signals.post_init, 'sender': Created, 'instance': instance + }]) + finally: + signals.post_init.disconnect(self.receiver, sender=Created)