diff --git a/django/contrib/syndication/views.py b/django/contrib/syndication/views.py index a9d1bff5cf..2378a14874 100644 --- a/django/contrib/syndication/views.py +++ b/django/contrib/syndication/views.py @@ -1,3 +1,5 @@ +from inspect import getattr_static, unwrap + from django.contrib.sites.shortcuts import get_current_site from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist from django.http import Http404, HttpResponse @@ -82,10 +84,21 @@ class Feed: # Check co_argcount rather than try/excepting the function and # catching the TypeError, because something inside the function # may raise the TypeError. This technique is more accurate. + func = unwrap(attr) try: - code = attr.__code__ + code = func.__code__ except AttributeError: - code = attr.__call__.__code__ + func = unwrap(attr.__call__) + code = func.__code__ + # If function doesn't have arguments and it is not a static method, + # it was decorated without using @functools.wraps. + if not code.co_argcount and not isinstance( + getattr_static(self, func.__name__, None), staticmethod + ): + raise ImproperlyConfigured( + f"Feed method {attname!r} decorated by {func.__name__!r} needs to " + f"use @functools.wraps." + ) if code.co_argcount == 2: # one argument is 'self' return attr(obj) else: diff --git a/tests/syndication_tests/feeds.py b/tests/syndication_tests/feeds.py index 223a0b0bb1..a35dc29e20 100644 --- a/tests/syndication_tests/feeds.py +++ b/tests/syndication_tests/feeds.py @@ -1,3 +1,5 @@ +from functools import wraps + from django.contrib.syndication import views from django.utils import feedgenerator from django.utils.timezone import get_fixed_timezone @@ -5,6 +7,23 @@ from django.utils.timezone import get_fixed_timezone from .models import Article, Entry +def wraps_decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + value = f(*args, **kwargs) + return f"{value} -- decorated by @wraps." + + return wrapper + + +def common_decorator(f): + def wrapper(*args, **kwargs): + value = f(*args, **kwargs) + return f"{value} -- common decorated." + + return wrapper + + class TestRss2Feed(views.Feed): title = "My blog" description = "A more thorough description of my blog." @@ -47,11 +66,45 @@ class TestRss2FeedWithCallableObject(TestRss2Feed): ttl = TimeToLive() -class TestRss2FeedWithStaticMethod(TestRss2Feed): +class TestRss2FeedWithDecoratedMethod(TestRss2Feed): + class TimeToLive: + @wraps_decorator + def __call__(self): + return 800 + + @staticmethod + @wraps_decorator + def feed_copyright(): + return "Copyright (c) 2022, John Doe" + + ttl = TimeToLive() + @staticmethod def categories(): return ("javascript", "vue") + @wraps_decorator + def title(self): + return "Overridden title" + + @wraps_decorator + def item_title(self, item): + return f"Overridden item title: {item.title}" + + @wraps_decorator + def description(self, obj): + return "Overridden description" + + @wraps_decorator + def item_description(self): + return "Overridden item description" + + +class TestRss2FeedWithWrongDecoratedMethod(TestRss2Feed): + @common_decorator + def item_description(self, item): + return f"Overridden item description: {item.title}" + class TestRss2FeedWithGuidIsPermaLinkTrue(TestRss2Feed): def item_guid_is_permalink(self, item): diff --git a/tests/syndication_tests/tests.py b/tests/syndication_tests/tests.py index 6aaf80c1a9..a68ed879db 100644 --- a/tests/syndication_tests/tests.py +++ b/tests/syndication_tests/tests.py @@ -202,11 +202,38 @@ class SyndicationFeedTest(FeedTestCase): chan = doc.getElementsByTagName("rss")[0].getElementsByTagName("channel")[0] self.assertChildNodeContent(chan, {"ttl": "700"}) - def test_rss2_feed_with_static_methods(self): - response = self.client.get("/syndication/rss2/with-static-methods/") + def test_rss2_feed_with_decorated_methods(self): + response = self.client.get("/syndication/rss2/with-decorated-methods/") doc = minidom.parseString(response.content) chan = doc.getElementsByTagName("rss")[0].getElementsByTagName("channel")[0] self.assertCategories(chan, ["javascript", "vue"]) + self.assertChildNodeContent( + chan, + { + "title": "Overridden title -- decorated by @wraps.", + "description": "Overridden description -- decorated by @wraps.", + "ttl": "800 -- decorated by @wraps.", + "copyright": "Copyright (c) 2022, John Doe -- decorated by @wraps.", + }, + ) + items = chan.getElementsByTagName("item") + self.assertChildNodeContent( + items[0], + { + "title": ( + f"Overridden item title: {self.e1.title} -- decorated by @wraps." + ), + "description": "Overridden item description -- decorated by @wraps.", + }, + ) + + def test_rss2_feed_with_wrong_decorated_methods(self): + msg = ( + "Feed method 'item_description' decorated by 'wrapper' needs to use " + "@functools.wraps." + ) + with self.assertRaisesMessage(ImproperlyConfigured, msg): + self.client.get("/syndication/rss2/with-wrong-decorated-methods/") def test_rss2_feed_guid_permalink_false(self): """ diff --git a/tests/syndication_tests/urls.py b/tests/syndication_tests/urls.py index 5d2b23bf0a..50f673373e 100644 --- a/tests/syndication_tests/urls.py +++ b/tests/syndication_tests/urls.py @@ -7,7 +7,14 @@ urlpatterns = [ path( "syndication/rss2/with-callable-object/", feeds.TestRss2FeedWithCallableObject() ), - path("syndication/rss2/with-static-methods/", feeds.TestRss2FeedWithStaticMethod()), + path( + "syndication/rss2/with-decorated-methods/", + feeds.TestRss2FeedWithDecoratedMethod(), + ), + path( + "syndication/rss2/with-wrong-decorated-methods/", + feeds.TestRss2FeedWithWrongDecoratedMethod(), + ), path("syndication/rss2/articles//", feeds.TestGetObjectFeed()), path( "syndication/rss2/guid_ispermalink_true/",