diff --git a/django/contrib/sites/managers.py b/django/contrib/sites/managers.py index 59215c44f92..3df485a0400 100644 --- a/django/contrib/sites/managers.py +++ b/django/contrib/sites/managers.py @@ -4,17 +4,38 @@ from django.db.models.fields import FieldDoesNotExist class CurrentSiteManager(models.Manager): "Use this to limit objects to those associated with the current site." - def __init__(self, field_name='site'): + def __init__(self, field_name=None): super(CurrentSiteManager, self).__init__() self.__field_name = field_name self.__is_validated = False - + + def _validate_field_name(self): + field_names = self.model._meta.get_all_field_names() + + # If a custom name is provided, make sure the field exists on the model + if self.__field_name is not None and self.__field_name not in field_names: + raise ValueError("%s couldn't find a field named %s in %s." % \ + (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) + + # Otherwise, see if there is a field called either 'site' or 'sites' + else: + for potential_name in ['site', 'sites']: + if potential_name in field_names: + self.__field_name = potential_name + self.__is_validated = True + break + + # Now do a type check on the field (FK or M2M only) + try: + field = self.model._meta.get_field(self.__field_name) + if not isinstance(field, (models.ForeignKey, models.ManyToManyField)): + raise TypeError("%s must be a ForeignKey or ManyToManyField." %self.__field_name) + except FieldDoesNotExist: + raise ValueError("%s couldn't find a field named %s in %s." % \ + (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) + self.__is_validated = True + def get_query_set(self): if not self.__is_validated: - try: - self.model._meta.get_field(self.__field_name) - except FieldDoesNotExist: - raise ValueError("%s couldn't find a field named %s in %s." % \ - (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) - self.__is_validated = True + self._validate_field_name() return super(CurrentSiteManager, self).get_query_set().filter(**{self.__field_name + '__id__exact': settings.SITE_ID}) diff --git a/tests/regressiontests/sites_framework/__init__.py b/tests/regressiontests/sites_framework/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/regressiontests/sites_framework/models.py b/tests/regressiontests/sites_framework/models.py new file mode 100644 index 00000000000..9ecc3e66600 --- /dev/null +++ b/tests/regressiontests/sites_framework/models.py @@ -0,0 +1,36 @@ +from django.contrib.sites.managers import CurrentSiteManager +from django.contrib.sites.models import Site +from django.db import models + +class AbstractArticle(models.Model): + title = models.CharField(max_length=50) + + objects = models.Manager() + on_site = CurrentSiteManager() + + class Meta: + abstract = True + + def __unicode__(self): + return self.title + +class SyndicatedArticle(AbstractArticle): + sites = models.ManyToManyField(Site) + +class ExclusiveArticle(AbstractArticle): + site = models.ForeignKey(Site) + +class CustomArticle(AbstractArticle): + places_this_article_should_appear = models.ForeignKey(Site) + + objects = models.Manager() + on_site = CurrentSiteManager("places_this_article_should_appear") + +class InvalidArticle(AbstractArticle): + site = models.ForeignKey(Site) + + objects = models.Manager() + on_site = CurrentSiteManager("places_this_article_should_appear") + +class ConfusedArticle(AbstractArticle): + site = models.IntegerField() diff --git a/tests/regressiontests/sites_framework/tests.py b/tests/regressiontests/sites_framework/tests.py new file mode 100644 index 00000000000..b737727a564 --- /dev/null +++ b/tests/regressiontests/sites_framework/tests.py @@ -0,0 +1,34 @@ +from django.conf import settings +from django.contrib.sites.models import Site +from django.test import TestCase + +from models import SyndicatedArticle, ExclusiveArticle, CustomArticle, InvalidArticle, ConfusedArticle + +class SitesFrameworkTestCase(TestCase): + def setUp(self): + Site.objects.get_or_create(id=settings.SITE_ID, domain="example.com", name="example.com") + Site.objects.create(id=settings.SITE_ID+1, domain="example2.com", name="example2.com") + + def test_site_fk(self): + article = ExclusiveArticle.objects.create(title="Breaking News!", site_id=settings.SITE_ID) + self.assertEqual(ExclusiveArticle.on_site.all().get(), article) + + def test_sites_m2m(self): + article = SyndicatedArticle.objects.create(title="Fresh News!") + article.sites.add(Site.objects.get(id=settings.SITE_ID)) + article.sites.add(Site.objects.get(id=settings.SITE_ID+1)) + article2 = SyndicatedArticle.objects.create(title="More News!") + article2.sites.add(Site.objects.get(id=settings.SITE_ID+1)) + self.assertEqual(SyndicatedArticle.on_site.all().get(), article) + + def test_custom_named_field(self): + article = CustomArticle.objects.create(title="Tantalizing News!", places_this_article_should_appear_id=settings.SITE_ID) + self.assertEqual(CustomArticle.on_site.all().get(), article) + + def test_invalid_name(self): + article = InvalidArticle.objects.create(title="Bad News!", site_id=settings.SITE_ID) + self.assertRaises(ValueError, InvalidArticle.on_site.all) + + def test_invalid_field_type(self): + article = ConfusedArticle.objects.create(title="More Bad News!", site=settings.SITE_ID) + self.assertRaises(TypeError, ConfusedArticle.on_site.all)