Fixed #15089 -- Allowed contrib.sites to lookup the current site based on request.get_host().

Thanks Claude Paroz, Riccardo Magliocchetti, and Damian Moore
for contributions to the patch.
This commit is contained in:
Tim Graham 2014-09-30 13:15:59 -04:00
parent 3339605821
commit 32c7d3c061
7 changed files with 100 additions and 28 deletions

View File

@ -72,7 +72,7 @@ def shortcut(request, content_type_id, object_id):
# Fall back to the current site (if possible). # Fall back to the current site (if possible).
if object_domain is None: if object_domain is None:
try: try:
object_domain = Site.objects.get_current().domain object_domain = Site.objects.get_current(request).domain
except Site.DoesNotExist: except Site.DoesNotExist:
pass pass

View File

@ -1,4 +1,4 @@
from .models import Site from .shortcuts import get_current_site
class CurrentSiteMiddleware(object): class CurrentSiteMiddleware(object):
@ -7,4 +7,4 @@ class CurrentSiteMiddleware(object):
""" """
def process_request(self, request): def process_request(self, request):
request.site = Site.objects.get_current() request.site = get_current_site(request)

View File

@ -34,26 +34,39 @@ def _simple_domain_name_validator(value):
class SiteManager(models.Manager): class SiteManager(models.Manager):
def get_current(self): def _get_site_by_id(self, site_id):
if site_id not in SITE_CACHE:
site = self.get(pk=site_id)
SITE_CACHE[site_id] = site
return SITE_CACHE[site_id]
def _get_site_by_request(self, request):
host = request.get_host()
if host not in SITE_CACHE:
site = self.get(domain__iexact=host)
SITE_CACHE[host] = site
return SITE_CACHE[host]
def get_current(self, request=None):
""" """
Returns the current ``Site`` based on the SITE_ID in the Returns the current Site based on the SITE_ID in the project's settings.
project's settings. The ``Site`` object is cached the first If SITE_ID isn't defined, it returns the site with domain matching
time it's retrieved from the database. request.get_host(). The ``Site`` object is cached the first time it's
retrieved from the database.
""" """
from django.conf import settings from django.conf import settings
try: if getattr(settings, 'SITE_ID', ''):
sid = settings.SITE_ID site_id = settings.SITE_ID
except AttributeError: return self._get_site_by_id(site_id)
raise ImproperlyConfigured( elif request:
"You're using the Django \"sites framework\" without having " return self._get_site_by_request(request)
"set the SITE_ID setting. Create a site in your database and "
"set the SITE_ID setting to fix this error.") raise ImproperlyConfigured(
try: "You're using the Django \"sites framework\" without having "
current_site = SITE_CACHE[sid] "set the SITE_ID setting. Create a site in your database and "
except KeyError: "set the SITE_ID setting or pass a request to "
current_site = self.get(pk=sid) "Site.objects.get_current() to fix this error."
SITE_CACHE[sid] = current_site )
return current_site
def clear_cache(self): def clear_cache(self):
"""Clears the ``Site`` object cache.""" """Clears the ``Site`` object cache."""
@ -103,5 +116,9 @@ def clear_site_cache(sender, **kwargs):
del SITE_CACHE[instance.pk] del SITE_CACHE[instance.pk]
except KeyError: except KeyError:
pass pass
try:
del SITE_CACHE[Site.objects.get(pk=instance.pk).domain]
except (KeyError, Site.DoesNotExist):
pass
pre_save.connect(clear_site_cache, sender=Site) pre_save.connect(clear_site_cache, sender=Site)
pre_delete.connect(clear_site_cache, sender=Site) pre_delete.connect(clear_site_cache, sender=Site)

View File

@ -12,7 +12,7 @@ def get_current_site(request):
# the Site models when django.contrib.sites isn't installed. # the Site models when django.contrib.sites isn't installed.
if apps.is_installed('django.contrib.sites'): if apps.is_installed('django.contrib.sites'):
from .models import Site from .models import Site
return Site.objects.get_current() return Site.objects.get_current(request)
else: else:
from .requests import RequestSite from .requests import RequestSite
return RequestSite(request) return RequestSite(request)

View File

@ -5,8 +5,9 @@ from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.http import HttpRequest from django.http import HttpRequest
from django.test import TestCase, modify_settings, override_settings from django.test import TestCase, modify_settings, override_settings
from . import models
from .middleware import CurrentSiteMiddleware from .middleware import CurrentSiteMiddleware
from .models import Site from .models import clear_site_cache, Site
from .requests import RequestSite from .requests import RequestSite
from .shortcuts import get_current_site from .shortcuts import get_current_site
@ -15,7 +16,12 @@ from .shortcuts import get_current_site
class SitesFrameworkTests(TestCase): class SitesFrameworkTests(TestCase):
def setUp(self): def setUp(self):
Site(id=settings.SITE_ID, domain="example.com", name="example.com").save() self.site = Site(
id=settings.SITE_ID,
domain="example.com",
name="example.com",
)
self.site.save()
def test_save_another(self): def test_save_another(self):
# Regression for #17415 # Regression for #17415
@ -71,6 +77,17 @@ class SitesFrameworkTests(TestCase):
self.assertIsInstance(site, RequestSite) self.assertIsInstance(site, RequestSite)
self.assertEqual(site.name, "example.com") self.assertEqual(site.name, "example.com")
@override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com'])
def test_get_current_site_no_site_id(self):
request = HttpRequest()
request.META = {
"SERVER_NAME": "example.com",
"SERVER_PORT": "80",
}
del settings.SITE_ID
site = get_current_site(request)
self.assertEqual(site.name, "example.com")
def test_domain_name_with_whitespaces(self): def test_domain_name_with_whitespaces(self):
# Regression for #17320 # Regression for #17320
# Domain names are not allowed contain whitespace characters # Domain names are not allowed contain whitespace characters
@ -81,6 +98,26 @@ class SitesFrameworkTests(TestCase):
site.domain = "test\ntest" site.domain = "test\ntest"
self.assertRaises(ValidationError, site.full_clean) self.assertRaises(ValidationError, site.full_clean)
def test_clear_site_cache(self):
request = HttpRequest()
request.META = {
"SERVER_NAME": "example.com",
"SERVER_PORT": "80",
}
self.assertEqual(models.SITE_CACHE, {})
get_current_site(request)
expected_cache = {self.site.id: self.site}
self.assertEqual(models.SITE_CACHE, expected_cache)
with self.settings(SITE_ID=''):
get_current_site(request)
expected_cache.update({self.site.domain: self.site})
self.assertEqual(models.SITE_CACHE, expected_cache)
clear_site_cache(Site, instance=self.site)
self.assertEqual(models.SITE_CACHE, {})
class MiddlewareTest(TestCase): class MiddlewareTest(TestCase):

View File

@ -18,10 +18,6 @@ The sites framework is mainly based on a simple model:
.. class:: models.Site .. class:: models.Site
A model for storing the ``domain`` and ``name`` attributes of a Web site. A model for storing the ``domain`` and ``name`` attributes of a Web site.
The :setting:`SITE_ID` setting specifies the database ID of the
:class:`~django.contrib.sites.models.Site` object (accessible using
the automatically added ``id`` attribute) associated with that
particular settings file.
.. attribute:: domain .. attribute:: domain
@ -31,6 +27,14 @@ The sites framework is mainly based on a simple model:
A human-readable "verbose" name for the Web site. A human-readable "verbose" name for the Web site.
The :setting:`SITE_ID` setting specifies the database ID of the
:class:`~django.contrib.sites.models.Site` object associated with that
particular settings file. It the setting is omitted, the
:func:`~django.contrib.sites.shortcuts.get_current_site` function will
try to get the current site by comparing the
:attr:`~django.contrib.sites.models.Site.domain` with the host name from
the :meth:`request.get_host() <django.http.HttpRequest.get_host>` method.
How you use this is up to you, but Django uses it in a couple of ways How you use this is up to you, but Django uses it in a couple of ways
automatically via simple conventions. automatically via simple conventions.
@ -308,6 +312,11 @@ model(s). It's a model :doc:`manager </topics/db/managers>` that
automatically filters its queries to include only objects associated automatically filters its queries to include only objects associated
with the current :class:`~django.contrib.sites.models.Site`. with the current :class:`~django.contrib.sites.models.Site`.
.. admonition:: Mandatory :setting:`SITE_ID`
The ``CurrentSiteManager`` is only usable when the :setting:`SITE_ID`
setting is defined in your settings.
Use :class:`~django.contrib.sites.managers.CurrentSiteManager` by adding it to Use :class:`~django.contrib.sites.managers.CurrentSiteManager` by adding it to
your model explicitly. For example:: your model explicitly. For example::
@ -492,3 +501,9 @@ Finally, to avoid repetitive fallback code, the framework provides a
.. versionchanged:: 1.7 .. versionchanged:: 1.7
This function used to be defined in ``django.contrib.sites.models``. This function used to be defined in ``django.contrib.sites.models``.
.. versionchanged:: 1.8
This function will now lookup the current site based on
:meth:`request.get_host() <django.http.HttpRequest.get_host>` if the
:setting:`SITE_ID` setting is not defined.

View File

@ -139,7 +139,10 @@ Minor features
:mod:`django.contrib.sites` :mod:`django.contrib.sites`
^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
* ... * :func:`~django.contrib.sites.shortcuts.get_current_site` will now lookup
the current site based on :meth:`request.get_host()
<django.http.HttpRequest.get_host>` if the :setting:`SITE_ID` setting is not
defined.
:mod:`django.contrib.staticfiles` :mod:`django.contrib.staticfiles`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^