mirror of https://github.com/django/django.git
107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
from asgiref.sync import iscoroutinefunction
|
|
|
|
from django.conf import settings
|
|
from django.http import HttpRequest, HttpResponse
|
|
from django.test import SimpleTestCase
|
|
from django.views.decorators.csrf import (
|
|
csrf_exempt,
|
|
csrf_protect,
|
|
ensure_csrf_cookie,
|
|
requires_csrf_token,
|
|
)
|
|
|
|
CSRF_TOKEN = "1bcdefghij2bcdefghij3bcdefghij4bcdefghij5bcdefghij6bcdefghijABCD"
|
|
|
|
|
|
class CsrfTestMixin:
|
|
def get_request(self, token=CSRF_TOKEN):
|
|
request = HttpRequest()
|
|
request.method = "POST"
|
|
if token:
|
|
request.POST["csrfmiddlewaretoken"] = token
|
|
request.COOKIES[settings.CSRF_COOKIE_NAME] = token
|
|
return request
|
|
|
|
|
|
class CsrfProtectTests(CsrfTestMixin, SimpleTestCase):
|
|
def test_csrf_protect_decorator(self):
|
|
@csrf_protect
|
|
def sync_view(request):
|
|
return HttpResponse()
|
|
|
|
request = self.get_request()
|
|
response = sync_view(request)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertIs(request.csrf_processing_done, True)
|
|
|
|
with self.assertLogs("django.security.csrf", "WARNING"):
|
|
request = self.get_request(token=None)
|
|
response = sync_view(request)
|
|
self.assertEqual(response.status_code, 403)
|
|
|
|
|
|
class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase):
|
|
def test_requires_csrf_token_decorator(self):
|
|
@requires_csrf_token
|
|
def sync_view(request):
|
|
return HttpResponse()
|
|
|
|
request = self.get_request()
|
|
response = sync_view(request)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertIs(request.csrf_processing_done, True)
|
|
|
|
with self.assertNoLogs("django.security.csrf", "WARNING"):
|
|
request = self.get_request(token=None)
|
|
response = sync_view(request)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
|
|
class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase):
|
|
def test_ensure_csrf_cookie_decorator(self):
|
|
@ensure_csrf_cookie
|
|
def sync_view(request):
|
|
return HttpResponse()
|
|
|
|
request = self.get_request()
|
|
response = sync_view(request)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertIs(request.csrf_processing_done, True)
|
|
|
|
with self.assertNoLogs("django.security.csrf", "WARNING"):
|
|
request = self.get_request(token=None)
|
|
response = sync_view(request)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
|
|
class CsrfExemptTests(SimpleTestCase):
|
|
def test_wrapped_sync_function_is_not_coroutine_function(self):
|
|
def sync_view(request):
|
|
return HttpResponse()
|
|
|
|
wrapped_view = csrf_exempt(sync_view)
|
|
self.assertIs(iscoroutinefunction(wrapped_view), False)
|
|
|
|
def test_wrapped_async_function_is_coroutine_function(self):
|
|
async def async_view(request):
|
|
return HttpResponse()
|
|
|
|
wrapped_view = csrf_exempt(async_view)
|
|
self.assertIs(iscoroutinefunction(wrapped_view), True)
|
|
|
|
def test_csrf_exempt_decorator(self):
|
|
@csrf_exempt
|
|
def sync_view(request):
|
|
return HttpResponse()
|
|
|
|
self.assertIs(sync_view.csrf_exempt, True)
|
|
self.assertIsInstance(sync_view(HttpRequest()), HttpResponse)
|
|
|
|
async def test_csrf_exempt_decorator_async_view(self):
|
|
@csrf_exempt
|
|
async def async_view(request):
|
|
return HttpResponse()
|
|
|
|
self.assertIs(async_view.csrf_exempt, True)
|
|
self.assertIsInstance(await async_view(HttpRequest()), HttpResponse)
|