[1.8.x] Fixed #25389 -- Fixed pickling a SimpleLazyObject wrapping a model.

Pickling a `SimpleLazyObject` wrapping a model did not work correctly; in
particular it did not add the `_django_version` attribute added in 42736ac8.
Now it will handle this and other custom `__reduce__` methods correctly.

Backport of 35355a4ffe from master
This commit is contained in:
Ben Kraft 2015-09-11 23:06:25 -07:00 committed by Tim Graham
parent 29c9a7d220
commit c03f0c282d
4 changed files with 128 additions and 24 deletions

View File

@ -6,7 +6,6 @@ from functools import wraps
from django.utils import six from django.utils import six
from django.utils.deprecation import RemovedInDjango19Warning from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.six.moves import copyreg
# You can't trivially replace this with `functools.partial` because this binds # You can't trivially replace this with `functools.partial` because this binds
@ -268,32 +267,30 @@ class LazyObject(object):
raise NotImplementedError('subclasses of LazyObject must provide a _setup() method') raise NotImplementedError('subclasses of LazyObject must provide a _setup() method')
# Because we have messed with __class__ below, we confuse pickle as to what # Because we have messed with __class__ below, we confuse pickle as to what
# class we are pickling. It also appears to stop __reduce__ from being # class we are pickling. We're going to have to initialize the wrapped
# called. So, we define __getstate__ in a way that cooperates with the way # object to successfully pickle it, so we might as well just pickle the
# that pickle interprets this class. This fails when the wrapped class is # wrapped object since they're supposed to act the same way.
# a builtin, but it is better than nothing. #
def __getstate__(self): # Unfortunately, if we try to simply act like the wrapped object, the ruse
# will break down when pickle gets our id(). Thus we end up with pickle
# thinking, in effect, that we are a distinct object from the wrapped
# object, but with the same __dict__. This can cause problems (see #25389).
#
# So instead, we define our own __reduce__ method and custom unpickler. We
# pickle the wrapped object as the unpickler's argument, so that pickle
# will pickle it normally, and then the unpickler simply returns its
# argument.
def __reduce__(self):
if self._wrapped is empty: if self._wrapped is empty:
self._setup() self._setup()
return self._wrapped.__dict__ return (unpickle_lazyobject, (self._wrapped,))
# Python 3.3 will call __reduce__ when pickling; this method is needed # We have to explicitly override __getstate__ so that older versions of
# to serialize and deserialize correctly. # pickle don't try to pickle the __dict__ (which in the case of a
@classmethod # SimpleLazyObject may contain a lambda). The value will end up being
def __newobj__(cls, *args): # ignored by our __reduce__ and custom unpickler.
return cls.__new__(cls, *args) def __getstate__(self):
return {}
def __reduce_ex__(self, proto):
if proto >= 2:
# On Py3, since the default protocol is 3, pickle uses the
# ``__newobj__`` method (& more efficient opcodes) for writing.
return (self.__newobj__, (self.__class__,), self.__getstate__())
else:
# On Py2, the default protocol is 0 (for back-compat) & the above
# code fails miserably (see regression test). Instead, we return
# exactly what's returned if there's no ``__reduce__`` method at
# all.
return (copyreg._reconstructor, (self.__class__, object, None), self.__getstate__())
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
if self._wrapped is empty: if self._wrapped is empty:
@ -332,6 +329,15 @@ class LazyObject(object):
__contains__ = new_method_proxy(operator.contains) __contains__ = new_method_proxy(operator.contains)
def unpickle_lazyobject(wrapped):
"""
Used to unpickle lazy objects. Just return its argument, which will be the
wrapped object.
"""
return wrapped
unpickle_lazyobject.__safe_for_unpickling__ = True
# Workaround for http://bugs.python.org/issue12370 # Workaround for http://bugs.python.org/issue12370
_super = super _super = super

View File

@ -56,3 +56,5 @@ Bugfixes
* Fixed incorrect queries with multiple many-to-many fields on a model with the * Fixed incorrect queries with multiple many-to-many fields on a model with the
same 'to' model and with ``related_name`` set to '+' (:ticket:`24505`, same 'to' model and with ``related_name`` set to '+' (:ticket:`24505`,
:ticket:`25486`). :ticket:`25486`).
* Fixed pickling a ``SimpleLazyObject`` wrapping a model (:ticket:`25389`).

View File

@ -11,3 +11,7 @@ class Category(models.Model):
class Thing(models.Model): class Thing(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
category = models.ForeignKey(Category) category = models.ForeignKey(Category)
class CategoryInfo(models.Model):
category = models.OneToOneField(Category)

View File

@ -3,11 +3,14 @@ from __future__ import unicode_literals
import copy import copy
import pickle import pickle
import sys import sys
import warnings
from unittest import TestCase from unittest import TestCase
from django.utils import six from django.utils import six
from django.utils.functional import LazyObject, SimpleLazyObject, empty from django.utils.functional import LazyObject, SimpleLazyObject, empty
from .models import Category, CategoryInfo
class Foo(object): class Foo(object):
""" """
@ -273,3 +276,92 @@ class SimpleLazyObjectTestCase(LazyObjectTestCase):
self.assertNotIn(6, lazy_set) self.assertNotIn(6, lazy_set)
self.assertEqual(len(lazy_list), 5) self.assertEqual(len(lazy_list), 5)
self.assertEqual(len(lazy_set), 4) self.assertEqual(len(lazy_set), 4)
class BaseBaz(object):
"""
A base class with a funky __reduce__ method, meant to simulate the
__reduce__ method of Model, which sets self._django_version.
"""
def __init__(self):
self.baz = 'wrong'
def __reduce__(self):
self.baz = 'right'
return super(BaseBaz, self).__reduce__()
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
for attr in ['bar', 'baz', 'quux']:
if hasattr(self, attr) != hasattr(other, attr):
return False
elif getattr(self, attr, None) != getattr(other, attr, None):
return False
return True
class Baz(BaseBaz):
"""
A class that inherits from BaseBaz and has its own __reduce_ex__ method.
"""
def __init__(self, bar):
self.bar = bar
super(Baz, self).__init__()
def __reduce_ex__(self, proto):
self.quux = 'quux'
return super(Baz, self).__reduce_ex__(proto)
class BazProxy(Baz):
"""
A class that acts as a proxy for Baz. It does some scary mucking about with
dicts, which simulates some crazy things that people might do with
e.g. proxy models.
"""
def __init__(self, baz):
self.__dict__ = baz.__dict__
self._baz = baz
super(BaseBaz, self).__init__()
class SimpleLazyObjectPickleTestCase(TestCase):
"""
Regression test for pickling a SimpleLazyObject wrapping a model (#25389).
Also covers other classes with a custom __reduce__ method.
"""
def test_pickle_with_reduce(self):
"""
Test in a fairly synthetic setting.
"""
# Test every pickle protocol available
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
lazy_objs = [
SimpleLazyObject(lambda: BaseBaz()),
SimpleLazyObject(lambda: Baz(1)),
SimpleLazyObject(lambda: BazProxy(Baz(2))),
]
for obj in lazy_objs:
pickled = pickle.dumps(obj, protocol)
unpickled = pickle.loads(pickled)
self.assertEqual(unpickled, obj)
self.assertEqual(unpickled.baz, 'right')
def test_pickle_model(self):
"""
Test on an actual model, based on the report in #25426.
"""
category = Category.objects.create(name="thing1")
CategoryInfo.objects.create(category=category)
# Test every pickle protocol available
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
lazy_category = SimpleLazyObject(lambda: category)
# Test both if we accessed a field on the model and if we didn't.
lazy_category.categoryinfo
lazy_category_2 = SimpleLazyObject(lambda: category)
with warnings.catch_warnings(record=True) as recorded:
self.assertEqual(pickle.loads(pickle.dumps(lazy_category, protocol)), category)
self.assertEqual(pickle.loads(pickle.dumps(lazy_category_2, protocol)), category)
# Assert that there were no warnings.
self.assertEqual(len(recorded), 0)