diff --git a/django/http/request.py b/django/http/request.py index ad18706e1e..749b9f2561 100644 --- a/django/http/request.py +++ b/django/http/request.py @@ -4,7 +4,6 @@ import copy import os import re import sys -import warnings from io import BytesIO from pprint import pformat try: @@ -66,11 +65,14 @@ class HttpRequest(object): host = '%s:%s' % (host, server_port) allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS - if validate_host(host, allowed_hosts): + domain, port = split_domain_port(host) + if domain and validate_host(domain, allowed_hosts): return host else: - raise SuspiciousOperation( - "Invalid HTTP_HOST header (you may need to set ALLOWED_HOSTS): %s" % host) + msg = "Invalid HTTP_HOST header: %r." % host + if domain: + msg += "You may need to add %r to ALLOWED_HOSTS." % domain + raise SuspiciousOperation(msg) def get_full_path(self): # RFC 3986 requires query string arguments to be in the ASCII range. @@ -454,9 +456,30 @@ def bytes_to_text(s, encoding): return s +def split_domain_port(host): + """ + Return a (domain, port) tuple from a given host. + + Returned domain is lower-cased. If the host is invalid, the domain will be + empty. + """ + host = host.lower() + + if not host_validation_re.match(host): + return '', '' + + if host[-1] == ']': + # It's an IPv6 address without a port. + return host, '' + bits = host.rsplit(':', 1) + if len(bits) == 2: + return tuple(bits) + return bits[0], '' + + def validate_host(host, allowed_hosts): """ - Validate the given host header value for this site. + Validate the given host for this site. Check that the host looks valid and matches a host or host pattern in the given list of ``allowed_hosts``. Any pattern beginning with a period @@ -464,31 +487,20 @@ def validate_host(host, allowed_hosts): ``example.com`` and any subdomain), ``*`` matches anything, and anything else must match exactly. + Note: This function assumes that the given host is lower-cased and has + already had the port, if any, stripped off. + Return ``True`` for a valid host, ``False`` otherwise. """ - # All validation is case-insensitive - host = host.lower() - - # Basic sanity check - if not host_validation_re.match(host): - return False - - # Validate only the domain part. - if host[-1] == ']': - # It's an IPv6 address without a port. - domain = host - else: - domain = host.rsplit(':', 1)[0] - for pattern in allowed_hosts: pattern = pattern.lower() match = ( pattern == '*' or pattern.startswith('.') and ( - domain.endswith(pattern) or domain == pattern[1:] + host.endswith(pattern) or host == pattern[1:] ) or - pattern == domain + pattern == host ) if match: return True diff --git a/tests/requests/tests.py b/tests/requests/tests.py index e176775449..daf426ea47 100644 --- a/tests/requests/tests.py +++ b/tests/requests/tests.py @@ -11,16 +11,16 @@ from django.core import signals from django.core.exceptions import SuspiciousOperation from django.core.handlers.wsgi import WSGIRequest, LimitedStream from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError -from django.test import TransactionTestCase +from django.test import SimpleTestCase, TransactionTestCase from django.test.client import FakePayload from django.test.utils import override_settings, str_prefix from django.utils import six -from django.utils import unittest +from django.utils.unittest import skipIf from django.utils.http import cookie_date, urlencode from django.utils.timezone import utc -class RequestsTests(unittest.TestCase): +class RequestsTests(SimpleTestCase): def test_httprequest(self): request = HttpRequest() self.assertEqual(list(request.GET.keys()), []) @@ -287,6 +287,56 @@ class RequestsTests(unittest.TestCase): self.assertEqual(request.get_host(), 'example.com') + @override_settings(ALLOWED_HOSTS=[]) + def test_get_host_suggestion_of_allowed_host(self): + """get_host() makes helpful suggestions if a valid-looking host is not in ALLOWED_HOSTS.""" + msg_invalid_host = "Invalid HTTP_HOST header: %r." + msg_suggestion = msg_invalid_host + "You may need to add %r to ALLOWED_HOSTS." + + for host in [ # Valid-looking hosts + 'example.com', + '12.34.56.78', + '[2001:19f0:feee::dead:beef:cafe]', + 'xn--4ca9at.com', # Punnycode for öäü.com + ]: + request = HttpRequest() + request.META = {'HTTP_HOST': host} + self.assertRaisesMessage( + SuspiciousOperation, + msg_suggestion % (host, host), + request.get_host + ) + + for domain, port in [ # Valid-looking hosts with a port number + ('example.com', 80), + ('12.34.56.78', 443), + ('[2001:19f0:feee::dead:beef:cafe]', 8080), + ]: + host = '%s:%s' % (domain, port) + request = HttpRequest() + request.META = {'HTTP_HOST': host} + self.assertRaisesMessage( + SuspiciousOperation, + msg_suggestion % (host, domain), + request.get_host + ) + + for host in [ # Invalid hosts + 'example.com@evil.tld', + 'example.com:dr.frankenstein@evil.tld', + 'example.com:dr.frankenstein@evil.tld:80', + 'example.com:80/badpath', + 'example.com: recovermypassword.com', + ]: + request = HttpRequest() + request.META = {'HTTP_HOST': host} + self.assertRaisesMessage( + SuspiciousOperation, + msg_invalid_host % host, + request.get_host + ) + + def test_near_expiration(self): "Cookie will expire when an near expiration time is provided" response = HttpResponse() @@ -587,7 +637,7 @@ class RequestsTests(unittest.TestCase): request.body -@unittest.skipIf(connection.vendor == 'sqlite' +@skipIf(connection.vendor == 'sqlite' and connection.settings_dict['NAME'] in ('', ':memory:'), "Cannot establish two connections to an in-memory SQLite database.") class DatabaseConnectionHandlingTests(TransactionTestCase):