Fixed #21391 -- Allow model signals to lazily reference their senders.
This commit is contained in:
parent
03bc0a8ac5
commit
eb38257e51
|
@ -1,5 +1,6 @@
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.management.color import color_style
|
from django.core.management.color import color_style
|
||||||
|
@ -25,7 +26,7 @@ def get_validation_errors(outfile, app=None):
|
||||||
validates all models of all installed apps. Writes errors, if any, to outfile.
|
validates all models of all installed apps. Writes errors, if any, to outfile.
|
||||||
Returns number of errors.
|
Returns number of errors.
|
||||||
"""
|
"""
|
||||||
from django.db import models, connection
|
from django.db import connection, models
|
||||||
from django.db.models.loading import get_app_errors
|
from django.db.models.loading import get_app_errors
|
||||||
from django.db.models.deletion import SET_NULL, SET_DEFAULT
|
from django.db.models.deletion import SET_NULL, SET_DEFAULT
|
||||||
|
|
||||||
|
@ -363,6 +364,8 @@ def get_validation_errors(outfile, app=None):
|
||||||
for it in opts.index_together:
|
for it in opts.index_together:
|
||||||
validate_local_fields(e, opts, "index_together", it)
|
validate_local_fields(e, opts, "index_together", it)
|
||||||
|
|
||||||
|
validate_model_signals(e)
|
||||||
|
|
||||||
return len(e.errors)
|
return len(e.errors)
|
||||||
|
|
||||||
|
|
||||||
|
@ -382,3 +385,28 @@ def validate_local_fields(e, opts, field_name, fields):
|
||||||
e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name))
|
e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name))
|
||||||
if f not in opts.local_fields:
|
if f not in opts.local_fields:
|
||||||
e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name))
|
e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_signals(e):
|
||||||
|
"""Ensure lazily referenced model signals senders are installed."""
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
for name in dir(models.signals):
|
||||||
|
obj = getattr(models.signals, name)
|
||||||
|
if isinstance(obj, models.signals.ModelSignal):
|
||||||
|
for reference, receivers in obj.unresolved_references.items():
|
||||||
|
for receiver, _, _ in receivers:
|
||||||
|
# The receiver is either a function or an instance of class
|
||||||
|
# defining a `__call__` method.
|
||||||
|
if isinstance(receiver, types.FunctionType):
|
||||||
|
description = "The `%s` function" % receiver.__name__
|
||||||
|
else:
|
||||||
|
description = "An instance of the `%s` class" % receiver.__class__.__name__
|
||||||
|
e.add(
|
||||||
|
receiver.__module__,
|
||||||
|
"%s was connected to the `%s` signal "
|
||||||
|
"with a lazy reference to the '%s' sender, "
|
||||||
|
"which has not been installed." % (
|
||||||
|
description, name, '.'.join(reference)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,20 +1,70 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from django.db.models.loading import get_model
|
||||||
from django.dispatch import Signal
|
from django.dispatch import Signal
|
||||||
|
from django.utils import six
|
||||||
|
|
||||||
|
|
||||||
class_prepared = Signal(providing_args=["class"])
|
class_prepared = Signal(providing_args=["class"])
|
||||||
|
|
||||||
pre_init = Signal(providing_args=["instance", "args", "kwargs"], use_caching=True)
|
|
||||||
post_init = Signal(providing_args=["instance"], use_caching=True)
|
|
||||||
|
|
||||||
pre_save = Signal(providing_args=["instance", "raw", "using", "update_fields"],
|
class ModelSignal(Signal):
|
||||||
|
"""
|
||||||
|
Signal subclass that allows the sender to be lazily specified as a string
|
||||||
|
of the `app_label.ModelName` form.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ModelSignal, self).__init__(*args, **kwargs)
|
||||||
|
self.unresolved_references = defaultdict(list)
|
||||||
|
class_prepared.connect(self._resolve_references)
|
||||||
|
|
||||||
|
def _resolve_references(self, sender, **kwargs):
|
||||||
|
opts = sender._meta
|
||||||
|
reference = (opts.app_label, opts.object_name)
|
||||||
|
try:
|
||||||
|
receivers = self.unresolved_references.pop(reference)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for receiver, weak, dispatch_uid in receivers:
|
||||||
|
super(ModelSignal, self).connect(
|
||||||
|
receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid
|
||||||
|
)
|
||||||
|
|
||||||
|
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None):
|
||||||
|
if isinstance(sender, six.string_types):
|
||||||
|
try:
|
||||||
|
app_label, object_name = sender.split('.')
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
"Specified sender must either be a model or a "
|
||||||
|
"model name of the 'app_label.ModelName' form."
|
||||||
|
)
|
||||||
|
sender = get_model(app_label, object_name, only_installed=False)
|
||||||
|
if sender is None:
|
||||||
|
reference = (app_label, object_name)
|
||||||
|
self.unresolved_references[reference].append(
|
||||||
|
(receiver, weak, dispatch_uid)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
super(ModelSignal, self).connect(
|
||||||
|
receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid
|
||||||
|
)
|
||||||
|
|
||||||
|
pre_init = ModelSignal(providing_args=["instance", "args", "kwargs"], use_caching=True)
|
||||||
|
post_init = ModelSignal(providing_args=["instance"], use_caching=True)
|
||||||
|
|
||||||
|
pre_save = ModelSignal(providing_args=["instance", "raw", "using", "update_fields"],
|
||||||
use_caching=True)
|
use_caching=True)
|
||||||
post_save = Signal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True)
|
post_save = ModelSignal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True)
|
||||||
|
|
||||||
pre_delete = Signal(providing_args=["instance", "using"], use_caching=True)
|
pre_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True)
|
||||||
post_delete = Signal(providing_args=["instance", "using"], use_caching=True)
|
post_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True)
|
||||||
|
|
||||||
|
m2m_changed = ModelSignal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)
|
||||||
|
|
||||||
pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"])
|
pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"])
|
||||||
pre_syncdb = pre_migrate
|
pre_syncdb = pre_migrate
|
||||||
post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"])
|
post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"])
|
||||||
post_syncdb = post_migrate
|
post_syncdb = post_migrate
|
||||||
|
|
||||||
m2m_changed = Signal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ Model signals
|
||||||
:synopsis: Signals sent by the model system.
|
:synopsis: Signals sent by the model system.
|
||||||
|
|
||||||
The :mod:`django.db.models.signals` module defines a set of signals sent by the
|
The :mod:`django.db.models.signals` module defines a set of signals sent by the
|
||||||
module system.
|
model system.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
|
@ -37,6 +37,14 @@ module system.
|
||||||
so if your handler is a local function, it may be garbage collected. To
|
so if your handler is a local function, it may be garbage collected. To
|
||||||
prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`.
|
prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`.
|
||||||
|
|
||||||
|
.. versionadded:: 1.7
|
||||||
|
|
||||||
|
Model signals ``sender`` model can be lazily referenced when connecting a
|
||||||
|
receiver by specifying its full application label. For example, an
|
||||||
|
``Answer`` model defined in the ``polls`` application could be referenced
|
||||||
|
as ``'polls.Answer'``. This sort of reference can be quite handy when
|
||||||
|
dealing with circular import dependencies and swappable models.
|
||||||
|
|
||||||
pre_init
|
pre_init
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
|
|
@ -425,7 +425,7 @@ Models
|
||||||
* Is it now possible to avoid creating a backward relation for
|
* Is it now possible to avoid creating a backward relation for
|
||||||
:class:`~django.db.models.OneToOneField` by setting its
|
:class:`~django.db.models.OneToOneField` by setting its
|
||||||
:attr:`~django.db.models.ForeignKey.related_name` to
|
:attr:`~django.db.models.ForeignKey.related_name` to
|
||||||
`'+'` or ending it with `'+'`.
|
``'+'`` or ending it with ``'+'``.
|
||||||
|
|
||||||
* :class:`F expressions <django.db.models.F>` support the power operator
|
* :class:`F expressions <django.db.models.F>` support the power operator
|
||||||
(``**``).
|
(``**``).
|
||||||
|
@ -436,6 +436,10 @@ Signals
|
||||||
* The ``enter`` argument was added to the
|
* The ``enter`` argument was added to the
|
||||||
:data:`~django.test.signals.setting_changed` signal.
|
:data:`~django.test.signals.setting_changed` signal.
|
||||||
|
|
||||||
|
* The model signals can be now be connected to using a ``str`` of the
|
||||||
|
``'app_label.ModelName'`` form – just like related fields – to lazily
|
||||||
|
reference their senders.
|
||||||
|
|
||||||
Templates
|
Templates
|
||||||
^^^^^^^^^
|
^^^^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -413,6 +413,19 @@ different User model.
|
||||||
class Article(models.Model):
|
class Article(models.Model):
|
||||||
author = models.ForeignKey(settings.AUTH_USER_MODEL)
|
author = models.ForeignKey(settings.AUTH_USER_MODEL)
|
||||||
|
|
||||||
|
.. versionadded:: 1.7
|
||||||
|
|
||||||
|
When connecting to signals sent by the User model, you should specify the
|
||||||
|
custom model using the :setting:`AUTH_USER_MODEL` setting. For example::
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db.models.signals import post_save
|
||||||
|
|
||||||
|
def post_save_receiver(signal, sender, instance, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
post_save.connect(post_save_receiver, sender=settings.AUTH_USER_MODEL)
|
||||||
|
|
||||||
Specifying a custom User model
|
Specifying a custom User model
|
||||||
------------------------------
|
------------------------------
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,22 @@
|
||||||
from django.core import management
|
from django.core import management
|
||||||
|
from django.core.management.validation import (
|
||||||
|
ModelErrorCollection, validate_model_signals
|
||||||
|
)
|
||||||
|
from django.db.models.signals import post_init
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
|
|
||||||
class ModelValidationTest(TestCase):
|
class OnPostInit(object):
|
||||||
|
def __call__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_post_init(**kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelValidationTest(TestCase):
|
||||||
def test_models_validate(self):
|
def test_models_validate(self):
|
||||||
# All our models should validate properly
|
# All our models should validate properly
|
||||||
# Validation Tests:
|
# Validation Tests:
|
||||||
|
@ -13,3 +25,23 @@ class ModelValidationTest(TestCase):
|
||||||
# * related_name='+' doesn't clash with another '+'
|
# * related_name='+' doesn't clash with another '+'
|
||||||
# See: https://code.djangoproject.com/ticket/21375
|
# See: https://code.djangoproject.com/ticket/21375
|
||||||
management.call_command("validate", stdout=six.StringIO())
|
management.call_command("validate", stdout=six.StringIO())
|
||||||
|
|
||||||
|
def test_model_signal(self):
|
||||||
|
unresolved_references = post_init.unresolved_references.copy()
|
||||||
|
post_init.connect(on_post_init, sender='missing-app.Model')
|
||||||
|
post_init.connect(OnPostInit(), sender='missing-app.Model')
|
||||||
|
e = ModelErrorCollection(six.StringIO())
|
||||||
|
validate_model_signals(e)
|
||||||
|
self.assertSetEqual(set(e.errors), {
|
||||||
|
('model_validation.tests',
|
||||||
|
"The `on_post_init` function was connected to the `post_init` "
|
||||||
|
"signal with a lazy reference to the 'missing-app.Model' "
|
||||||
|
"sender, which has not been installed."
|
||||||
|
),
|
||||||
|
('model_validation.tests',
|
||||||
|
"An instance of the `OnPostInit` class was connected to "
|
||||||
|
"the `post_init` signal with a lazy reference to the "
|
||||||
|
"'missing-app.Model' sender, which has not been installed."
|
||||||
|
)
|
||||||
|
})
|
||||||
|
post_init.unresolved_references = unresolved_references
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
from django.db.models import signals
|
from django.db.models import signals
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
@ -8,8 +9,7 @@ from django.utils import six
|
||||||
from .models import Author, Book, Car, Person
|
from .models import Author, Book, Car, Person
|
||||||
|
|
||||||
|
|
||||||
class SignalTests(TestCase):
|
class BaseSignalTest(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Save up the number of connected signals so that we can check at the
|
# Save up the number of connected signals so that we can check at the
|
||||||
# end that all the signals we register get properly unregistered (#9989)
|
# end that all the signals we register get properly unregistered (#9989)
|
||||||
|
@ -30,6 +30,8 @@ class SignalTests(TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(self.pre_signals, post_signals)
|
self.assertEqual(self.pre_signals, post_signals)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalTests(BaseSignalTest):
|
||||||
def test_save_signals(self):
|
def test_save_signals(self):
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
|
@ -239,3 +241,48 @@ class SignalTests(TestCase):
|
||||||
self.assertTrue(a._run)
|
self.assertTrue(a._run)
|
||||||
self.assertTrue(b._run)
|
self.assertTrue(b._run)
|
||||||
self.assertEqual(signals.post_save.receivers, [])
|
self.assertEqual(signals.post_save.receivers, [])
|
||||||
|
|
||||||
|
|
||||||
|
class LazyModelRefTest(BaseSignalTest):
|
||||||
|
def setUp(self):
|
||||||
|
super(LazyModelRefTest, self).setUp()
|
||||||
|
self.received = []
|
||||||
|
|
||||||
|
def receiver(self, **kwargs):
|
||||||
|
self.received.append(kwargs)
|
||||||
|
|
||||||
|
def test_invalid_sender_model_name(self):
|
||||||
|
with self.assertRaisesMessage(ValueError,
|
||||||
|
"Specified sender must either be a model or a "
|
||||||
|
"model name of the 'app_label.ModelName' form."):
|
||||||
|
signals.post_init.connect(self.receiver, sender='invalid')
|
||||||
|
|
||||||
|
def test_already_loaded_model(self):
|
||||||
|
signals.post_init.connect(
|
||||||
|
self.receiver, sender='signals.Book', weak=False
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
instance = Book()
|
||||||
|
self.assertEqual(self.received, [{
|
||||||
|
'signal': signals.post_init,
|
||||||
|
'sender': Book,
|
||||||
|
'instance': instance
|
||||||
|
}])
|
||||||
|
finally:
|
||||||
|
signals.post_init.disconnect(self.receiver, sender=Book)
|
||||||
|
|
||||||
|
def test_not_loaded_model(self):
|
||||||
|
signals.post_init.connect(
|
||||||
|
self.receiver, sender='signals.Created', weak=False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
class Created(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
instance = Created()
|
||||||
|
self.assertEqual(self.received, [{
|
||||||
|
'signal': signals.post_init, 'sender': Created, 'instance': instance
|
||||||
|
}])
|
||||||
|
finally:
|
||||||
|
signals.post_init.disconnect(self.receiver, sender=Created)
|
||||||
|
|
Loading…
Reference in New Issue