Fixed #32416 -- Made ThreadedWSGIServer close connections after each thread.

ThreadedWSGIServer is used by LiveServerTestCase.
This commit is contained in:
Chris Jerdonek 2021-02-14 22:10:59 -08:00 committed by Mariusz Felisiak
parent 71a936f9d8
commit 823a9e6bac
4 changed files with 95 additions and 3 deletions

View File

@ -16,6 +16,7 @@ from wsgiref import simple_server
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import LimitedStream from django.core.handlers.wsgi import LimitedStream
from django.core.wsgi import get_wsgi_application from django.core.wsgi import get_wsgi_application
from django.db import connections
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
__all__ = ('WSGIServer', 'WSGIRequestHandler') __all__ = ('WSGIServer', 'WSGIRequestHandler')
@ -81,6 +82,28 @@ class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer):
"""A threaded version of the WSGIServer""" """A threaded version of the WSGIServer"""
daemon_threads = True daemon_threads = True
def __init__(self, *args, connections_override=None, **kwargs):
super().__init__(*args, **kwargs)
self.connections_override = connections_override
# socketserver.ThreadingMixIn.process_request() passes this method as
# the target to a new Thread object.
def process_request_thread(self, request, client_address):
if self.connections_override:
# Override this thread's database connections with the ones
# provided by the parent thread.
for alias, conn in self.connections_override.items():
connections[alias] = conn
super().process_request_thread(request, client_address)
def _close_connections(self):
# Used for mocking in tests.
connections.close_all()
def close_request(self, request):
self._close_connections()
super().close_request(request)
class ServerHandler(simple_server.ServerHandler): class ServerHandler(simple_server.ServerHandler):
http_version = '1.1' http_version = '1.1'

View File

@ -83,6 +83,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"the sqlite backend's close() method is a no-op when using an " "the sqlite backend's close() method is a no-op when using an "
"in-memory database": { "in-memory database": {
'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections', 'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections',
'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections',
}, },
}) })
return skips return skips

View File

@ -1513,11 +1513,12 @@ class LiveServerThread(threading.Thread):
finally: finally:
connections.close_all() connections.close_all()
def _create_server(self): def _create_server(self, connections_override=None):
return self.server_class( return self.server_class(
(self.host, self.port), (self.host, self.port),
QuietWSGIRequestHandler, QuietWSGIRequestHandler,
allow_reuse_address=False, allow_reuse_address=False,
connections_override=connections_override,
) )
def terminate(self): def terminate(self):
@ -1600,7 +1601,7 @@ class LiveServerTestCase(TransactionTestCase):
def _tearDownClassInternal(cls): def _tearDownClassInternal(cls):
# Terminate the live server's thread. # Terminate the live server's thread.
cls.server_thread.terminate() cls.server_thread.terminate()
# Restore sqlite in-memory database connections' non-shareability. # Restore shared connections' non-shareability.
for conn in cls.server_thread.connections_override.values(): for conn in cls.server_thread.connections_override.values():
conn.dec_thread_sharing() conn.dec_thread_sharing()

View File

@ -4,13 +4,15 @@ Tests for django.core.servers.
import errno import errno
import os import os
import socket import socket
import threading
from http.client import HTTPConnection from http.client import HTTPConnection
from urllib.error import HTTPError from urllib.error import HTTPError
from urllib.parse import urlencode from urllib.parse import urlencode
from urllib.request import urlopen from urllib.request import urlopen
from django.conf import settings from django.conf import settings
from django.core.servers.basehttp import WSGIServer from django.core.servers.basehttp import ThreadedWSGIServer, WSGIServer
from django.db import DEFAULT_DB_ALIAS, connections
from django.test import LiveServerTestCase, override_settings from django.test import LiveServerTestCase, override_settings
from django.test.testcases import LiveServerThread, QuietWSGIRequestHandler from django.test.testcases import LiveServerThread, QuietWSGIRequestHandler
@ -40,6 +42,71 @@ class LiveServerBase(LiveServerTestCase):
return urlopen(self.live_server_url + url) return urlopen(self.live_server_url + url)
class CloseConnectionTestServer(ThreadedWSGIServer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# This event is set right after the first time a request closes its
# database connections.
self._connections_closed = threading.Event()
def _close_connections(self):
super()._close_connections()
self._connections_closed.set()
class CloseConnectionTestLiveServerThread(LiveServerThread):
server_class = CloseConnectionTestServer
def _create_server(self, connections_override=None):
return super()._create_server(connections_override=self.connections_override)
class LiveServerTestCloseConnectionTest(LiveServerBase):
server_thread_class = CloseConnectionTestLiveServerThread
@classmethod
def _make_connections_override(cls):
conn = connections[DEFAULT_DB_ALIAS]
cls.conn = conn
cls.old_conn_max_age = conn.settings_dict['CONN_MAX_AGE']
# Set the connection's CONN_MAX_AGE to None to simulate the
# CONN_MAX_AGE setting being set to None on the server. This prevents
# Django from closing the connection and allows testing that
# ThreadedWSGIServer closes connections.
conn.settings_dict['CONN_MAX_AGE'] = None
# Pass a database connection through to the server to check it is being
# closed by ThreadedWSGIServer.
return {DEFAULT_DB_ALIAS: conn}
@classmethod
def tearDownConnectionTest(cls):
cls.conn.settings_dict['CONN_MAX_AGE'] = cls.old_conn_max_age
@classmethod
def tearDownClass(cls):
cls.tearDownConnectionTest()
super().tearDownClass()
def test_closes_connections(self):
# The server's request thread sets this event after closing
# its database connections.
closed_event = self.server_thread.httpd._connections_closed
conn = self.conn
# Open a connection to the database.
conn.connect()
self.assertIsNotNone(conn.connection)
with self.urlopen('/model_view/') as f:
# The server can access the database.
self.assertEqual(f.read().splitlines(), [b'jane', b'robert'])
# Wait for the server's request thread to close the connection.
# A timeout of 0.1 seconds should be more than enough. If the wait
# times out, the assertion after should fail.
closed_event.wait(timeout=0.1)
self.assertIsNone(conn.connection)
class FailingLiveServerThread(LiveServerThread): class FailingLiveServerThread(LiveServerThread):
def _create_server(self): def _create_server(self):
raise RuntimeError('Error creating server.') raise RuntimeError('Error creating server.')