diff --git a/django/contrib/sessions/middleware.py b/django/contrib/sessions/middleware.py index e76c08ee5da..63013eef7a5 100644 --- a/django/contrib/sessions/middleware.py +++ b/django/contrib/sessions/middleware.py @@ -15,6 +15,7 @@ class SessionMiddleware(MiddlewareMixin): def __init__(self, get_response=None): self._get_response_none_deprecation(get_response) self.get_response = get_response + self._async_check() engine = import_module(settings.SESSION_ENGINE) self.SessionStore = engine.SessionStore diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index bb782dad9be..82d2e1ab9de 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -1,4 +1,3 @@ -import asyncio import logging import sys import tempfile @@ -132,7 +131,7 @@ class ASGIHandler(base.BaseHandler): def __init__(self): super().__init__() - self.load_middleware() + self.load_middleware(is_async=True) async def __call__(self, scope, receive, send): """ @@ -158,12 +157,8 @@ class ASGIHandler(base.BaseHandler): if request is None: await self.send_response(error_response, send) return - # Get the response, using a threadpool via sync_to_async, if needed. - if asyncio.iscoroutinefunction(self.get_response): - response = await self.get_response(request) - else: - # If get_response is synchronous, run it non-blocking. - response = await sync_to_async(self.get_response)(request) + # Get the response, using the async mode of BaseHandler. + response = await self.get_response_async(request) response._handler_class = self.__class__ # Increase chunk size on file responses (ASGI servers handles low-level # chunking). @@ -264,7 +259,7 @@ class ASGIHandler(base.BaseHandler): 'body': chunk, 'more_body': not last, }) - response.close() + await sync_to_async(response.close)() @classmethod def chunk_bytes(cls, data): diff --git a/django/core/handlers/base.py b/django/core/handlers/base.py index 418bc7a46bf..e7fbaf594e2 100644 --- a/django/core/handlers/base.py +++ b/django/core/handlers/base.py @@ -1,6 +1,9 @@ +import asyncio import logging import types +from asgiref.sync import async_to_sync, sync_to_async + from django.conf import settings from django.core.exceptions import ImproperlyConfigured, MiddlewareNotUsed from django.core.signals import request_finished @@ -20,7 +23,7 @@ class BaseHandler: _exception_middleware = None _middleware_chain = None - def load_middleware(self): + def load_middleware(self, is_async=False): """ Populate middleware lists from settings.MIDDLEWARE. @@ -30,10 +33,28 @@ class BaseHandler: self._template_response_middleware = [] self._exception_middleware = [] - handler = convert_exception_to_response(self._get_response) + get_response = self._get_response_async if is_async else self._get_response + handler = convert_exception_to_response(get_response) + handler_is_async = is_async for middleware_path in reversed(settings.MIDDLEWARE): middleware = import_string(middleware_path) + middleware_can_sync = getattr(middleware, 'sync_capable', True) + middleware_can_async = getattr(middleware, 'async_capable', False) + if not middleware_can_sync and not middleware_can_async: + raise RuntimeError( + 'Middleware %s must have at least one of ' + 'sync_capable/async_capable set to True.' % middleware_path + ) + elif not handler_is_async and middleware_can_sync: + middleware_is_async = False + else: + middleware_is_async = middleware_can_async try: + # Adapt handler, if needed. + handler = self.adapt_method_mode( + middleware_is_async, handler, handler_is_async, + debug=settings.DEBUG, name='middleware %s' % middleware_path, + ) mw_instance = middleware(handler) except MiddlewareNotUsed as exc: if settings.DEBUG: @@ -49,24 +70,56 @@ class BaseHandler: ) if hasattr(mw_instance, 'process_view'): - self._view_middleware.insert(0, mw_instance.process_view) + self._view_middleware.insert( + 0, + self.adapt_method_mode(is_async, mw_instance.process_view), + ) if hasattr(mw_instance, 'process_template_response'): - self._template_response_middleware.append(mw_instance.process_template_response) + self._template_response_middleware.append( + self.adapt_method_mode(is_async, mw_instance.process_template_response), + ) if hasattr(mw_instance, 'process_exception'): - self._exception_middleware.append(mw_instance.process_exception) + # The exception-handling stack is still always synchronous for + # now, so adapt that way. + self._exception_middleware.append( + self.adapt_method_mode(False, mw_instance.process_exception), + ) handler = convert_exception_to_response(mw_instance) + handler_is_async = middleware_is_async + # Adapt the top of the stack, if needed. + handler = self.adapt_method_mode(is_async, handler, handler_is_async) # We only assign to this when initialization is complete as it is used # as a flag for initialization being complete. self._middleware_chain = handler - def make_view_atomic(self, view): - non_atomic_requests = getattr(view, '_non_atomic_requests', set()) - for db in connections.all(): - if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: - view = transaction.atomic(using=db.alias)(view) - return view + def adapt_method_mode( + self, is_async, method, method_is_async=None, debug=False, name=None, + ): + """ + Adapt a method to be in the correct "mode": + - If is_async is False: + - Synchronous methods are left alone + - Asynchronous methods are wrapped with async_to_sync + - If is_async is True: + - Synchronous methods are wrapped with sync_to_async() + - Asynchronous methods are left alone + """ + if method_is_async is None: + method_is_async = asyncio.iscoroutinefunction(method) + if debug and not name: + name = name or 'method %s()' % method.__qualname__ + if is_async: + if not method_is_async: + if debug: + logger.debug('Synchronous %s adapted.', name) + return sync_to_async(method, thread_sensitive=True) + elif method_is_async: + if debug: + logger.debug('Asynchronous %s adapted.' % name) + return async_to_sync(method) + return method def get_response(self, request): """Return an HttpResponse object for the given HttpRequest.""" @@ -82,6 +135,26 @@ class BaseHandler: ) return response + async def get_response_async(self, request): + """ + Asynchronous version of get_response. + + Funneling everything, including WSGI, into a single async + get_response() is too slow. Avoid the context switch by using + a separate async response path. + """ + # Setup default url resolver for this thread. + set_urlconf(settings.ROOT_URLCONF) + response = await self._middleware_chain(request) + response._resource_closers.append(request.close) + if response.status_code >= 400: + await sync_to_async(log_response)( + '%s: %s', response.reason_phrase, request.path, + response=response, + request=request, + ) + return response + def _get_response(self, request): """ Resolve and call the view, then apply view, exception, and @@ -89,17 +162,7 @@ class BaseHandler: inside the request/response middleware. """ response = None - - if hasattr(request, 'urlconf'): - urlconf = request.urlconf - set_urlconf(urlconf) - resolver = get_resolver(urlconf) - else: - resolver = get_resolver() - - resolver_match = resolver.resolve(request.path_info) - callback, callback_args, callback_kwargs = resolver_match - request.resolver_match = resolver_match + callback, callback_args, callback_kwargs = self.resolve_request(request) # Apply view middleware for middleware_method in self._view_middleware: @@ -109,6 +172,9 @@ class BaseHandler: if response is None: wrapped_callback = self.make_view_atomic(callback) + # If it is an asynchronous view, run it in a subthread. + if asyncio.iscoroutinefunction(wrapped_callback): + wrapped_callback = async_to_sync(wrapped_callback) try: response = wrapped_callback(request, *callback_args, **callback_kwargs) except Exception as e: @@ -137,6 +203,123 @@ class BaseHandler: return response + async def _get_response_async(self, request): + """ + Resolve and call the view, then apply view, exception, and + template_response middleware. This method is everything that happens + inside the request/response middleware. + """ + response = None + callback, callback_args, callback_kwargs = self.resolve_request(request) + + # Apply view middleware. + for middleware_method in self._view_middleware: + response = await middleware_method(request, callback, callback_args, callback_kwargs) + if response: + break + + if response is None: + wrapped_callback = self.make_view_atomic(callback) + # If it is a synchronous view, run it in a subthread + if not asyncio.iscoroutinefunction(wrapped_callback): + wrapped_callback = sync_to_async(wrapped_callback, thread_sensitive=True) + try: + response = await wrapped_callback(request, *callback_args, **callback_kwargs) + except Exception as e: + response = await sync_to_async( + self.process_exception_by_middleware, + thread_sensitive=True, + )(e, request) + + # Complain if the view returned None or an uncalled coroutine. + self.check_response(response, callback) + + # If the response supports deferred rendering, apply template + # response middleware and then render the response + if hasattr(response, 'render') and callable(response.render): + for middleware_method in self._template_response_middleware: + response = await middleware_method(request, response) + # Complain if the template response middleware returned None or + # an uncalled coroutine. + self.check_response( + response, + middleware_method, + name='%s.process_template_response' % ( + middleware_method.__self__.__class__.__name__, + ) + ) + try: + if asyncio.iscoroutinefunction(response.render): + response = await response.render() + else: + response = await sync_to_async(response.render, thread_sensitive=True)() + except Exception as e: + response = await sync_to_async( + self.process_exception_by_middleware, + thread_sensitive=True, + )(e, request) + + # Make sure the response is not a coroutine + if asyncio.iscoroutine(response): + raise RuntimeError('Response is still a coroutine.') + return response + + def resolve_request(self, request): + """ + Retrieve/set the urlconf for the request. Return the view resolved, + with its args and kwargs. + """ + # Work out the resolver. + if hasattr(request, 'urlconf'): + urlconf = request.urlconf + set_urlconf(urlconf) + resolver = get_resolver(urlconf) + else: + resolver = get_resolver() + # Resolve the view, and assign the match object back to the request. + resolver_match = resolver.resolve(request.path_info) + request.resolver_match = resolver_match + return resolver_match + + def check_response(self, response, callback, name=None): + """ + Raise an error if the view returned None or an uncalled coroutine. + """ + if not(response is None or asyncio.iscoroutine(response)): + return + if not name: + if isinstance(callback, types.FunctionType): # FBV + name = 'The view %s.%s' % (callback.__module__, callback.__name__) + else: # CBV + name = 'The view %s.%s.__call__' % ( + callback.__module__, + callback.__class__.__name__, + ) + if response is None: + raise ValueError( + "%s didn't return an HttpResponse object. It returned None " + "instead." % name + ) + elif asyncio.iscoroutine(response): + raise ValueError( + "%s didn't return an HttpResponse object. It returned an " + "unawaited coroutine instead. You may need to add an 'await' " + "into your view." % name + ) + + # Other utility methods. + + def make_view_atomic(self, view): + non_atomic_requests = getattr(view, '_non_atomic_requests', set()) + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: + if asyncio.iscoroutinefunction(view): + raise RuntimeError( + 'You cannot use ATOMIC_REQUESTS with async views.' + ) + view = transaction.atomic(using=db.alias)(view) + return view + def process_exception_by_middleware(self, exception, request): """ Pass the exception to the exception middleware. If no middleware @@ -148,23 +331,6 @@ class BaseHandler: return response raise - def check_response(self, response, callback, name=None): - """Raise an error if the view returned None.""" - if response is not None: - return - if not name: - if isinstance(callback, types.FunctionType): # FBV - name = 'The view %s.%s' % (callback.__module__, callback.__name__) - else: # CBV - name = 'The view %s.%s.__call__' % ( - callback.__module__, - callback.__class__.__name__, - ) - raise ValueError( - "%s didn't return an HttpResponse object. It returned None " - "instead." % name - ) - def reset_urlconf(sender, **kwargs): """Reset the URLconf after each request is finished.""" diff --git a/django/core/handlers/exception.py b/django/core/handlers/exception.py index 66443ce5601..50880f27841 100644 --- a/django/core/handlers/exception.py +++ b/django/core/handlers/exception.py @@ -1,7 +1,10 @@ +import asyncio import logging import sys from functools import wraps +from asgiref.sync import sync_to_async + from django.conf import settings from django.core import signals from django.core.exceptions import ( @@ -28,14 +31,24 @@ def convert_exception_to_response(get_response): no middleware leaks an exception and that the next middleware in the stack can rely on getting a response instead of an exception. """ - @wraps(get_response) - def inner(request): - try: - response = get_response(request) - except Exception as exc: - response = response_for_exception(request, exc) - return response - return inner + if asyncio.iscoroutinefunction(get_response): + @wraps(get_response) + async def inner(request): + try: + response = await get_response(request) + except Exception as exc: + response = await sync_to_async(response_for_exception)(request, exc) + return response + return inner + else: + @wraps(get_response) + def inner(request): + try: + response = get_response(request) + except Exception as exc: + response = response_for_exception(request, exc) + return response + return inner def response_for_exception(request, exc): diff --git a/django/test/__init__.py b/django/test/__init__.py index 4782d721840..d1f953a8dd1 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -1,6 +1,8 @@ """Django Unit Test framework.""" -from django.test.client import Client, RequestFactory +from django.test.client import ( + AsyncClient, AsyncRequestFactory, Client, RequestFactory, +) from django.test.testcases import ( LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature, @@ -11,8 +13,9 @@ from django.test.utils import ( ) __all__ = [ - 'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase', - 'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature', - 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings', - 'modify_settings', 'override_settings', 'override_system_checks', 'tag', + 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory', + 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase', + 'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', + 'ignore_warnings', 'modify_settings', 'override_settings', + 'override_system_checks', 'tag', ] diff --git a/django/test/client.py b/django/test/client.py index 34fc9f3cf11..c4608324616 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -9,7 +9,10 @@ from importlib import import_module from io import BytesIO from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit +from asgiref.sync import sync_to_async + from django.conf import settings +from django.core.handlers.asgi import ASGIRequest from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.core.serializers.json import DjangoJSONEncoder @@ -157,6 +160,52 @@ class ClientHandler(BaseHandler): return response +class AsyncClientHandler(BaseHandler): + """An async version of ClientHandler.""" + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): + self.enforce_csrf_checks = enforce_csrf_checks + super().__init__(*args, **kwargs) + + async def __call__(self, scope): + # Set up middleware if needed. We couldn't do this earlier, because + # settings weren't available. + if self._middleware_chain is None: + self.load_middleware(is_async=True) + # Extract body file from the scope, if provided. + if '_body_file' in scope: + body_file = scope.pop('_body_file') + else: + body_file = FakePayload('') + + request_started.disconnect(close_old_connections) + await sync_to_async(request_started.send)(sender=self.__class__, scope=scope) + request_started.connect(close_old_connections) + request = ASGIRequest(scope, body_file) + # Sneaky little hack so that we can easily get round + # CsrfViewMiddleware. This makes life easier, and is probably required + # for backwards compatibility with external tests against admin views. + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + # Request goes through middleware. + response = await self.get_response_async(request) + # Simulate behaviors of most Web servers. + conditional_content_removal(request, response) + # Attach the originating ASGI request to the response so that it could + # be later retrieved. + response.asgi_request = request + # Emulate a server by calling the close method on completion. + if response.streaming: + response.streaming_content = await sync_to_async(closing_iterator_wrapper)( + response.streaming_content, + response.close, + ) + else: + request_finished.disconnect(close_old_connections) + # Will fire request_finished. + await sync_to_async(response.close)() + request_finished.connect(close_old_connections) + return response + + def store_rendered_templates(store, signal, sender, template, context, **kwargs): """ Store templates and contexts that are rendered. @@ -421,7 +470,194 @@ class RequestFactory: return self.request(**r) -class Client(RequestFactory): +class AsyncRequestFactory(RequestFactory): + """ + Class that lets you create mock ASGI-like Request objects for use in + testing. Usage: + + rf = AsyncRequestFactory() + get_request = await rf.get('/hello/') + post_request = await rf.post('/submit/', {'foo': 'bar'}) + + Once you have a request object you can pass it to any view function, + including synchronous ones. The reason we have a separate class here is: + a) this makes ASGIRequest subclasses, and + b) AsyncTestClient can subclass it. + """ + def _base_scope(self, **request): + """The base scope for a request.""" + # This is a minimal valid ASGI scope, plus: + # - headers['cookie'] for cookie support, + # - 'client' often useful, see #8551. + scope = { + 'asgi': {'version': '3.0'}, + 'type': 'http', + 'http_version': '1.1', + 'client': ['127.0.0.1', 0], + 'server': ('testserver', '80'), + 'scheme': 'http', + 'method': 'GET', + 'headers': [], + **self.defaults, + **request, + } + scope['headers'].append(( + b'cookie', + b'; '.join(sorted( + ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii') + for morsel in self.cookies.values() + )), + )) + return scope + + def request(self, **request): + """Construct a generic request object.""" + # This is synchronous, which means all methods on this class are. + # AsyncClient, however, has an async request function, which makes all + # its methods async. + if '_body_file' in request: + body_file = request.pop('_body_file') + else: + body_file = FakePayload('') + return ASGIRequest(self._base_scope(**request), body_file) + + def generic( + self, method, path, data='', content_type='application/octet-stream', + secure=False, **extra, + ): + """Construct an arbitrary HTTP request.""" + parsed = urlparse(str(path)) # path can be lazy. + data = force_bytes(data, settings.DEFAULT_CHARSET) + s = { + 'method': method, + 'path': self._get_path(parsed), + 'server': ('127.0.0.1', '443' if secure else '80'), + 'scheme': 'https' if secure else 'http', + 'headers': [(b'host', b'testserver')], + } + if data: + s['headers'].extend([ + (b'content-length', bytes(len(data))), + (b'content-type', content_type.encode('ascii')), + ]) + s['_body_file'] = FakePayload(data) + s.update(extra) + # If QUERY_STRING is absent or empty, we want to extract it from the + # URL. + if not s.get('query_string'): + s['query_string'] = parsed[4] + return self.request(**s) + + +class ClientMixin: + """ + Mixin with common methods between Client and AsyncClient. + """ + def store_exc_info(self, **kwargs): + """Store exceptions when they are generated by a view.""" + self.exc_info = sys.exc_info() + + def check_exception(self, response): + """ + Look for a signaled exception, clear the current context exception + data, re-raise the signaled exception, and clear the signaled exception + from the local cache. + """ + response.exc_info = self.exc_info + if self.exc_info: + _, exc_value, _ = self.exc_info + self.exc_info = None + if self.raise_request_exception: + raise exc_value + + @property + def session(self): + """Return the current session variables.""" + engine = import_module(settings.SESSION_ENGINE) + cookie = self.cookies.get(settings.SESSION_COOKIE_NAME) + if cookie: + return engine.SessionStore(cookie.value) + session = engine.SessionStore() + session.save() + self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key + return session + + def login(self, **credentials): + """ + Set the Factory to appear as if it has successfully logged into a site. + + Return True if login is possible or False if the provided credentials + are incorrect. + """ + from django.contrib.auth import authenticate + user = authenticate(**credentials) + if user: + self._login(user) + return True + return False + + def force_login(self, user, backend=None): + def get_backend(): + from django.contrib.auth import load_backend + for backend_path in settings.AUTHENTICATION_BACKENDS: + backend = load_backend(backend_path) + if hasattr(backend, 'get_user'): + return backend_path + + if backend is None: + backend = get_backend() + user.backend = backend + self._login(user, backend) + + def _login(self, user, backend=None): + from django.contrib.auth import login + # Create a fake request to store login details. + request = HttpRequest() + if self.session: + request.session = self.session + else: + engine = import_module(settings.SESSION_ENGINE) + request.session = engine.SessionStore() + login(request, user, backend) + # Save the session values. + request.session.save() + # Set the cookie to represent the session. + session_cookie = settings.SESSION_COOKIE_NAME + self.cookies[session_cookie] = request.session.session_key + cookie_data = { + 'max-age': None, + 'path': '/', + 'domain': settings.SESSION_COOKIE_DOMAIN, + 'secure': settings.SESSION_COOKIE_SECURE or None, + 'expires': None, + } + self.cookies[session_cookie].update(cookie_data) + + def logout(self): + """Log out the user by removing the cookies and session object.""" + from django.contrib.auth import get_user, logout + request = HttpRequest() + if self.session: + request.session = self.session + request.user = get_user(request) + else: + engine = import_module(settings.SESSION_ENGINE) + request.session = engine.SessionStore() + logout(request) + self.cookies = SimpleCookie() + + def _parse_json(self, response, **extra): + if not hasattr(response, '_json'): + if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): + raise ValueError( + 'Content-Type header is "%s", not "application/json"' + % response.get('Content-Type') + ) + response._json = json.loads(response.content.decode(response.charset), **extra) + return response._json + + +class Client(ClientMixin, RequestFactory): """ A class that can act as a client for testing purposes. @@ -446,23 +682,6 @@ class Client(RequestFactory): self.exc_info = None self.extra = None - def store_exc_info(self, **kwargs): - """Store exceptions when they are generated by a view.""" - self.exc_info = sys.exc_info() - - @property - def session(self): - """Return the current session variables.""" - engine = import_module(settings.SESSION_ENGINE) - cookie = self.cookies.get(settings.SESSION_COOKIE_NAME) - if cookie: - return engine.SessionStore(cookie.value) - - session = engine.SessionStore() - session.save() - self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key - return session - def request(self, **request): """ The master request method. Compose the environment dictionary and pass @@ -486,15 +705,8 @@ class Client(RequestFactory): finally: signals.template_rendered.disconnect(dispatch_uid=signal_uid) got_request_exception.disconnect(dispatch_uid=exception_uid) - # Look for a signaled exception, clear the current context exception - # data, then re-raise the signaled exception. Also clear the signaled - # exception from the local cache. - response.exc_info = self.exc_info - if self.exc_info: - _, exc_value, _ = self.exc_info - self.exc_info = None - if self.raise_request_exception: - raise exc_value + # Check for signaled exceptions. + self.check_exception(response) # Save the client and request that stimulated the response. response.client = self response.request = request @@ -583,85 +795,6 @@ class Client(RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def login(self, **credentials): - """ - Set the Factory to appear as if it has successfully logged into a site. - - Return True if login is possible; False if the provided credentials - are incorrect. - """ - from django.contrib.auth import authenticate - user = authenticate(**credentials) - if user: - self._login(user) - return True - else: - return False - - def force_login(self, user, backend=None): - def get_backend(): - from django.contrib.auth import load_backend - for backend_path in settings.AUTHENTICATION_BACKENDS: - backend = load_backend(backend_path) - if hasattr(backend, 'get_user'): - return backend_path - if backend is None: - backend = get_backend() - user.backend = backend - self._login(user, backend) - - def _login(self, user, backend=None): - from django.contrib.auth import login - engine = import_module(settings.SESSION_ENGINE) - - # Create a fake request to store login details. - request = HttpRequest() - - if self.session: - request.session = self.session - else: - request.session = engine.SessionStore() - login(request, user, backend) - - # Save the session values. - request.session.save() - - # Set the cookie to represent the session. - session_cookie = settings.SESSION_COOKIE_NAME - self.cookies[session_cookie] = request.session.session_key - cookie_data = { - 'max-age': None, - 'path': '/', - 'domain': settings.SESSION_COOKIE_DOMAIN, - 'secure': settings.SESSION_COOKIE_SECURE or None, - 'expires': None, - } - self.cookies[session_cookie].update(cookie_data) - - def logout(self): - """Log out the user by removing the cookies and session object.""" - from django.contrib.auth import get_user, logout - - request = HttpRequest() - engine = import_module(settings.SESSION_ENGINE) - if self.session: - request.session = self.session - request.user = get_user(request) - else: - request.session = engine.SessionStore() - logout(request) - self.cookies = SimpleCookie() - - def _parse_json(self, response, **extra): - if not hasattr(response, '_json'): - if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): - raise ValueError( - 'Content-Type header is "{}", not "application/json"' - .format(response.get('Content-Type')) - ) - response._json = json.loads(response.content.decode(response.charset), **extra) - return response._json - def _handle_redirects(self, response, data='', content_type='', **extra): """ Follow any redirects by requesting responses from the server using GET. @@ -714,3 +847,66 @@ class Client(RequestFactory): raise RedirectCycleError("Too many redirects.", last_response=response) return response + + +class AsyncClient(ClientMixin, AsyncRequestFactory): + """ + An async version of Client that creates ASGIRequests and calls through an + async request path. + + Does not currently support "follow" on its methods. + """ + def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + super().__init__(**defaults) + self.handler = AsyncClientHandler(enforce_csrf_checks) + self.raise_request_exception = raise_request_exception + self.exc_info = None + self.extra = None + + async def request(self, **request): + """ + The master request method. Compose the scope dictionary and pass to the + handler, return the result of the handler. Assume defaults for the + query environment, which can be overridden using the arguments to the + request. + """ + if 'follow' in request: + raise NotImplementedError( + 'AsyncClient request methods do not accept the follow ' + 'parameter.' + ) + scope = self._base_scope(**request) + # Curry a data dictionary into an instance of the template renderer + # callback function. + data = {} + on_template_render = partial(store_rendered_templates, data) + signal_uid = 'template-render-%s' % id(request) + signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) + # Capture exceptions created by the handler. + exception_uid = 'request-exception-%s' % id(request) + got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid) + try: + response = await self.handler(scope) + finally: + signals.template_rendered.disconnect(dispatch_uid=signal_uid) + got_request_exception.disconnect(dispatch_uid=exception_uid) + # Check for signaled exceptions. + self.check_exception(response) + # Save the client and request that stimulated the response. + response.client = self + response.request = request + # Add any rendered template detail to the response. + response.templates = data.get('templates', []) + response.context = data.get('context') + response.json = partial(self._parse_json, response) + # Attach the ResolverMatch instance to the response. + response.resolver_match = SimpleLazyObject(lambda: resolve(request['path'])) + # Flatten a single context. Not really necessary anymore thanks to the + # __getattr__ flattening in ContextList, but has some edge case + # backwards compatibility implications. + if response.context and len(response.context) == 1: + response.context = response.context[0] + # Update persistent cookie data. + if response.cookies: + self.cookies.update(response.cookies) + return response diff --git a/django/test/testcases.py b/django/test/testcases.py index 2c24708a3bf..7ebddf80e5e 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -33,7 +33,7 @@ from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction from django.forms.fields import CharField from django.http import QueryDict from django.http.request import split_domain_port, validate_host -from django.test.client import Client +from django.test.client import AsyncClient, Client from django.test.html import HTMLParseError, parse_html from django.test.signals import setting_changed, template_rendered from django.test.utils import ( @@ -151,6 +151,7 @@ class SimpleTestCase(unittest.TestCase): # The class we'll use for the test client self.client. # Can be overridden in derived classes. client_class = Client + async_client_class = AsyncClient _overridden_settings = None _modified_settings = None @@ -292,6 +293,7 @@ class SimpleTestCase(unittest.TestCase): * Clear the mail test outbox. """ self.client = self.client_class() + self.async_client = self.async_client_class() mail.outbox = [] def _post_teardown(self): diff --git a/django/test/utils.py b/django/test/utils.py index e626667b095..d1f7d195463 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -1,3 +1,4 @@ +import asyncio import logging import re import sys @@ -362,12 +363,22 @@ class TestContextDecorator: raise TypeError('Can only decorate subclasses of unittest.TestCase') def decorate_callable(self, func): - @wraps(func) - def inner(*args, **kwargs): - with self as context: - if self.kwarg_name: - kwargs[self.kwarg_name] = context - return func(*args, **kwargs) + if asyncio.iscoroutinefunction(func): + # If the inner function is an async function, we must execute async + # as well so that the `with` statement executes at the right time. + @wraps(func) + async def inner(*args, **kwargs): + with self as context: + if self.kwarg_name: + kwargs[self.kwarg_name] = context + return await func(*args, **kwargs) + else: + @wraps(func) + def inner(*args, **kwargs): + with self as context: + if self.kwarg_name: + kwargs[self.kwarg_name] = context + return func(*args, **kwargs) return inner def __call__(self, decorated): diff --git a/django/utils/decorators.py b/django/utils/decorators.py index bb2e498e46a..5c9a5d01c76 100644 --- a/django/utils/decorators.py +++ b/django/utils/decorators.py @@ -150,3 +150,30 @@ def make_middleware_decorator(middleware_class): return _wrapped_view return _decorator return _make_decorator + + +def sync_and_async_middleware(func): + """ + Mark a middleware factory as returning a hybrid middleware supporting both + types of request. + """ + func.sync_capable = True + func.async_capable = True + return func + + +def sync_only_middleware(func): + """ + Mark a middleware factory as returning a sync middleware. + This is the default. + """ + func.sync_capable = True + func.async_capable = False + return func + + +def async_only_middleware(func): + """Mark a middleware factory as returning an async middleware.""" + func.sync_capable = False + func.async_capable = True + return func diff --git a/django/utils/deprecation.py b/django/utils/deprecation.py index 81e7c3a15b4..6336558a817 100644 --- a/django/utils/deprecation.py +++ b/django/utils/deprecation.py @@ -1,6 +1,9 @@ +import asyncio import inspect import warnings +from asgiref.sync import sync_to_async + class RemovedInNextVersionWarning(DeprecationWarning): pass @@ -80,14 +83,31 @@ class DeprecationInstanceCheck(type): class MiddlewareMixin: + sync_capable = True + async_capable = True + # RemovedInDjango40Warning: when the deprecation ends, replace with: # def __init__(self, get_response): def __init__(self, get_response=None): self._get_response_none_deprecation(get_response) self.get_response = get_response + self._async_check() super().__init__() + def _async_check(self): + """ + If get_response is a coroutine function, turns us into async mode so + a thread is not consumed during a whole request. + """ + if asyncio.iscoroutinefunction(self.get_response): + # Mark the class as async-capable, but do the actual switch + # inside __call__ to avoid swapping out dunder methods + self._is_coroutine = asyncio.coroutines._is_coroutine + def __call__(self, request): + # Exit out to async mode, if needed + if asyncio.iscoroutinefunction(self.get_response): + return self.__acall__(request) response = None if hasattr(self, 'process_request'): response = self.process_request(request) @@ -96,6 +116,19 @@ class MiddlewareMixin: response = self.process_response(request, response) return response + async def __acall__(self, request): + """ + Async version of __call__ that is swapped in when an async request + is running. + """ + response = None + if hasattr(self, 'process_request'): + response = await sync_to_async(self.process_request)(request) + response = response or await self.get_response(request) + if hasattr(self, 'process_response'): + response = await sync_to_async(self.process_response)(request, response) + return response + def _get_response_none_deprecation(self, get_response): if get_response is None: warnings.warn( diff --git a/docs/index.txt b/docs/index.txt index 3b33478f14d..f82e3d85c12 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -110,8 +110,7 @@ manipulating the data of your Web application. Learn more about it below: :doc:`Custom lookups ` | :doc:`Query Expressions ` | :doc:`Conditional Expressions ` | - :doc:`Database Functions ` | - :doc:`Asynchronous Support ` + :doc:`Database Functions ` * **Other:** :doc:`Supported databases ` | @@ -131,7 +130,8 @@ to know about views via the links below: :doc:`URLconfs ` | :doc:`View functions ` | :doc:`Shortcuts ` | - :doc:`Decorators ` + :doc:`Decorators ` | + :doc:`Asynchronous Support ` * **Reference:** :doc:`Built-in Views ` | diff --git a/docs/ref/utils.txt b/docs/ref/utils.txt index 741ed585f28..0375b2e63b0 100644 --- a/docs/ref/utils.txt +++ b/docs/ref/utils.txt @@ -210,6 +210,31 @@ The functions defined in this module share the following properties: def my_view(request): pass +.. function:: sync_only_middleware(middleware) + + .. versionadded:: 3.1 + + Marks a middleware as :ref:`synchronous-only `. (The + default in Django, but this allows you to future-proof if the default ever + changes in a future release.) + +.. function:: async_only_middleware(middleware) + + .. versionadded:: 3.1 + + Marks a middleware as :ref:`asynchronous-only `. Django + will wrap it in an asynchronous event loop when it is called from the WSGI + request path. + +.. function:: sync_and_async_middleware(middleware) + + .. versionadded:: 3.1 + + Marks a middleware as :ref:`sync and async compatible `, + this allows to avoid converting requests. You must implement detection of + the current request type to use this decorator. See :ref:`asynchronous + middleware documentation ` for details. + ``django.utils.encoding`` ========================= diff --git a/docs/releases/3.1.txt b/docs/releases/3.1.txt index 642575d023b..efc57e8d016 100644 --- a/docs/releases/3.1.txt +++ b/docs/releases/3.1.txt @@ -27,6 +27,43 @@ officially support the latest release of each series. What's new in Django 3.1 ======================== +Asynchronous views and middleware support +----------------------------------------- + +Django now supports a fully asynchronous request path, including: + +* :ref:`Asynchronous views ` +* :ref:`Asynchronous middleware ` +* :ref:`Asynchronous tests and test client ` + +To get started with async views, you need to declare a view using +``async def``:: + + async def my_view(request): + await asyncio.sleep(0.5) + return HttpResponse('Hello, async world!') + +All asynchronous features are supported whether you are running under WSGI or +ASGI mode. However, there will be performance penalties using async code in +WSGI mode. You can read more about the specifics in :doc:`/topics/async` +documentation. + +You are free to mix async and sync views, middleware, and tests as much as you +want. Django will ensure that you always end up with the right execution +context. We expect most projects will keep the majority of their views +synchronous, and only have a select few running in async mode - but it is +entirely your choice. + +Django's ORM, cache layer, and other pieces of code that do long-running +network calls do not yet support async access. We expect to add support for +them in upcoming releases. Async views are ideal, however, if you are doing a +lot of API or HTTP calls inside your view, you can now natively do all those +HTTP calls in parallel to considerably speed up your view's execution. + +Asynchronous support should be entirely backwards-compatible and we have tried +to ensure that it has no speed regressions for your existing, synchronous code. +It should have no noticeable effect on any existing Django projects. + Minor features -------------- diff --git a/docs/spelling_wordlist b/docs/spelling_wordlist index 99bdba17b37..3d2d214409b 100644 --- a/docs/spelling_wordlist +++ b/docs/spelling_wordlist @@ -144,6 +144,7 @@ databrowse datafile dataset datasets +datastores datatype datetimes Debian diff --git a/docs/topics/async.txt b/docs/topics/async.txt index b502e6b65bd..2ca76c972b3 100644 --- a/docs/topics/async.txt +++ b/docs/topics/async.txt @@ -6,13 +6,106 @@ Asynchronous support .. currentmodule:: asgiref.sync -Django has developing support for asynchronous ("async") Python, but does not -yet support asynchronous views or middleware; they will be coming in a future -release. +Django has support for writing asynchronous ("async") views, along with an +entirely async-enabled request stack if you are running under +:doc:`ASGI ` rather than WSGI. Async views will +still work under WSGI, but with performance penalties, and without the ability +to have efficient long-running requests. -There is limited support for other parts of the async ecosystem; namely, Django -can natively talk :doc:`ASGI `, and some async -safety support. +We're still working on asynchronous support for the ORM and other parts of +Django; you can expect to see these in future releases. For now, you can use +the :func:`sync_to_async` adapter to interact with normal Django, as well as +use a whole range of Python asyncio libraries natively. See below for more +details. + +.. versionchanged:: 3.1 + + Support for async views was added. + +Async views +=========== + +.. versionadded:: 3.1 + +Any view can be declared async by making the callable part of it return a +coroutine - commonly, this is done using ``async def``. For a function-based +view, this means declaring the whole view using ``async def``. For a +class-based view, this means making its ``__call__()`` method an ``async def`` +(not its ``__init__()`` or ``as_view()``). + +.. note:: + + Django uses ``asyncio.iscoroutinefunction`` to test if your view is + asynchronous or not. If you implement your own method of returning a + coroutine, ensure you set the ``_is_coroutine`` attribute of the view + to ``asyncio.coroutines._is_coroutine`` so this function returns ``True``. + +Under a WSGI server, asynchronous views will run in their own, one-off event +loop. This means that you can do things like parallel, async HTTP calls to APIs +without any issues, but you will not get the benefits of an asynchronous +request stack. + +If you want these benefits - which are mostly around the ability to service +hundreds of connections without using any Python threads (enabling slow +streaming, long-polling, and other exciting response types) - you will need to +deploy Django using :doc:`ASGI ` instead. + +.. warning:: + + You will only get the benefits of a fully-asynchronous request stack if you + have *no synchronous middleware* loaded into your site; if there is a piece + of synchronous middleware, then Django must use a thread per request to + safely emulate a synchronous environment for it. + + Middleware can be built to support :ref:`both sync and async + ` contexts. Some of Django's middleware is built like + this, but not all. To see what middleware Django has to adapt, you can turn + on debug logging for the ``django.request`` logger and look for log + messages about *`"Synchronous middleware ... adapted"*. + +In either ASGI or WSGI mode, though, you can safely use asynchronous support to +run code in parallel rather than serially, which is especially handy when +dealing with external APIs or datastores. + +If you want to call a part of Django that is still synchronous (like the ORM) +you will need to wrap it in a :func:`sync_to_async` call, like this:: + + from asgiref.sync import sync_to_async + + results = sync_to_async(MyModel.objects.get)(pk=123) + +You may find it easier to move any ORM code into its own function and call that +entire function using :func:`sync_to_async`. If you accidentally try to call +part of Django that is still synchronous-only from an async view, you will +trigger Django's :ref:`asynchronous safety protection ` to +protect your data from corruption. + +Performance +----------- + +When running in a mode that does not match the view (e.g. an async view under +WSGI, or a traditional sync view under ASGI), Django must emulate the other +call style to allow your code to run. This context-switch causes a small +performance penalty of around a millisecond. + +This is true of middleware as well, however. Django will attempt to minimize +the number of context-switches. If you have an ASGI server, but all your +middleware and views are synchronous, it will switch just once, before it +enters the middleware stack. + +If, however, you put synchronous middleware between an ASGI server and an +asynchronous view, it will have to switch into sync mode for the middleware and +then back to asynchronous mode for the view, holding the synchronous thread +open for middleware exception propagation. This may not be noticeable, but bear +in mind that even adding a single piece of synchronous middleware can drag your +whole async project down to running with one thread per request, and the +associated performance penalties. + +You should do your own performance testing to see what effect ASGI vs. WSGI has +on your code. In some cases, there may be a performance increase even for +purely-synchronous codebase under ASGI because the request-handling code is +still all running asynchronously. In general, though, you will only want to +enable ASGI mode if you have asynchronous code in your site. .. _async-safety: diff --git a/docs/topics/http/middleware.txt b/docs/topics/http/middleware.txt index d72d39de5ed..3fe00b947f5 100644 --- a/docs/topics/http/middleware.txt +++ b/docs/topics/http/middleware.txt @@ -71,6 +71,10 @@ method from the handler which takes care of applying :ref:`view middleware applying :ref:`template-response ` and :ref:`exception ` middleware. +Middleware can either support only synchronous Python (the default), only +asynchronous Python, or both. See :ref:`async-middleware` for details of how to +advertise what you support, and know what kind of request you are getting. + Middleware can live anywhere on your Python path. ``__init__(get_response)`` @@ -282,6 +286,81 @@ if the very next middleware in the chain raises an that exception; instead it will get an :class:`~django.http.HttpResponse` object with a :attr:`~django.http.HttpResponse.status_code` of 404. +.. _async-middleware: + +Asynchronous support +==================== + +.. versionadded:: 3.1 + +Middleware can support any combination of synchronous and asynchronous +requests. Django will adapt requests to fit the middleware's requirements if it +cannot support both, but at a performance penalty. + +By default, Django assumes that your middleware is capable of handling only +synchronous requests. To change these assumptions, set the following attributes +on your middleware factory function or class: + +* ``sync_capable`` is a boolean indicating if the middleware can handle + synchronous requests. Defaults to ``True``. + +* ``async_capable`` is a boolean indicating if the middleware can handle + asynchronous requests. Defaults to ``False``. + +If your middleware has both ``sync_capable = True`` and +``async_capable = True``, then Django will pass it the request in whatever form +it is currently in. You can work out what type of request you have by seeing +if the ``get_response`` object you are passed is a coroutine function or not +(using :py:func:`asyncio.iscoroutinefunction`). + +The ``django.utils.decorators`` module contains +:func:`~django.utils.decorators.sync_only_middleware`, +:func:`~django.utils.decorators.async_only_middleware`, and +:func:`~django.utils.decorators.sync_and_async_middleware` decorators that +allow you to apply these flags to middleware factory functions. + +The returned callable must match the sync or async nature of the +``get_response`` method. If you have an asynchronous ``get_response``, you must +return a coroutine function (``async def``). + +``process_view``, ``process_template_response`` and ``process_exception`` +methods, if they are provided, should also be adapted to match the sync/async +mode. However, Django will individually adapt them as required if you do not, +at an additional performance penalty. + +Here's an example of how to detect and adapt your middleware if it supports +both:: + + import asyncio + from django.utils.decorators import sync_and_async_middleware + + @sync_and_async_middleware + def simple_middleware(get_response): + # One-time configuration and initialization goes here. + if asyncio.iscoroutinefunction(get_response): + async def middleware(request): + # Do something here! + response = await get_response(request) + return response + + else: + def middleware(request): + # Do something here! + response = get_response(request) + return response + + return middleware + +.. note:: + + If you declare a hybrid middleware that supports both synchronous and + asynchronous calls, the kind of call you get may not match the underlying + view. Django will optimize the middleware call stack to have as few + sync/async transitions as possible. + + Thus, even if you are wrapping an async view, you may be called in sync + mode if there is other, synchronous middleware between you and the view. + .. _upgrading-middleware: Upgrading pre-Django 1.10-style middleware @@ -292,8 +371,8 @@ Upgrading pre-Django 1.10-style middleware Django provides ``django.utils.deprecation.MiddlewareMixin`` to ease creating middleware classes that are compatible with both :setting:`MIDDLEWARE` and the -old ``MIDDLEWARE_CLASSES``. All middleware classes included with Django -are compatible with both settings. +old ``MIDDLEWARE_CLASSES``, and support synchronous and asynchronous requests. +All middleware classes included with Django are compatible with both settings. The mixin provides an ``__init__()`` method that requires a ``get_response`` argument and stores it in ``self.get_response``. @@ -345,3 +424,7 @@ These are the behavioral differences between using :setting:`MIDDLEWARE` and HTTP response, and then the next middleware in line will see that response. Middleware are never skipped due to a middleware raising an exception. + +.. versionchanged:: 3.1 + + Support for asynchronous requests was added to the ``MiddlewareMixin``. diff --git a/docs/topics/http/views.txt b/docs/topics/http/views.txt index 30761926767..1d3f9b5a113 100644 --- a/docs/topics/http/views.txt +++ b/docs/topics/http/views.txt @@ -202,3 +202,28 @@ in a test view. For example:: response = self.client.get('/403/') # Make assertions on the response here. For example: self.assertContains(response, 'Error handler content', status_code=403) + +.. _async-views: + +Asynchronous views +================== + +.. versionadded:: 3.1 + +As well as being synchronous functions, views can also be asynchronous +functions (``async def``). Django will automatically detect these and run them +in an asynchronous context. You will need to be using an asynchronous (ASGI) +server to get the full power of them, however. + +Here's an example of an asynchronous view:: + + from django.http import HttpResponse + import datetime + + async def current_datetime(request): + now = datetime.datetime.now() + html = 'It is now %s.' % now + return HttpResponse(html) + +You can read more about Django's asynchronous support, and how to best use +asynchronous views, in :doc:`/topics/async`. diff --git a/docs/topics/testing/advanced.txt b/docs/topics/testing/advanced.txt index a909c5f3d83..8aca92ea36a 100644 --- a/docs/topics/testing/advanced.txt +++ b/docs/topics/testing/advanced.txt @@ -67,6 +67,17 @@ The following is a unit test using the request factory:: response = MyView.as_view()(request) self.assertEqual(response.status_code, 200) +AsyncRequestFactory +------------------- + +``RequestFactory`` creates WSGI-like requests. If you want to create ASGI-like +requests, including having a correct ASGI ``scope``, you can instead use +``django.test.AsyncRequestFactory``. + +This class is directly API-compatible with ``RequestFactory``, with the only +difference being that it returns ``ASGIRequest`` instances rather than +``WSGIRequest`` instances. All of its methods are still synchronous callables. + Testing class-based views ========================= diff --git a/docs/topics/testing/tools.txt b/docs/topics/testing/tools.txt index 02433bbf5ed..64fee656bbf 100644 --- a/docs/topics/testing/tools.txt +++ b/docs/topics/testing/tools.txt @@ -1755,6 +1755,62 @@ You can also exclude tests by tag. To run core tests if they are not slow: test has two tags and you select one of them and exclude the other, the test won't be run. +.. _async-tests: + +Testing asynchronous code +========================= + +.. versionadded:: 3.1 + +If you merely want to test the output of your asynchronous views, the standard +test client will run them inside their own asynchronous loop without any extra +work needed on your part. + +However, if you want to write fully-asynchronous tests for a Django project, +you will need to take several things into account. + +Firstly, your tests must be ``async def`` methods on the test class (in order +to give them an asynchronous context). Django will automatically detect +any ``async def`` tests and wrap them so they run in their own event loop. + +If you are testing from an asynchronous function, you must also use the +asynchronous test client. This is available as ``django.test.AsyncClient``, +or as ``self.async_client`` on any test. + +With the exception of the ``follow`` parameter, which is not supported, +``AsyncClient`` has the same methods and signatures as the synchronous (normal) +test client, but any method that makes a request must be awaited:: + + async def test_my_thing(self): + response = await self.async_client.get('/some-url/') + self.assertEqual(response.status_code, 200) + +The asynchronous client can also call synchronous views; it runs through +Django's :doc:`asynchronous request path `, which supports both. +Any view called through the ``AsyncClient`` will get an ``ASGIRequest`` object +for its ``request`` rather than the ``WSGIRequest`` that the normal client +creates. + +.. warning:: + + If you are using test decorators, they must be async-compatible to ensure + they work correctly. Django's built-in decorators will behave correctly, but + third-party ones may appear to not execute (they will "wrap" the wrong part + of the execution flow and not your test). + + If you need to use these decorators, then you should decorate your test + methods with :func:`~asgiref.sync.async_to_sync` *inside* of them instead:: + + from asgiref.sync import async_to_sync + from django.test import TestCase + + class MyTests(TestCase): + + @mock.patch(...) + @async_to_sync + def test_my_thing(self): + ... + .. _topics-testing-email: Email services diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index fada34d0d82..c123f027fb3 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -7,7 +7,7 @@ from asgiref.testing import ApplicationCommunicator from django.core.asgi import get_asgi_application from django.core.signals import request_started from django.db import close_old_connections -from django.test import SimpleTestCase, override_settings +from django.test import AsyncRequestFactory, SimpleTestCase, override_settings from .urls import test_filename @@ -15,21 +15,11 @@ from .urls import test_filename @skipIf(sys.platform == 'win32' and (3, 8, 0) < sys.version_info < (3, 8, 1), 'https://bugs.python.org/issue38563') @override_settings(ROOT_URLCONF='asgi.urls') class ASGITest(SimpleTestCase): + async_request_factory = AsyncRequestFactory() def setUp(self): request_started.disconnect(close_old_connections) - def _get_scope(self, **kwargs): - return { - 'type': 'http', - 'asgi': {'version': '3.0', 'spec_version': '2.1'}, - 'http_version': '1.1', - 'method': 'GET', - 'query_string': b'', - 'server': ('testserver', 80), - **kwargs, - } - def tearDown(self): request_started.connect(close_old_connections) @@ -39,7 +29,8 @@ class ASGITest(SimpleTestCase): """ application = get_asgi_application() # Construct HTTP request. - communicator = ApplicationCommunicator(application, self._get_scope(path='/')) + scope = self.async_request_factory._base_scope(path='/') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) # Read the response. response_start = await communicator.receive_output() @@ -62,7 +53,8 @@ class ASGITest(SimpleTestCase): """ application = get_asgi_application() # Construct HTTP request. - communicator = ApplicationCommunicator(application, self._get_scope(path='/file/')) + scope = self.async_request_factory._base_scope(path='/file/') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) # Get the file content. with open(test_filename, 'rb') as test_file: @@ -82,12 +74,14 @@ class ASGITest(SimpleTestCase): response_body = await communicator.receive_output() self.assertEqual(response_body['type'], 'http.response.body') self.assertEqual(response_body['body'], test_file_contents) + # Allow response.close() to finish. + await communicator.wait() async def test_headers(self): application = get_asgi_application() communicator = ApplicationCommunicator( application, - self._get_scope( + self.async_request_factory._base_scope( path='/meta/', headers=[ [b'content-type', b'text/plain; charset=utf-8'], @@ -116,10 +110,11 @@ class ASGITest(SimpleTestCase): application = get_asgi_application() for query_string in (b'name=Andrew', 'name=Andrew'): with self.subTest(query_string=query_string): - communicator = ApplicationCommunicator( - application, - self._get_scope(path='/', query_string=query_string), + scope = self.async_request_factory._base_scope( + path='/', + query_string=query_string, ) + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) response_start = await communicator.receive_output() self.assertEqual(response_start['type'], 'http.response.start') @@ -130,17 +125,16 @@ class ASGITest(SimpleTestCase): async def test_disconnect(self): application = get_asgi_application() - communicator = ApplicationCommunicator(application, self._get_scope(path='/')) + scope = self.async_request_factory._base_scope(path='/') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.disconnect'}) with self.assertRaises(asyncio.TimeoutError): await communicator.receive_output() async def test_wrong_connection_type(self): application = get_asgi_application() - communicator = ApplicationCommunicator( - application, - self._get_scope(path='/', type='other'), - ) + scope = self.async_request_factory._base_scope(path='/', type='other') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) msg = 'Django can only handle ASGI/HTTP connections, not other.' with self.assertRaisesMessage(ValueError, msg): @@ -148,10 +142,8 @@ class ASGITest(SimpleTestCase): async def test_non_unicode_query_string(self): application = get_asgi_application() - communicator = ApplicationCommunicator( - application, - self._get_scope(path='/', query_string=b'\xff'), - ) + scope = self.async_request_factory._base_scope(path='/', query_string=b'\xff') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) response_start = await communicator.receive_output() self.assertEqual(response_start['type'], 'http.response.start') diff --git a/tests/handlers/tests.py b/tests/handlers/tests.py index fc7074833b9..dac2d967f3e 100644 --- a/tests/handlers/tests.py +++ b/tests/handlers/tests.py @@ -106,6 +106,16 @@ class TransactionsPerRequestTests(TransactionTestCase): connection.settings_dict['ATOMIC_REQUESTS'] = old_atomic_requests self.assertContains(response, 'True') + async def test_auto_transaction_async_view(self): + old_atomic_requests = connection.settings_dict['ATOMIC_REQUESTS'] + try: + connection.settings_dict['ATOMIC_REQUESTS'] = True + msg = 'You cannot use ATOMIC_REQUESTS with async views.' + with self.assertRaisesMessage(RuntimeError, msg): + await self.async_client.get('/async_regular/') + finally: + connection.settings_dict['ATOMIC_REQUESTS'] = old_atomic_requests + def test_no_auto_transaction(self): old_atomic_requests = connection.settings_dict['ATOMIC_REQUESTS'] try: @@ -157,6 +167,11 @@ def empty_middleware(get_response): class HandlerRequestTests(SimpleTestCase): request_factory = RequestFactory() + def test_async_view(self): + """Calling an async view down the normal synchronous path.""" + response = self.client.get('/async_regular/') + self.assertEqual(response.status_code, 200) + def test_suspiciousop_in_view_returns_400(self): response = self.client.get('/suspicious/') self.assertEqual(response.status_code, 400) @@ -224,3 +239,39 @@ class ScriptNameTests(SimpleTestCase): 'PATH_INFO': '/milestones/accounts/login/help', }) self.assertEqual(script_name, '/mst') + + +@override_settings(ROOT_URLCONF='handlers.urls') +class AsyncHandlerRequestTests(SimpleTestCase): + """Async variants of the normal handler request tests.""" + + async def test_sync_view(self): + """Calling a sync view down the asynchronous path.""" + response = await self.async_client.get('/regular/') + self.assertEqual(response.status_code, 200) + + async def test_async_view(self): + """Calling an async view down the asynchronous path.""" + response = await self.async_client.get('/async_regular/') + self.assertEqual(response.status_code, 200) + + async def test_suspiciousop_in_view_returns_400(self): + response = await self.async_client.get('/suspicious/') + self.assertEqual(response.status_code, 400) + + async def test_no_response(self): + msg = ( + "The view handlers.views.no_response didn't return an " + "HttpResponse object. It returned None instead." + ) + with self.assertRaisesMessage(ValueError, msg): + await self.async_client.get('/no_response_fbv/') + + async def test_unawaited_response(self): + msg = ( + "The view handlers.views.async_unawaited didn't return an " + "HttpResponse object. It returned an unawaited coroutine instead. " + "You may need to add an 'await' into your view." + ) + with self.assertRaisesMessage(ValueError, msg): + await self.async_client.get('/unawaited/') diff --git a/tests/handlers/urls.py b/tests/handlers/urls.py index b008395267d..a438da55b4f 100644 --- a/tests/handlers/urls.py +++ b/tests/handlers/urls.py @@ -4,6 +4,7 @@ from . import views urlpatterns = [ path('regular/', views.regular), + path('async_regular/', views.async_regular), path('no_response_fbv/', views.no_response), path('no_response_cbv/', views.NoResponse()), path('streaming/', views.streaming), @@ -12,4 +13,5 @@ urlpatterns = [ path('suspicious/', views.suspicious), path('malformed_post/', views.malformed_post), path('httpstatus_enum/', views.httpstatus_enum), + path('unawaited/', views.async_unawaited), ] diff --git a/tests/handlers/views.py b/tests/handlers/views.py index 872fd52676d..9180c5e5a48 100644 --- a/tests/handlers/views.py +++ b/tests/handlers/views.py @@ -1,3 +1,4 @@ +import asyncio from http import HTTPStatus from django.core.exceptions import SuspiciousOperation @@ -44,3 +45,12 @@ def malformed_post(request): def httpstatus_enum(request): return HttpResponse(status=HTTPStatus.OK) + + +async def async_regular(request): + return HttpResponse(b'regular content') + + +async def async_unawaited(request): + """Return an unawaited coroutine (common error for async views).""" + return asyncio.sleep(0) diff --git a/tests/middleware_exceptions/middleware.py b/tests/middleware_exceptions/middleware.py index 63502c6902a..69c6db57e79 100644 --- a/tests/middleware_exceptions/middleware.py +++ b/tests/middleware_exceptions/middleware.py @@ -1,6 +1,9 @@ from django.http import Http404, HttpResponse from django.template import engines from django.template.response import TemplateResponse +from django.utils.decorators import ( + async_only_middleware, sync_and_async_middleware, sync_only_middleware, +) log = [] @@ -18,6 +21,12 @@ class ProcessExceptionMiddleware(BaseMiddleware): return HttpResponse('Exception caught') +@async_only_middleware +class AsyncProcessExceptionMiddleware(BaseMiddleware): + async def process_exception(self, request, exception): + return HttpResponse('Exception caught') + + class ProcessExceptionLogMiddleware(BaseMiddleware): def process_exception(self, request, exception): log.append('process-exception') @@ -33,6 +42,12 @@ class ProcessViewMiddleware(BaseMiddleware): return HttpResponse('Processed view %s' % view_func.__name__) +@async_only_middleware +class AsyncProcessViewMiddleware(BaseMiddleware): + async def process_view(self, request, view_func, view_args, view_kwargs): + return HttpResponse('Processed view %s' % view_func.__name__) + + class ProcessViewNoneMiddleware(BaseMiddleware): def process_view(self, request, view_func, view_args, view_kwargs): log.append('processed view %s' % view_func.__name__) @@ -51,6 +66,13 @@ class TemplateResponseMiddleware(BaseMiddleware): return response +@async_only_middleware +class AsyncTemplateResponseMiddleware(BaseMiddleware): + async def process_template_response(self, request, response): + response.context_data['mw'].append(self.__class__.__name__) + return response + + class LogMiddleware(BaseMiddleware): def __call__(self, request): response = self.get_response(request) @@ -63,6 +85,48 @@ class NoTemplateResponseMiddleware(BaseMiddleware): return None +@async_only_middleware +class AsyncNoTemplateResponseMiddleware(BaseMiddleware): + async def process_template_response(self, request, response): + return None + + class NotFoundMiddleware(BaseMiddleware): def __call__(self, request): raise Http404('not found') + + +class TeapotMiddleware(BaseMiddleware): + def __call__(self, request): + response = self.get_response(request) + response.status_code = 418 + return response + + +@async_only_middleware +def async_teapot_middleware(get_response): + async def middleware(request): + response = await get_response(request) + response.status_code = 418 + return response + + return middleware + + +@sync_and_async_middleware +class SyncAndAsyncMiddleware(BaseMiddleware): + pass + + +@sync_only_middleware +class DecoratedTeapotMiddleware(TeapotMiddleware): + pass + + +class NotSyncOrAsyncMiddleware(BaseMiddleware): + """Middleware that is deliberately neither sync or async.""" + sync_capable = False + async_capable = False + + def __call__(self, request): + return self.get_response(request) diff --git a/tests/middleware_exceptions/tests.py b/tests/middleware_exceptions/tests.py index 3e614ae0dec..697841a35d5 100644 --- a/tests/middleware_exceptions/tests.py +++ b/tests/middleware_exceptions/tests.py @@ -180,3 +180,162 @@ class MiddlewareNotUsedTests(SimpleTestCase): with self.assertRaisesMessage(AssertionError, 'no logs'): with self.assertLogs('django.request', 'DEBUG'): self.client.get('/middleware_exceptions/view/') + + +@override_settings( + DEBUG=True, + ROOT_URLCONF='middleware_exceptions.urls', +) +class MiddlewareSyncAsyncTests(SimpleTestCase): + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.TeapotMiddleware', + ]) + def test_sync_teapot_middleware(self): + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.DecoratedTeapotMiddleware', + ]) + def test_sync_decorated_teapot_middleware(self): + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.async_teapot_middleware', + ]) + def test_async_teapot_middleware(self): + with self.assertLogs('django.request', 'DEBUG') as cm: + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + self.assertEqual( + cm.records[0].getMessage(), + "Synchronous middleware " + "middleware_exceptions.middleware.async_teapot_middleware " + "adapted.", + ) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.NotSyncOrAsyncMiddleware', + ]) + def test_not_sync_or_async_middleware(self): + msg = ( + 'Middleware ' + 'middleware_exceptions.middleware.NotSyncOrAsyncMiddleware must ' + 'have at least one of sync_capable/async_capable set to True.' + ) + with self.assertRaisesMessage(RuntimeError, msg): + self.client.get('/middleware_exceptions/view/') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.TeapotMiddleware', + ]) + async def test_sync_teapot_middleware_async(self): + with self.assertLogs('django.request', 'DEBUG') as cm: + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + self.assertEqual( + cm.records[0].getMessage(), + "Asynchronous middleware " + "middleware_exceptions.middleware.TeapotMiddleware adapted.", + ) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.async_teapot_middleware', + ]) + async def test_async_teapot_middleware_async(self): + with self.assertLogs('django.request', 'WARNING') as cm: + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + self.assertEqual( + cm.records[0].getMessage(), + 'Unknown Status Code: /middleware_exceptions/view/', + ) + + @override_settings( + DEBUG=False, + MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncNoTemplateResponseMiddleware', + ], + ) + def test_async_process_template_response_returns_none_with_sync_client(self): + msg = ( + "AsyncNoTemplateResponseMiddleware.process_template_response " + "didn't return an HttpResponse object." + ) + with self.assertRaisesMessage(ValueError, msg): + self.client.get('/middleware_exceptions/template_response/') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.SyncAndAsyncMiddleware', + ]) + async def test_async_and_sync_middleware_async_call(self): + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.content, b'OK') + self.assertEqual(response.status_code, 200) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.SyncAndAsyncMiddleware', + ]) + def test_async_and_sync_middleware_sync_call(self): + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.content, b'OK') + self.assertEqual(response.status_code, 200) + + +@override_settings(ROOT_URLCONF='middleware_exceptions.urls') +class AsyncMiddlewareTests(SimpleTestCase): + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncTemplateResponseMiddleware', + ]) + async def test_process_template_response(self): + response = await self.async_client.get( + '/middleware_exceptions/template_response/' + ) + self.assertEqual( + response.content, + b'template_response OK\nAsyncTemplateResponseMiddleware', + ) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncNoTemplateResponseMiddleware', + ]) + async def test_process_template_response_returns_none(self): + msg = ( + "AsyncNoTemplateResponseMiddleware.process_template_response " + "didn't return an HttpResponse object. It returned None instead." + ) + with self.assertRaisesMessage(ValueError, msg): + await self.async_client.get('/middleware_exceptions/template_response/') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessExceptionMiddleware', + ]) + async def test_exception_in_render_passed_to_process_exception(self): + response = await self.async_client.get( + '/middleware_exceptions/exception_in_render/' + ) + self.assertEqual(response.content, b'Exception caught') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessExceptionMiddleware', + ]) + async def test_exception_in_async_render_passed_to_process_exception(self): + response = await self.async_client.get( + '/middleware_exceptions/async_exception_in_render/' + ) + self.assertEqual(response.content, b'Exception caught') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessExceptionMiddleware', + ]) + async def test_view_exception_handled_by_process_exception(self): + response = await self.async_client.get('/middleware_exceptions/error/') + self.assertEqual(response.content, b'Exception caught') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessViewMiddleware', + ]) + async def test_process_view_return_response(self): + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.content, b'Processed view normal_view') diff --git a/tests/middleware_exceptions/urls.py b/tests/middleware_exceptions/urls.py index 46332916b6c..d676ef470cf 100644 --- a/tests/middleware_exceptions/urls.py +++ b/tests/middleware_exceptions/urls.py @@ -8,4 +8,9 @@ urlpatterns = [ path('middleware_exceptions/permission_denied/', views.permission_denied), path('middleware_exceptions/exception_in_render/', views.exception_in_render), path('middleware_exceptions/template_response/', views.template_response), + # Async views. + path( + 'middleware_exceptions/async_exception_in_render/', + views.async_exception_in_render, + ), ] diff --git a/tests/middleware_exceptions/views.py b/tests/middleware_exceptions/views.py index 3ae54081abb..7a1d244863c 100644 --- a/tests/middleware_exceptions/views.py +++ b/tests/middleware_exceptions/views.py @@ -27,3 +27,11 @@ def exception_in_render(request): raise Exception('Exception in HttpResponse.render()') return CustomHttpResponse('Error') + + +async def async_exception_in_render(request): + class CustomHttpResponse(HttpResponse): + async def render(self): + raise Exception('Exception in HttpResponse.render()') + + return CustomHttpResponse('Error') diff --git a/tests/test_client/tests.py b/tests/test_client/tests.py index ce9ce40de51..93ac3e37a70 100644 --- a/tests/test_client/tests.py +++ b/tests/test_client/tests.py @@ -25,9 +25,10 @@ from unittest import mock from django.contrib.auth.models import User from django.core import mail -from django.http import HttpResponse +from django.http import HttpResponse, HttpResponseNotAllowed from django.test import ( - Client, RequestFactory, SimpleTestCase, TestCase, override_settings, + AsyncRequestFactory, Client, RequestFactory, SimpleTestCase, TestCase, + override_settings, ) from django.urls import reverse_lazy @@ -918,3 +919,57 @@ class RequestFactoryTest(SimpleTestCase): protocol = request.META["SERVER_PROTOCOL"] echoed_request_line = "TRACE {} {}".format(url_path, protocol) self.assertContains(response, echoed_request_line) + + +@override_settings(ROOT_URLCONF='test_client.urls') +class AsyncClientTest(TestCase): + async def test_response_resolver_match(self): + response = await self.async_client.get('/async_get_view/') + self.assertTrue(hasattr(response, 'resolver_match')) + self.assertEqual(response.resolver_match.url_name, 'async_get_view') + + async def test_follow_parameter_not_implemented(self): + msg = 'AsyncClient request methods do not accept the follow parameter.' + tests = ( + 'get', + 'post', + 'put', + 'patch', + 'delete', + 'head', + 'options', + 'trace', + ) + for method_name in tests: + with self.subTest(method=method_name): + method = getattr(self.async_client, method_name) + with self.assertRaisesMessage(NotImplementedError, msg): + await method('/redirect_view/', follow=True) + + +@override_settings(ROOT_URLCONF='test_client.urls') +class AsyncRequestFactoryTest(SimpleTestCase): + request_factory = AsyncRequestFactory() + + async def test_request_factory(self): + tests = ( + 'get', + 'post', + 'put', + 'patch', + 'delete', + 'head', + 'options', + 'trace', + ) + for method_name in tests: + with self.subTest(method=method_name): + async def async_generic_view(request): + if request.method.lower() != method_name: + return HttpResponseNotAllowed(method_name) + return HttpResponse(status=200) + + method = getattr(self.request_factory, method_name) + request = method('/somewhere/') + response = await async_generic_view(request) + self.assertEqual(response.status_code, 200) diff --git a/tests/test_client/urls.py b/tests/test_client/urls.py index 61cbe005472..16cca52c38d 100644 --- a/tests/test_client/urls.py +++ b/tests/test_client/urls.py @@ -44,4 +44,6 @@ urlpatterns = [ path('accounts/no_trailing_slash', RedirectView.as_view(url='login/')), path('accounts/login/', auth_views.LoginView.as_view(template_name='login.html')), path('accounts/logout/', auth_views.LogoutView.as_view()), + # Async views. + path('async_get_view/', views.async_get_view, name='async_get_view'), ] diff --git a/tests/test_client/views.py b/tests/test_client/views.py index 2d076fafaf2..c2aef76508c 100644 --- a/tests/test_client/views.py +++ b/tests/test_client/views.py @@ -25,6 +25,10 @@ def get_view(request): return HttpResponse(t.render(c)) +async def async_get_view(request): + return HttpResponse(b'GET content.') + + def trace_view(request): """ A simple view that expects a TRACE request and echoes its status line.