Fixed #30015 -- Ensured request body is properly consumed for keep-alive connections.

This commit is contained in:
Konstantin Alekseev 2018-12-06 18:08:01 +03:00 committed by Carlton Gibson
parent 1939dd49d1
commit b514dc14f4
4 changed files with 43 additions and 0 deletions

View File

@ -14,6 +14,7 @@ import sys
from wsgiref import simple_server from wsgiref import simple_server
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import LimitedStream
from django.core.wsgi import get_wsgi_application from django.core.wsgi import get_wsgi_application
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
@ -80,6 +81,20 @@ class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer):
class ServerHandler(simple_server.ServerHandler): class ServerHandler(simple_server.ServerHandler):
http_version = '1.1' http_version = '1.1'
def __init__(self, stdin, stdout, stderr, environ, **kwargs):
"""
Setup a limited stream, so we can discard unread request data
at the end of the request. Django already uses `LimitedStream`
in `WSGIRequest` but it shouldn't discard the data since the
upstream servers usually do this. Hence we fix this only for
our testserver/runserver.
"""
try:
content_length = int(environ.get('CONTENT_LENGTH'))
except (ValueError, TypeError):
content_length = 0
super().__init__(LimitedStream(stdin, content_length), stdout, stderr, environ, **kwargs)
def cleanup_headers(self): def cleanup_headers(self):
super().cleanup_headers() super().cleanup_headers()
# HTTP/1.1 requires us to support persistent connections, so # HTTP/1.1 requires us to support persistent connections, so
@ -92,6 +107,10 @@ class ServerHandler(simple_server.ServerHandler):
if self.headers.get('Connection') == 'close': if self.headers.get('Connection') == 'close':
self.request_handler.close_connection = True self.request_handler.close_connection = True
def close(self):
self.get_stdin()._read_limited()
super().close()
def handle_error(self): def handle_error(self):
# Ignore broken pipe errors, otherwise pass on # Ignore broken pipe errors, otherwise pass on
if not is_broken_pipe_error(): if not is_broken_pipe_error():

View File

@ -111,6 +111,23 @@ class LiveServerViews(LiveServerBase):
finally: finally:
conn.close() conn.close()
def test_keep_alive_connection_clears_previous_request_data(self):
conn = HTTPConnection(LiveServerViews.server_thread.host, LiveServerViews.server_thread.port)
try:
conn.request('POST', '/method_view/', b'{}', headers={"Connection": "keep-alive"})
response = conn.getresponse()
self.assertFalse(response.will_close)
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), b'POST')
conn.request('POST', '/method_view/', b'{}', headers={"Connection": "close"})
response = conn.getresponse()
self.assertFalse(response.will_close)
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), b'POST')
finally:
conn.close()
def test_404(self): def test_404(self):
with self.assertRaises(HTTPError) as err: with self.assertRaises(HTTPError) as err:
self.urlopen('/') self.urlopen('/')

View File

@ -11,4 +11,5 @@ urlpatterns = [
url(r'^subview_calling_view/$', views.subview_calling_view), url(r'^subview_calling_view/$', views.subview_calling_view),
url(r'^subview/$', views.subview), url(r'^subview/$', views.subview),
url(r'^check_model_instance_from_subview/$', views.check_model_instance_from_subview), url(r'^check_model_instance_from_subview/$', views.check_model_instance_from_subview),
url(r'^method_view/$', views.method_view),
] ]

View File

@ -1,6 +1,7 @@
from urllib.request import urlopen from urllib.request import urlopen
from django.http import HttpResponse, StreamingHttpResponse from django.http import HttpResponse, StreamingHttpResponse
from django.views.decorators.csrf import csrf_exempt
from .models import Person from .models import Person
@ -42,3 +43,8 @@ def check_model_instance_from_subview(request):
pass pass
with urlopen(request.GET['url'] + '/model_view/') as response: with urlopen(request.GET['url'] + '/model_view/') as response:
return HttpResponse('subview calling view: {}'.format(response.read().decode())) return HttpResponse('subview calling view: {}'.format(response.read().decode()))
@csrf_exempt
def method_view(request):
return HttpResponse(request.method)