Fixed #27181 -- Allowed contrib.sites to match domains with trailing ".".
This commit is contained in:
parent
95627cb0aa
commit
05d2c5a66d
|
@ -46,8 +46,6 @@ class SiteManager(models.Manager):
|
||||||
except Site.DoesNotExist:
|
except Site.DoesNotExist:
|
||||||
# Fallback to looking up site after stripping port from the host.
|
# Fallback to looking up site after stripping port from the host.
|
||||||
domain, port = split_domain_port(host)
|
domain, port = split_domain_port(host)
|
||||||
if not port:
|
|
||||||
raise
|
|
||||||
if domain not in SITE_CACHE:
|
if domain not in SITE_CACHE:
|
||||||
SITE_CACHE[domain] = self.get(domain__iexact=domain)
|
SITE_CACHE[domain] = self.get(domain__iexact=domain)
|
||||||
return SITE_CACHE[domain]
|
return SITE_CACHE[domain]
|
||||||
|
|
|
@ -554,9 +554,10 @@ def split_domain_port(host):
|
||||||
# It's an IPv6 address without a port.
|
# It's an IPv6 address without a port.
|
||||||
return host, ''
|
return host, ''
|
||||||
bits = host.rsplit(':', 1)
|
bits = host.rsplit(':', 1)
|
||||||
if len(bits) == 2:
|
domain, port = bits if len(bits) == 2 else (bits[0], '')
|
||||||
return tuple(bits)
|
# Remove a trailing dot (if present) from the domain.
|
||||||
return bits[0], ''
|
domain = domain[:-1] if domain.endswith('.') else domain
|
||||||
|
return domain, port
|
||||||
|
|
||||||
|
|
||||||
def validate_host(host, allowed_hosts):
|
def validate_host(host, allowed_hosts):
|
||||||
|
@ -574,8 +575,6 @@ def validate_host(host, allowed_hosts):
|
||||||
|
|
||||||
Return ``True`` for a valid host, ``False`` otherwise.
|
Return ``True`` for a valid host, ``False`` otherwise.
|
||||||
"""
|
"""
|
||||||
host = host[:-1] if host.endswith('.') else host
|
|
||||||
|
|
||||||
for pattern in allowed_hosts:
|
for pattern in allowed_hosts:
|
||||||
if pattern == '*' or is_same_domain(host, pattern):
|
if pattern == '*' or is_same_domain(host, pattern):
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -11,6 +11,7 @@ from django.core.handlers.wsgi import LimitedStream, WSGIRequest
|
||||||
from django.http import (
|
from django.http import (
|
||||||
HttpRequest, HttpResponse, RawPostDataException, UnreadablePostError,
|
HttpRequest, HttpResponse, RawPostDataException, UnreadablePostError,
|
||||||
)
|
)
|
||||||
|
from django.http.request import split_domain_port
|
||||||
from django.test import RequestFactory, SimpleTestCase, override_settings
|
from django.test import RequestFactory, SimpleTestCase, override_settings
|
||||||
from django.test.client import FakePayload
|
from django.test.client import FakePayload
|
||||||
from django.test.utils import freeze_time, str_prefix
|
from django.test.utils import freeze_time, str_prefix
|
||||||
|
@ -842,6 +843,11 @@ class HostValidationTests(SimpleTestCase):
|
||||||
with self.assertRaisesMessage(SuspiciousOperation, msg_suggestion2 % "invalid_hostname.com"):
|
with self.assertRaisesMessage(SuspiciousOperation, msg_suggestion2 % "invalid_hostname.com"):
|
||||||
request.get_host()
|
request.get_host()
|
||||||
|
|
||||||
|
def test_split_domain_port_removes_trailing_dot(self):
|
||||||
|
domain, port = split_domain_port('example.com.:8080')
|
||||||
|
self.assertEqual(domain, 'example.com')
|
||||||
|
self.assertEqual(port, '8080')
|
||||||
|
|
||||||
|
|
||||||
class BuildAbsoluteURITestCase(SimpleTestCase):
|
class BuildAbsoluteURITestCase(SimpleTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -93,6 +93,19 @@ class SitesFrameworkTests(TestCase):
|
||||||
site = get_current_site(request)
|
site = get_current_site(request)
|
||||||
self.assertEqual(site.name, "example.com")
|
self.assertEqual(site.name, "example.com")
|
||||||
|
|
||||||
|
@override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com'])
|
||||||
|
def test_get_current_site_host_with_trailing_dot(self):
|
||||||
|
"""
|
||||||
|
The site is matched if the name in the request has a trailing dot.
|
||||||
|
"""
|
||||||
|
request = HttpRequest()
|
||||||
|
request.META = {
|
||||||
|
'SERVER_NAME': 'example.com.',
|
||||||
|
'SERVER_PORT': '80',
|
||||||
|
}
|
||||||
|
site = get_current_site(request)
|
||||||
|
self.assertEqual(site.name, 'example.com')
|
||||||
|
|
||||||
@override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com', 'example.net'])
|
@override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com', 'example.net'])
|
||||||
def test_get_current_site_no_site_id_and_handle_port_fallback(self):
|
def test_get_current_site_no_site_id_and_handle_port_fallback(self):
|
||||||
request = HttpRequest()
|
request = HttpRequest()
|
||||||
|
|
Loading…
Reference in New Issue