Fixed #27181 -- Allowed contrib.sites to match domains with trailing ".".

This commit is contained in:
Anton Samarchyan 2016-11-29 18:17:10 -05:00 committed by Tim Graham
parent 95627cb0aa
commit 05d2c5a66d
4 changed files with 23 additions and 7 deletions

View File

@ -46,8 +46,6 @@ class SiteManager(models.Manager):
except Site.DoesNotExist: except Site.DoesNotExist:
# Fallback to looking up site after stripping port from the host. # Fallback to looking up site after stripping port from the host.
domain, port = split_domain_port(host) domain, port = split_domain_port(host)
if not port:
raise
if domain not in SITE_CACHE: if domain not in SITE_CACHE:
SITE_CACHE[domain] = self.get(domain__iexact=domain) SITE_CACHE[domain] = self.get(domain__iexact=domain)
return SITE_CACHE[domain] return SITE_CACHE[domain]

View File

@ -554,9 +554,10 @@ def split_domain_port(host):
# It's an IPv6 address without a port. # It's an IPv6 address without a port.
return host, '' return host, ''
bits = host.rsplit(':', 1) bits = host.rsplit(':', 1)
if len(bits) == 2: domain, port = bits if len(bits) == 2 else (bits[0], '')
return tuple(bits) # Remove a trailing dot (if present) from the domain.
return bits[0], '' domain = domain[:-1] if domain.endswith('.') else domain
return domain, port
def validate_host(host, allowed_hosts): def validate_host(host, allowed_hosts):
@ -574,8 +575,6 @@ def validate_host(host, allowed_hosts):
Return ``True`` for a valid host, ``False`` otherwise. Return ``True`` for a valid host, ``False`` otherwise.
""" """
host = host[:-1] if host.endswith('.') else host
for pattern in allowed_hosts: for pattern in allowed_hosts:
if pattern == '*' or is_same_domain(host, pattern): if pattern == '*' or is_same_domain(host, pattern):
return True return True

View File

@ -11,6 +11,7 @@ from django.core.handlers.wsgi import LimitedStream, WSGIRequest
from django.http import ( from django.http import (
HttpRequest, HttpResponse, RawPostDataException, UnreadablePostError, HttpRequest, HttpResponse, RawPostDataException, UnreadablePostError,
) )
from django.http.request import split_domain_port
from django.test import RequestFactory, SimpleTestCase, override_settings from django.test import RequestFactory, SimpleTestCase, override_settings
from django.test.client import FakePayload from django.test.client import FakePayload
from django.test.utils import freeze_time, str_prefix from django.test.utils import freeze_time, str_prefix
@ -842,6 +843,11 @@ class HostValidationTests(SimpleTestCase):
with self.assertRaisesMessage(SuspiciousOperation, msg_suggestion2 % "invalid_hostname.com"): with self.assertRaisesMessage(SuspiciousOperation, msg_suggestion2 % "invalid_hostname.com"):
request.get_host() request.get_host()
def test_split_domain_port_removes_trailing_dot(self):
domain, port = split_domain_port('example.com.:8080')
self.assertEqual(domain, 'example.com')
self.assertEqual(port, '8080')
class BuildAbsoluteURITestCase(SimpleTestCase): class BuildAbsoluteURITestCase(SimpleTestCase):
""" """

View File

@ -93,6 +93,19 @@ class SitesFrameworkTests(TestCase):
site = get_current_site(request) site = get_current_site(request)
self.assertEqual(site.name, "example.com") self.assertEqual(site.name, "example.com")
@override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com'])
def test_get_current_site_host_with_trailing_dot(self):
"""
The site is matched if the name in the request has a trailing dot.
"""
request = HttpRequest()
request.META = {
'SERVER_NAME': 'example.com.',
'SERVER_PORT': '80',
}
site = get_current_site(request)
self.assertEqual(site.name, 'example.com')
@override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com', 'example.net']) @override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com', 'example.net'])
def test_get_current_site_no_site_id_and_handle_port_fallback(self): def test_get_current_site_no_site_id_and_handle_port_fallback(self):
request = HttpRequest() request = HttpRequest()