mirror of https://github.com/django/django.git
Fixed #21391 -- Allow model signals to lazily reference their senders.
This commit is contained in:
parent
03bc0a8ac5
commit
eb38257e51
|
@ -1,5 +1,6 @@
|
|||
import collections
|
||||
import sys
|
||||
import types
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.color import color_style
|
||||
|
@ -25,7 +26,7 @@ def get_validation_errors(outfile, app=None):
|
|||
validates all models of all installed apps. Writes errors, if any, to outfile.
|
||||
Returns number of errors.
|
||||
"""
|
||||
from django.db import models, connection
|
||||
from django.db import connection, models
|
||||
from django.db.models.loading import get_app_errors
|
||||
from django.db.models.deletion import SET_NULL, SET_DEFAULT
|
||||
|
||||
|
@ -363,6 +364,8 @@ def get_validation_errors(outfile, app=None):
|
|||
for it in opts.index_together:
|
||||
validate_local_fields(e, opts, "index_together", it)
|
||||
|
||||
validate_model_signals(e)
|
||||
|
||||
return len(e.errors)
|
||||
|
||||
|
||||
|
@ -382,3 +385,28 @@ def validate_local_fields(e, opts, field_name, fields):
|
|||
e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name))
|
||||
if f not in opts.local_fields:
|
||||
e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name))
|
||||
|
||||
|
||||
def validate_model_signals(e):
|
||||
"""Ensure lazily referenced model signals senders are installed."""
|
||||
from django.db import models
|
||||
|
||||
for name in dir(models.signals):
|
||||
obj = getattr(models.signals, name)
|
||||
if isinstance(obj, models.signals.ModelSignal):
|
||||
for reference, receivers in obj.unresolved_references.items():
|
||||
for receiver, _, _ in receivers:
|
||||
# The receiver is either a function or an instance of class
|
||||
# defining a `__call__` method.
|
||||
if isinstance(receiver, types.FunctionType):
|
||||
description = "The `%s` function" % receiver.__name__
|
||||
else:
|
||||
description = "An instance of the `%s` class" % receiver.__class__.__name__
|
||||
e.add(
|
||||
receiver.__module__,
|
||||
"%s was connected to the `%s` signal "
|
||||
"with a lazy reference to the '%s' sender, "
|
||||
"which has not been installed." % (
|
||||
description, name, '.'.join(reference)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,20 +1,70 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from django.db.models.loading import get_model
|
||||
from django.dispatch import Signal
|
||||
from django.utils import six
|
||||
|
||||
|
||||
class_prepared = Signal(providing_args=["class"])
|
||||
|
||||
pre_init = Signal(providing_args=["instance", "args", "kwargs"], use_caching=True)
|
||||
post_init = Signal(providing_args=["instance"], use_caching=True)
|
||||
|
||||
pre_save = Signal(providing_args=["instance", "raw", "using", "update_fields"],
|
||||
class ModelSignal(Signal):
|
||||
"""
|
||||
Signal subclass that allows the sender to be lazily specified as a string
|
||||
of the `app_label.ModelName` form.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelSignal, self).__init__(*args, **kwargs)
|
||||
self.unresolved_references = defaultdict(list)
|
||||
class_prepared.connect(self._resolve_references)
|
||||
|
||||
def _resolve_references(self, sender, **kwargs):
|
||||
opts = sender._meta
|
||||
reference = (opts.app_label, opts.object_name)
|
||||
try:
|
||||
receivers = self.unresolved_references.pop(reference)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
for receiver, weak, dispatch_uid in receivers:
|
||||
super(ModelSignal, self).connect(
|
||||
receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid
|
||||
)
|
||||
|
||||
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None):
|
||||
if isinstance(sender, six.string_types):
|
||||
try:
|
||||
app_label, object_name = sender.split('.')
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Specified sender must either be a model or a "
|
||||
"model name of the 'app_label.ModelName' form."
|
||||
)
|
||||
sender = get_model(app_label, object_name, only_installed=False)
|
||||
if sender is None:
|
||||
reference = (app_label, object_name)
|
||||
self.unresolved_references[reference].append(
|
||||
(receiver, weak, dispatch_uid)
|
||||
)
|
||||
return
|
||||
super(ModelSignal, self).connect(
|
||||
receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid
|
||||
)
|
||||
|
||||
pre_init = ModelSignal(providing_args=["instance", "args", "kwargs"], use_caching=True)
|
||||
post_init = ModelSignal(providing_args=["instance"], use_caching=True)
|
||||
|
||||
pre_save = ModelSignal(providing_args=["instance", "raw", "using", "update_fields"],
|
||||
use_caching=True)
|
||||
post_save = Signal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True)
|
||||
post_save = ModelSignal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True)
|
||||
|
||||
pre_delete = Signal(providing_args=["instance", "using"], use_caching=True)
|
||||
post_delete = Signal(providing_args=["instance", "using"], use_caching=True)
|
||||
pre_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True)
|
||||
post_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True)
|
||||
|
||||
m2m_changed = ModelSignal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)
|
||||
|
||||
pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"])
|
||||
pre_syncdb = pre_migrate
|
||||
post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"])
|
||||
post_syncdb = post_migrate
|
||||
|
||||
m2m_changed = Signal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)
|
||||
|
|
|
@ -22,7 +22,7 @@ Model signals
|
|||
:synopsis: Signals sent by the model system.
|
||||
|
||||
The :mod:`django.db.models.signals` module defines a set of signals sent by the
|
||||
module system.
|
||||
model system.
|
||||
|
||||
.. warning::
|
||||
|
||||
|
@ -37,6 +37,14 @@ module system.
|
|||
so if your handler is a local function, it may be garbage collected. To
|
||||
prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
Model signals ``sender`` model can be lazily referenced when connecting a
|
||||
receiver by specifying its full application label. For example, an
|
||||
``Answer`` model defined in the ``polls`` application could be referenced
|
||||
as ``'polls.Answer'``. This sort of reference can be quite handy when
|
||||
dealing with circular import dependencies and swappable models.
|
||||
|
||||
pre_init
|
||||
--------
|
||||
|
||||
|
|
|
@ -425,7 +425,7 @@ Models
|
|||
* Is it now possible to avoid creating a backward relation for
|
||||
:class:`~django.db.models.OneToOneField` by setting its
|
||||
:attr:`~django.db.models.ForeignKey.related_name` to
|
||||
`'+'` or ending it with `'+'`.
|
||||
``'+'`` or ending it with ``'+'``.
|
||||
|
||||
* :class:`F expressions <django.db.models.F>` support the power operator
|
||||
(``**``).
|
||||
|
@ -436,6 +436,10 @@ Signals
|
|||
* The ``enter`` argument was added to the
|
||||
:data:`~django.test.signals.setting_changed` signal.
|
||||
|
||||
* The model signals can be now be connected to using a ``str`` of the
|
||||
``'app_label.ModelName'`` form – just like related fields – to lazily
|
||||
reference their senders.
|
||||
|
||||
Templates
|
||||
^^^^^^^^^
|
||||
|
||||
|
|
|
@ -413,6 +413,19 @@ different User model.
|
|||
class Article(models.Model):
|
||||
author = models.ForeignKey(settings.AUTH_USER_MODEL)
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
When connecting to signals sent by the User model, you should specify the
|
||||
custom model using the :setting:`AUTH_USER_MODEL` setting. For example::
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.signals import post_save
|
||||
|
||||
def post_save_receiver(signal, sender, instance, **kwargs):
|
||||
pass
|
||||
|
||||
post_save.connect(post_save_receiver, sender=settings.AUTH_USER_MODEL)
|
||||
|
||||
Specifying a custom User model
|
||||
------------------------------
|
||||
|
||||
|
|
|
@ -1,10 +1,22 @@
|
|||
from django.core import management
|
||||
from django.core.management.validation import (
|
||||
ModelErrorCollection, validate_model_signals
|
||||
)
|
||||
from django.db.models.signals import post_init
|
||||
from django.test import TestCase
|
||||
from django.utils import six
|
||||
|
||||
|
||||
class ModelValidationTest(TestCase):
|
||||
class OnPostInit(object):
|
||||
def __call__(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def on_post_init(**kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ModelValidationTest(TestCase):
|
||||
def test_models_validate(self):
|
||||
# All our models should validate properly
|
||||
# Validation Tests:
|
||||
|
@ -13,3 +25,23 @@ class ModelValidationTest(TestCase):
|
|||
# * related_name='+' doesn't clash with another '+'
|
||||
# See: https://code.djangoproject.com/ticket/21375
|
||||
management.call_command("validate", stdout=six.StringIO())
|
||||
|
||||
def test_model_signal(self):
|
||||
unresolved_references = post_init.unresolved_references.copy()
|
||||
post_init.connect(on_post_init, sender='missing-app.Model')
|
||||
post_init.connect(OnPostInit(), sender='missing-app.Model')
|
||||
e = ModelErrorCollection(six.StringIO())
|
||||
validate_model_signals(e)
|
||||
self.assertSetEqual(set(e.errors), {
|
||||
('model_validation.tests',
|
||||
"The `on_post_init` function was connected to the `post_init` "
|
||||
"signal with a lazy reference to the 'missing-app.Model' "
|
||||
"sender, which has not been installed."
|
||||
),
|
||||
('model_validation.tests',
|
||||
"An instance of the `OnPostInit` class was connected to "
|
||||
"the `post_init` signal with a lazy reference to the "
|
||||
"'missing-app.Model' sender, which has not been installed."
|
||||
)
|
||||
})
|
||||
post_init.unresolved_references = unresolved_references
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import signals
|
||||
from django.dispatch import receiver
|
||||
from django.test import TestCase
|
||||
|
@ -8,8 +9,7 @@ from django.utils import six
|
|||
from .models import Author, Book, Car, Person
|
||||
|
||||
|
||||
class SignalTests(TestCase):
|
||||
|
||||
class BaseSignalTest(TestCase):
|
||||
def setUp(self):
|
||||
# Save up the number of connected signals so that we can check at the
|
||||
# end that all the signals we register get properly unregistered (#9989)
|
||||
|
@ -30,6 +30,8 @@ class SignalTests(TestCase):
|
|||
)
|
||||
self.assertEqual(self.pre_signals, post_signals)
|
||||
|
||||
|
||||
class SignalTests(BaseSignalTest):
|
||||
def test_save_signals(self):
|
||||
data = []
|
||||
|
||||
|
@ -239,3 +241,48 @@ class SignalTests(TestCase):
|
|||
self.assertTrue(a._run)
|
||||
self.assertTrue(b._run)
|
||||
self.assertEqual(signals.post_save.receivers, [])
|
||||
|
||||
|
||||
class LazyModelRefTest(BaseSignalTest):
|
||||
def setUp(self):
|
||||
super(LazyModelRefTest, self).setUp()
|
||||
self.received = []
|
||||
|
||||
def receiver(self, **kwargs):
|
||||
self.received.append(kwargs)
|
||||
|
||||
def test_invalid_sender_model_name(self):
|
||||
with self.assertRaisesMessage(ValueError,
|
||||
"Specified sender must either be a model or a "
|
||||
"model name of the 'app_label.ModelName' form."):
|
||||
signals.post_init.connect(self.receiver, sender='invalid')
|
||||
|
||||
def test_already_loaded_model(self):
|
||||
signals.post_init.connect(
|
||||
self.receiver, sender='signals.Book', weak=False
|
||||
)
|
||||
try:
|
||||
instance = Book()
|
||||
self.assertEqual(self.received, [{
|
||||
'signal': signals.post_init,
|
||||
'sender': Book,
|
||||
'instance': instance
|
||||
}])
|
||||
finally:
|
||||
signals.post_init.disconnect(self.receiver, sender=Book)
|
||||
|
||||
def test_not_loaded_model(self):
|
||||
signals.post_init.connect(
|
||||
self.receiver, sender='signals.Created', weak=False
|
||||
)
|
||||
|
||||
try:
|
||||
class Created(models.Model):
|
||||
pass
|
||||
|
||||
instance = Created()
|
||||
self.assertEqual(self.received, [{
|
||||
'signal': signals.post_init, 'sender': Created, 'instance': instance
|
||||
}])
|
||||
finally:
|
||||
signals.post_init.disconnect(self.receiver, sender=Created)
|
||||
|
|
Loading…
Reference in New Issue