diff --git a/tests/regressiontests/generic_relations_regress/models.py b/tests/regressiontests/generic_relations_regress/models.py index dacc9b380b..bdd8ec5eda 100644 --- a/tests/regressiontests/generic_relations_regress/models.py +++ b/tests/regressiontests/generic_relations_regress/models.py @@ -6,7 +6,7 @@ from django.utils.encoding import python_2_unicode_compatible __all__ = ('Link', 'Place', 'Restaurant', 'Person', 'Address', 'CharLink', 'TextLink', 'OddRelation1', 'OddRelation2', - 'Contact', 'Organization', 'Note') + 'Contact', 'Organization', 'Note', 'Company') @python_2_unicode_compatible class Link(models.Model): @@ -84,3 +84,10 @@ class Organization(models.Model): name = models.CharField(max_length=255) contacts = models.ManyToManyField(Contact, related_name='organizations') +@python_2_unicode_compatible +class Company(models.Model): + name = models.CharField(max_length=100) + links = generic.GenericRelation(Link) + + def __str__(self): + return "Company: %s" % self.name diff --git a/tests/regressiontests/generic_relations_regress/tests.py b/tests/regressiontests/generic_relations_regress/tests.py index 262c2e4917..7690fd560a 100644 --- a/tests/regressiontests/generic_relations_regress/tests.py +++ b/tests/regressiontests/generic_relations_regress/tests.py @@ -2,7 +2,7 @@ from django.db.models import Q from django.test import TestCase from .models import (Address, Place, Restaurant, Link, CharLink, TextLink, - Person, Contact, Note, Organization, OddRelation1, OddRelation2) + Person, Contact, Note, Organization, OddRelation1, OddRelation2, Company) class GenericRelationTests(TestCase): @@ -80,3 +80,21 @@ class GenericRelationTests(TestCase): ) self.assertEqual(str(qs.query).count('JOIN'), 2) + def test_generic_relation_ordering(self): + """ + Test that ordering over a generic relation does not include extraneous + duplicate results, nor excludes rows not participating in the relation. + """ + p1 = Place.objects.create(name="South Park") + p2 = Place.objects.create(name="The City") + c = Company.objects.create(name="Chubby's Intl.") + l1 = Link.objects.create(content_object=p1) + l2 = Link.objects.create(content_object=c) + + places = list(Place.objects.order_by('links__id')) + def count_places(place): + return len(filter(lambda p: p.id == place.id, places)) + + self.assertEqual(len(places), 2) + self.assertEqual(count_places(p1), 1) + self.assertEqual(count_places(p2), 1)