Fixed #29253 -- Made method_decorator(list) copy attributes.

This commit is contained in:
Chris Jerdonek 2018-04-16 10:38:37 -07:00 committed by Tim Graham
parent a480ef89ad
commit fdc936c913
2 changed files with 82 additions and 67 deletions

View File

@ -12,6 +12,44 @@ class classonlymethod(classmethod):
return super().__get__(instance, cls)
def _update_method_wrapper(_wrapper, decorator):
# _multi_decorate()'s bound_method isn't available in this scope. Cheat by
# using it on a dummy function.
@decorator
def dummy(*args, **kwargs):
pass
update_wrapper(_wrapper, dummy)
def _multi_decorate(decorators, method):
"""
Decorate `method` with one or more function decorators. `decorators` can be
a single decorator or an iterable of decorators.
"""
if hasattr(decorators, '__iter__'):
# Apply a list/tuple of decorators if 'decorators' is one. Decorator
# functions are applied so that the call order is the same as the
# order in which they appear in the iterable.
decorators = decorators[::-1]
else:
decorators = [decorators]
def _wrapper(self, *args, **kwargs):
# bound_method has the signature that 'decorator' expects i.e. no
# 'self' argument.
bound_method = method.__get__(self, type(self))
for dec in decorators:
bound_method = dec(bound_method)
return bound_method(*args, **kwargs)
# Copy any attributes that a decorator adds to the function it decorates.
for dec in decorators:
_update_method_wrapper(_wrapper, dec)
# Preserve any existing attributes of 'method', including the name.
update_wrapper(_wrapper, method)
return _wrapper
def method_decorator(decorator, name=''):
"""
Convert a function decorator into a method decorator
@ -21,70 +59,30 @@ def method_decorator(decorator, name=''):
# defined on. If 'obj' is a class, the 'name' is required to be the name
# of the method that will be decorated.
def _dec(obj):
is_class = isinstance(obj, type)
if is_class:
if name and hasattr(obj, name):
func = getattr(obj, name)
if not callable(func):
raise TypeError(
"Cannot decorate '{0}' as it isn't a callable "
"attribute of {1} ({2})".format(name, obj, func)
)
else:
raise ValueError(
"The keyword argument `name` must be the name of a method "
"of the decorated class: {0}. Got '{1}' instead".format(
obj, name,
)
)
else:
func = obj
if not isinstance(obj, type):
return _multi_decorate(decorator, obj)
if not (name and hasattr(obj, name)):
raise ValueError(
"The keyword argument `name` must be the name of a method "
"of the decorated class: %s. Got '%s' instead." % (obj, name)
)
method = getattr(obj, name)
if not callable(method):
raise TypeError(
"Cannot decorate '%s' as it isn't a callable attribute of "
"%s (%s)." % (name, obj, method)
)
_wrapper = _multi_decorate(decorator, method)
setattr(obj, name, _wrapper)
return obj
def decorate(function):
"""
Apply a list/tuple of decorators if decorator is one. Decorator
functions are applied so that the call order is the same as the
order in which they appear in the iterable.
"""
if hasattr(decorator, '__iter__'):
for dec in decorator[::-1]:
function = dec(function)
return function
return decorator(function)
def _wrapper(self, *args, **kwargs):
@decorate
def bound_func(*args2, **kwargs2):
return func.__get__(self, type(self))(*args2, **kwargs2)
# bound_func has the signature that 'decorator' expects i.e. no
# 'self' argument, but it is a closure over self so it can call
# 'func' correctly.
return bound_func(*args, **kwargs)
# In case 'decorator' adds attributes to the function it decorates, we
# want to copy those. We don't have access to bound_func in this scope,
# but we can cheat by using it on a dummy function.
@decorate
def dummy(*args, **kwargs):
pass
update_wrapper(_wrapper, dummy)
# Need to preserve any existing attributes of 'func', including the name.
update_wrapper(_wrapper, func)
if is_class:
setattr(obj, name, _wrapper)
return obj
return _wrapper
# Don't worry about making _dec look similar to a list/tuple as it's rather
# meaningless.
if not hasattr(decorator, '__iter__'):
update_wrapper(_dec, decorator)
# Change the name to aid debugging.
if hasattr(decorator, '__name__'):
_dec.__name__ = 'method_decorator(%s)' % decorator.__name__
else:
_dec.__name__ = 'method_decorator(%s)' % decorator.__class__.__name__
obj = decorator if hasattr(decorator, '__name__') else decorator.__class__
_dec.__name__ = 'method_decorator(%s)' % obj.__name__
return _dec

View File

@ -168,7 +168,7 @@ def myattr_dec(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.myattr = True
return wraps(func)(wrapper)
return wrapper
myattr_dec_m = method_decorator(myattr_dec)
@ -178,7 +178,7 @@ def myattr2_dec(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.myattr2 = True
return wraps(func)(wrapper)
return wrapper
myattr2_dec_m = method_decorator(myattr2_dec)
@ -209,13 +209,23 @@ class MethodDecoratorTests(SimpleTestCase):
def test_preserve_attributes(self):
# Sanity check myattr_dec and myattr2_dec
@myattr_dec
def func():
pass
self.assertIs(getattr(func, 'myattr', False), True)
@myattr2_dec
def func():
pass
self.assertIs(getattr(func, 'myattr2', False), True)
@myattr_dec
@myattr2_dec
def func():
pass
self.assertIs(getattr(func, 'myattr', False), True)
self.assertIs(getattr(func, 'myattr2', False), True)
self.assertIs(getattr(func, 'myattr2', False), False)
# Decorate using method_decorator() on the method.
class TestPlain:
@ -235,16 +245,23 @@ class MethodDecoratorTests(SimpleTestCase):
"A method"
pass
# Decorate using an iterable of decorators.
decorators = (myattr_dec_m, myattr2_dec_m)
@method_decorator(decorators, "method")
class TestIterable:
# Decorate using an iterable of function decorators.
@method_decorator((myattr_dec, myattr2_dec), 'method')
class TestFunctionIterable:
def method(self):
"A method"
pass
tests = (TestPlain, TestMethodAndClass, TestIterable)
# Decorate using an iterable of method decorators.
decorators = (myattr_dec_m, myattr2_dec_m)
@method_decorator(decorators, "method")
class TestMethodIterable:
def method(self):
"A method"
pass
tests = (TestPlain, TestMethodAndClass, TestFunctionIterable, TestMethodIterable)
for Test in tests:
with self.subTest(Test=Test):
self.assertIs(getattr(Test().method, 'myattr', False), True)