From 05d2c5a66dd72c26e5221855e08834a66c844399 Mon Sep 17 00:00:00 2001 From: Anton Samarchyan Date: Tue, 29 Nov 2016 18:17:10 -0500 Subject: [PATCH] Fixed #27181 -- Allowed contrib.sites to match domains with trailing ".". --- django/contrib/sites/models.py | 2 -- django/http/request.py | 9 ++++----- tests/requests/tests.py | 6 ++++++ tests/sites_tests/tests.py | 13 +++++++++++++ 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/django/contrib/sites/models.py b/django/contrib/sites/models.py index e4c68638da..7b237b0873 100644 --- a/django/contrib/sites/models.py +++ b/django/contrib/sites/models.py @@ -46,8 +46,6 @@ class SiteManager(models.Manager): except Site.DoesNotExist: # Fallback to looking up site after stripping port from the host. domain, port = split_domain_port(host) - if not port: - raise if domain not in SITE_CACHE: SITE_CACHE[domain] = self.get(domain__iexact=domain) return SITE_CACHE[domain] diff --git a/django/http/request.py b/django/http/request.py index 83131d52c3..9ffcd23fbd 100644 --- a/django/http/request.py +++ b/django/http/request.py @@ -554,9 +554,10 @@ def split_domain_port(host): # It's an IPv6 address without a port. return host, '' bits = host.rsplit(':', 1) - if len(bits) == 2: - return tuple(bits) - return bits[0], '' + domain, port = bits if len(bits) == 2 else (bits[0], '') + # Remove a trailing dot (if present) from the domain. + domain = domain[:-1] if domain.endswith('.') else domain + return domain, port def validate_host(host, allowed_hosts): @@ -574,8 +575,6 @@ def validate_host(host, allowed_hosts): Return ``True`` for a valid host, ``False`` otherwise. """ - host = host[:-1] if host.endswith('.') else host - for pattern in allowed_hosts: if pattern == '*' or is_same_domain(host, pattern): return True diff --git a/tests/requests/tests.py b/tests/requests/tests.py index 1a243a02b0..0942619715 100644 --- a/tests/requests/tests.py +++ b/tests/requests/tests.py @@ -11,6 +11,7 @@ from django.core.handlers.wsgi import LimitedStream, WSGIRequest from django.http import ( HttpRequest, HttpResponse, RawPostDataException, UnreadablePostError, ) +from django.http.request import split_domain_port from django.test import RequestFactory, SimpleTestCase, override_settings from django.test.client import FakePayload from django.test.utils import freeze_time, str_prefix @@ -842,6 +843,11 @@ class HostValidationTests(SimpleTestCase): with self.assertRaisesMessage(SuspiciousOperation, msg_suggestion2 % "invalid_hostname.com"): request.get_host() + def test_split_domain_port_removes_trailing_dot(self): + domain, port = split_domain_port('example.com.:8080') + self.assertEqual(domain, 'example.com') + self.assertEqual(port, '8080') + class BuildAbsoluteURITestCase(SimpleTestCase): """ diff --git a/tests/sites_tests/tests.py b/tests/sites_tests/tests.py index 64fe2215d4..3ea66be9bb 100644 --- a/tests/sites_tests/tests.py +++ b/tests/sites_tests/tests.py @@ -93,6 +93,19 @@ class SitesFrameworkTests(TestCase): site = get_current_site(request) self.assertEqual(site.name, "example.com") + @override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com']) + def test_get_current_site_host_with_trailing_dot(self): + """ + The site is matched if the name in the request has a trailing dot. + """ + request = HttpRequest() + request.META = { + 'SERVER_NAME': 'example.com.', + 'SERVER_PORT': '80', + } + site = get_current_site(request) + self.assertEqual(site.name, 'example.com') + @override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com', 'example.net']) def test_get_current_site_no_site_id_and_handle_port_fallback(self): request = HttpRequest()