From 7bec480fe2ace94c8e7f0c88485442bfa74436b4 Mon Sep 17 00:00:00 2001 From: Alex Hill Date: Fri, 27 Mar 2015 00:52:11 +0800 Subject: [PATCH] Fixed #24201 -- Added order_with_respect_to support to GenericForeignKey. --- django/contrib/contenttypes/fields.py | 32 +++++++++- django/db/models/base.py | 58 ++++++++++--------- django/db/models/fields/__init__.py | 7 +++ django/db/models/fields/related.py | 43 +++++++++----- docs/releases/1.9.txt | 7 +++ tests/contenttypes_tests/models.py | 39 +++++++++++++ .../test_order_with_respect_to.py | 18 ++++++ tests/order_with_respect_to/models.py | 9 +++ tests/order_with_respect_to/tests.py | 55 +++++++++++------- 9 files changed, 205 insertions(+), 63 deletions(-) create mode 100644 tests/contenttypes_tests/test_order_with_respect_to.py diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 837a9d5bcac..1cd2fc0fa6c 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -7,9 +7,10 @@ from django.core import checks from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction from django.db.models import DO_NOTHING, signals -from django.db.models.base import ModelBase +from django.db.models.base import ModelBase, make_foreign_order_accessors from django.db.models.fields.related import ( ForeignObject, ForeignObjectRel, ForeignRelatedObjectsDescriptor, + lazy_related_operation, ) from django.db.models.query_utils import PathInfo from django.utils.encoding import python_2_unicode_compatible, smart_text @@ -61,6 +62,20 @@ class GenericForeignKey(object): setattr(cls, name, self) + def get_filter_kwargs_for_object(self, obj): + """See corresponding method on Field""" + return { + self.fk_field: getattr(obj, self.fk_field), + self.ct_field: getattr(obj, self.ct_field), + } + + def get_forward_related_filter(self, obj): + """See corresponding method on RelatedField""" + return { + self.fk_field: obj.pk, + self.ct_field: ContentType.objects.get_for_model(obj).pk, + } + def __str__(self): model = self.model app = model._meta.app_label @@ -368,6 +383,21 @@ class GenericRelation(ForeignObject): self.model = cls setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self.remote_field)) + # Add get_RELATED_order() and set_RELATED_order() methods if the model + # on the other end of this relation is ordered with respect to this. + def matching_gfk(field): + return ( + isinstance(field, GenericForeignKey) and + self.content_type_field_name == field.ct_field and + self.object_id_field_name == field.fk_field + ) + + def make_generic_foreign_order_accessors(related_model, model): + if matching_gfk(model._meta.order_with_respect_to): + make_foreign_order_accessors(model, related_model) + + lazy_related_operation(make_generic_foreign_order_accessors, self.model, self.remote_field.model) + def set_attributes_from_rel(self): pass diff --git a/django/db/models/base.py b/django/db/models/base.py index 3cbabbc576d..2c4a475cd45 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -311,21 +311,15 @@ class ModelBase(type): cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True) cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False) - # defer creating accessors on the foreign class until we are - # certain it has been created - def make_foreign_order_accessors(cls, model, field): - setattr( - field.remote_field.model, - 'get_%s_order' % cls.__name__.lower(), - curry(method_get_order, cls) - ) - setattr( - field.remote_field.model, - 'set_%s_order' % cls.__name__.lower(), - curry(method_set_order, cls) - ) - wrt = opts.order_with_respect_to - lazy_related_operation(make_foreign_order_accessors, cls, wrt.remote_field.model, field=wrt) + # Defer creating accessors on the foreign class until it has been + # created and registered. If remote_field is None, we're ordering + # with respect to a GenericForeignKey and don't know what the + # foreign class is - we'll add those accessors later in + # contribute_to_class(). + if opts.order_with_respect_to.remote_field: + wrt = opts.order_with_respect_to + remote = wrt.remote_field.model + lazy_related_operation(make_foreign_order_accessors, cls, remote) # Give the class a docstring -- its definition. if cls.__doc__ is None: @@ -803,8 +797,8 @@ class Model(six.with_metaclass(ModelBase)): # If this is a model with an order_with_respect_to # autopopulate the _order field field = meta.order_with_respect_to - order_value = cls._base_manager.using(using).filter( - **{field.name: getattr(self, field.attname)}).count() + filter_args = field.get_filter_kwargs_for_object(self) + order_value = cls._base_manager.using(using).filter(**filter_args).count() self._order = order_value fields = meta.local_concrete_fields @@ -892,9 +886,8 @@ class Model(six.with_metaclass(ModelBase)): op = 'gt' if is_next else 'lt' order = '_order' if is_next else '-_order' order_field = self._meta.order_with_respect_to - obj = self._default_manager.filter(**{ - order_field.name: getattr(self, order_field.attname) - }).filter(**{ + filter_args = order_field.get_filter_kwargs_for_object(self) + obj = self._default_manager.filter(**filter_args).filter(**{ '_order__%s' % op: self._default_manager.values('_order').filter(**{ self._meta.pk.name: self.pk }) @@ -1653,23 +1646,34 @@ class Model(six.with_metaclass(ModelBase)): def method_set_order(ordered_obj, self, id_list, using=None): if using is None: using = DEFAULT_DB_ALIAS - rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.remote_field.field_name) - order_name = ordered_obj._meta.order_with_respect_to.name + order_wrt = ordered_obj._meta.order_with_respect_to + filter_args = order_wrt.get_forward_related_filter(self) # FIXME: It would be nice if there was an "update many" version of update # for situations like this. with transaction.atomic(using=using, savepoint=False): for i, j in enumerate(id_list): - ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i) + ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i) def method_get_order(ordered_obj, self): - rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.remote_field.field_name) - order_name = ordered_obj._meta.order_with_respect_to.name + order_wrt = ordered_obj._meta.order_with_respect_to + filter_args = order_wrt.get_forward_related_filter(self) pk_name = ordered_obj._meta.pk.name - return [r[pk_name] for r in - ordered_obj.objects.filter(**{order_name: rel_val}).values(pk_name)] + return ordered_obj.objects.filter(**filter_args).values_list(pk_name, flat=True) +def make_foreign_order_accessors(model, related_model): + setattr( + related_model, + 'get_%s_order' % model.__name__.lower(), + curry(method_get_order, model) + ) + setattr( + related_model, + 'set_%s_order' % model.__name__.lower(), + curry(method_set_order, model) + ) + ######## # MISC # ######## diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 5cd96d2902f..91bf0285d94 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -678,6 +678,13 @@ class Field(RegisterLookupMixin): setattr(cls, 'get_%s_display' % self.name, curry(cls._get_FIELD_display, field=self)) + def get_filter_kwargs_for_object(self, obj): + """ + Return a dict that when passed as kwargs to self.model.filter(), would + yield all instances having the same value for this field as obj has. + """ + return {self.name: getattr(obj, self.attname)} + def get_attname(self): return self.name diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index c433ce239de..5e7bda5d5c8 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -303,6 +303,33 @@ class RelatedField(Field): field.do_related_class(related, model) lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self) + def get_forward_related_filter(self, obj): + """ + Return the keyword arguments that when supplied to + self.model.object.filter(), would select all instances related through + this field to the remote obj. This is used to build the querysets + returned by related descriptors. obj is an instance of + self.related_field.model. + """ + return { + '%s__%s' % (self.name, rh_field.name): getattr(obj, rh_field.attname) + for _, rh_field in self.related_fields + } + + def get_reverse_related_filter(self, obj): + """ + Complement to get_forward_related_filter(). Return the keyword + arguments that when passed to self.related_field.model.object.filter() + select all instances of self.related_field.model related through + this field to obj. obj is an instance of self.model. + """ + base_filter = { + rh_field.attname: getattr(obj, lh_field.attname) + for lh_field, rh_field in self.related_fields + } + base_filter.update(self.get_extra_descriptor_filter(obj) or {}) + return base_filter + @property def swappable_setting(self): """ @@ -453,11 +480,9 @@ class SingleRelatedObjectDescriptor(object): if related_pk is None: rel_obj = None else: - params = {} - for lh_field, rh_field in self.related.field.related_fields: - params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname) + filter_args = self.related.field.get_forward_related_filter(instance) try: - rel_obj = self.get_queryset(instance=instance).get(**params) + rel_obj = self.get_queryset(instance=instance).get(**filter_args) except self.related.related_model.DoesNotExist: rel_obj = None else: @@ -603,16 +628,8 @@ class ReverseSingleRelatedObjectDescriptor(object): if None in val: rel_obj = None else: - params = { - rh_field.attname: getattr(instance, lh_field.attname) - for lh_field, rh_field in self.field.related_fields} qs = self.get_queryset(instance=instance) - extra_filter = self.field.get_extra_descriptor_filter(instance) - if isinstance(extra_filter, dict): - params.update(extra_filter) - qs = qs.filter(**params) - else: - qs = qs.filter(extra_filter, **params) + qs = qs.filter(**self.field.get_reverse_related_filter(instance)) # Assuming the database enforces foreign keys, this won't fail. rel_obj = qs.get() if not self.field.remote_field.multiple: diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt index 6460b8a7bc3..1a5c488d0f4 100644 --- a/docs/releases/1.9.txt +++ b/docs/releases/1.9.txt @@ -187,6 +187,13 @@ Minor features makes it possible to use ``REMOTE_USER`` for setups where the header is only populated on login pages instead of every request in the session. +:mod:`django.contrib.contenttypes` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* It's now possible to use + :attr:`~django.db.models.Options.order_with_respect_to` with a + ``GenericForeignKey``. + :mod:`django.contrib.gis` ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/contenttypes_tests/models.py b/tests/contenttypes_tests/models.py index 88b926fce6f..a302fe97f01 100644 --- a/tests/contenttypes_tests/models.py +++ b/tests/contenttypes_tests/models.py @@ -1,5 +1,9 @@ from __future__ import unicode_literals +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType from django.db import models from django.utils.encoding import python_2_unicode_compatible from django.utils.http import urlquote @@ -76,3 +80,38 @@ class FooWithBrokenAbsoluteUrl(FooWithoutUrl): def get_absolute_url(self): return "/users/%s/" % self.unknown_field + + +class Question(models.Model): + text = models.CharField(max_length=200) + answer_set = GenericRelation('Answer') + + +@python_2_unicode_compatible +class Answer(models.Model): + text = models.CharField(max_length=200) + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + question = GenericForeignKey() + + class Meta: + order_with_respect_to = 'question' + + def __str__(self): + return self.text + + +@python_2_unicode_compatible +class Post(models.Model): + """An ordered tag on an item.""" + title = models.CharField(max_length=200) + content_type = models.ForeignKey(ContentType, models.CASCADE, null=True) + object_id = models.PositiveIntegerField(null=True) + parent = GenericForeignKey() + children = GenericRelation('Post') + + class Meta: + order_with_respect_to = 'parent' + + def __str__(self): + return self.title diff --git a/tests/contenttypes_tests/test_order_with_respect_to.py b/tests/contenttypes_tests/test_order_with_respect_to.py new file mode 100644 index 00000000000..5d7f99d48b0 --- /dev/null +++ b/tests/contenttypes_tests/test_order_with_respect_to.py @@ -0,0 +1,18 @@ +from order_with_respect_to.tests import ( + OrderWithRespectToTests, OrderWithRespectToTests2, +) + +from .models import Answer, Post, Question + + +class OrderWithRespectToGFKTests(OrderWithRespectToTests): + Answer = Answer + Question = Question + +del OrderWithRespectToTests + + +class OrderWithRespectToGFKTests2(OrderWithRespectToTests2): + Post = Post + +del OrderWithRespectToTests2 diff --git a/tests/order_with_respect_to/models.py b/tests/order_with_respect_to/models.py index 7afc917b4bb..5f531a32a54 100644 --- a/tests/order_with_respect_to/models.py +++ b/tests/order_with_respect_to/models.py @@ -1,5 +1,9 @@ """ Tests for the order_with_respect_to Meta attribute. + +We explicitly declare app_label on these models, because they are reused by +contenttypes_tests. When those tests are run in isolation, these models need +app_label because order_with_respect_to isn't in INSTALLED_APPS. """ from django.db import models @@ -10,6 +14,9 @@ from django.utils.encoding import python_2_unicode_compatible class Question(models.Model): text = models.CharField(max_length=200) + class Meta: + app_label = 'order_with_respect_to' + @python_2_unicode_compatible class Answer(models.Model): @@ -18,6 +25,7 @@ class Answer(models.Model): class Meta: order_with_respect_to = 'question' + app_label = 'order_with_respect_to' def __str__(self): return six.text_type(self.text) @@ -30,6 +38,7 @@ class Post(models.Model): class Meta: order_with_respect_to = "parent" + app_label = 'order_with_respect_to' def __str__(self): return self.title diff --git a/tests/order_with_respect_to/tests.py b/tests/order_with_respect_to/tests.py index ff92fdb6fbf..6c98ee1b9b1 100644 --- a/tests/order_with_respect_to/tests.py +++ b/tests/order_with_respect_to/tests.py @@ -10,13 +10,17 @@ from .models import Answer, Post, Question class OrderWithRespectToTests(TestCase): + # Hook to allow subclasses to run these tests with alternate models. + Answer = Answer + Question = Question + @classmethod def setUpTestData(cls): - cls.q1 = Question.objects.create(text="Which Beatle starts with the letter 'R'?") - Answer.objects.create(text="John", question=cls.q1) - Answer.objects.create(text="Paul", question=cls.q1) - Answer.objects.create(text="George", question=cls.q1) - Answer.objects.create(text="Ringo", question=cls.q1) + cls.q1 = cls.Question.objects.create(text="Which Beatle starts with the letter 'R'?") + cls.Answer.objects.create(text="John", question=cls.q1) + cls.Answer.objects.create(text="Paul", question=cls.q1) + cls.Answer.objects.create(text="George", question=cls.q1) + cls.Answer.objects.create(text="Ringo", question=cls.q1) def test_default_to_insertion_order(self): # Answers will always be ordered in the order they were inserted. @@ -30,30 +34,30 @@ class OrderWithRespectToTests(TestCase): def test_previous_and_next_in_order(self): # We can retrieve the answers related to a particular object, in the # order they were created, once we have a particular object. - a1 = Answer.objects.filter(question=self.q1)[0] + a1 = self.q1.answer_set.all()[0] self.assertEqual(a1.text, "John") self.assertEqual(a1.get_next_in_order().text, "Paul") - a2 = list(Answer.objects.filter(question=self.q1))[-1] + a2 = list(self.q1.answer_set.all())[-1] self.assertEqual(a2.text, "Ringo") self.assertEqual(a2.get_previous_in_order().text, "George") def test_item_ordering(self): # We can retrieve the ordering of the queryset from a particular item. - a1 = Answer.objects.filter(question=self.q1)[1] + a1 = self.q1.answer_set.all()[1] id_list = [o.pk for o in self.q1.answer_set.all()] - self.assertEqual(a1.question.get_answer_order(), id_list) + self.assertSequenceEqual(a1.question.get_answer_order(), id_list) # It doesn't matter which answer we use to check the order, it will # always be the same. - a2 = Answer.objects.create(text="Number five", question=self.q1) - self.assertEqual( - a1.question.get_answer_order(), a2.question.get_answer_order() + a2 = self.Answer.objects.create(text="Number five", question=self.q1) + self.assertListEqual( + list(a1.question.get_answer_order()), list(a2.question.get_answer_order()) ) def test_change_ordering(self): # The ordering can be altered - a = Answer.objects.create(text="Number five", question=self.q1) + a = self.Answer.objects.create(text="Number five", question=self.q1) # Swap the last two items in the order list id_list = [o.pk for o in self.q1.answer_set.all()] @@ -61,7 +65,7 @@ class OrderWithRespectToTests(TestCase): id_list.insert(-1, x) # By default, the ordering is different from the swapped version - self.assertNotEqual(a.question.get_answer_order(), id_list) + self.assertNotEqual(list(a.question.get_answer_order()), id_list) # Change the ordering to the swapped version - # this changes the ordering of the queryset. @@ -76,19 +80,25 @@ class OrderWithRespectToTests(TestCase): class OrderWithRespectToTests2(TestCase): + # Provide the Post model as a class attribute so that we can subclass this + # test case in contenttypes_tests.test_order_with_respect_to and run these + # tests with alternative implementations of Post. + Post = Post + def test_recursive_ordering(self): - p1 = Post.objects.create(title='1') - p2 = Post.objects.create(title='2') - p1_1 = Post.objects.create(title="1.1", parent=p1) - p1_2 = Post.objects.create(title="1.2", parent=p1) - Post.objects.create(title="2.1", parent=p2) - p1_3 = Post.objects.create(title="1.3", parent=p1) - self.assertEqual(p1.get_post_order(), [p1_1.pk, p1_2.pk, p1_3.pk]) + p1 = self.Post.objects.create(title="1") + p2 = self.Post.objects.create(title="2") + p1_1 = self.Post.objects.create(title="1.1", parent=p1) + p1_2 = self.Post.objects.create(title="1.2", parent=p1) + self.Post.objects.create(title="2.1", parent=p2) + p1_3 = self.Post.objects.create(title="1.3", parent=p1) + self.assertSequenceEqual(p1.get_post_order(), [p1_1.pk, p1_2.pk, p1_3.pk]) def test_duplicate_order_field(self): class Bar(models.Model): - pass + class Meta: + app_label = 'order_with_respect_to' class Foo(models.Model): bar = models.ForeignKey(Bar, models.CASCADE) @@ -96,6 +106,7 @@ class OrderWithRespectToTests2(TestCase): class Meta: order_with_respect_to = 'bar' + app_label = 'order_with_respect_to' count = 0 for field in Foo._meta.local_fields: