mirror of https://github.com/django/django.git
Fixed #24201 -- Added order_with_respect_to support to GenericForeignKey.
This commit is contained in:
parent
e1427cc609
commit
7bec480fe2
|
@ -7,9 +7,10 @@ from django.core import checks
|
|||
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
|
||||
from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction
|
||||
from django.db.models import DO_NOTHING, signals
|
||||
from django.db.models.base import ModelBase
|
||||
from django.db.models.base import ModelBase, make_foreign_order_accessors
|
||||
from django.db.models.fields.related import (
|
||||
ForeignObject, ForeignObjectRel, ForeignRelatedObjectsDescriptor,
|
||||
lazy_related_operation,
|
||||
)
|
||||
from django.db.models.query_utils import PathInfo
|
||||
from django.utils.encoding import python_2_unicode_compatible, smart_text
|
||||
|
@ -61,6 +62,20 @@ class GenericForeignKey(object):
|
|||
|
||||
setattr(cls, name, self)
|
||||
|
||||
def get_filter_kwargs_for_object(self, obj):
|
||||
"""See corresponding method on Field"""
|
||||
return {
|
||||
self.fk_field: getattr(obj, self.fk_field),
|
||||
self.ct_field: getattr(obj, self.ct_field),
|
||||
}
|
||||
|
||||
def get_forward_related_filter(self, obj):
|
||||
"""See corresponding method on RelatedField"""
|
||||
return {
|
||||
self.fk_field: obj.pk,
|
||||
self.ct_field: ContentType.objects.get_for_model(obj).pk,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
model = self.model
|
||||
app = model._meta.app_label
|
||||
|
@ -368,6 +383,21 @@ class GenericRelation(ForeignObject):
|
|||
self.model = cls
|
||||
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self.remote_field))
|
||||
|
||||
# Add get_RELATED_order() and set_RELATED_order() methods if the model
|
||||
# on the other end of this relation is ordered with respect to this.
|
||||
def matching_gfk(field):
|
||||
return (
|
||||
isinstance(field, GenericForeignKey) and
|
||||
self.content_type_field_name == field.ct_field and
|
||||
self.object_id_field_name == field.fk_field
|
||||
)
|
||||
|
||||
def make_generic_foreign_order_accessors(related_model, model):
|
||||
if matching_gfk(model._meta.order_with_respect_to):
|
||||
make_foreign_order_accessors(model, related_model)
|
||||
|
||||
lazy_related_operation(make_generic_foreign_order_accessors, self.model, self.remote_field.model)
|
||||
|
||||
def set_attributes_from_rel(self):
|
||||
pass
|
||||
|
||||
|
|
|
@ -311,21 +311,15 @@ class ModelBase(type):
|
|||
cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True)
|
||||
cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False)
|
||||
|
||||
# defer creating accessors on the foreign class until we are
|
||||
# certain it has been created
|
||||
def make_foreign_order_accessors(cls, model, field):
|
||||
setattr(
|
||||
field.remote_field.model,
|
||||
'get_%s_order' % cls.__name__.lower(),
|
||||
curry(method_get_order, cls)
|
||||
)
|
||||
setattr(
|
||||
field.remote_field.model,
|
||||
'set_%s_order' % cls.__name__.lower(),
|
||||
curry(method_set_order, cls)
|
||||
)
|
||||
wrt = opts.order_with_respect_to
|
||||
lazy_related_operation(make_foreign_order_accessors, cls, wrt.remote_field.model, field=wrt)
|
||||
# Defer creating accessors on the foreign class until it has been
|
||||
# created and registered. If remote_field is None, we're ordering
|
||||
# with respect to a GenericForeignKey and don't know what the
|
||||
# foreign class is - we'll add those accessors later in
|
||||
# contribute_to_class().
|
||||
if opts.order_with_respect_to.remote_field:
|
||||
wrt = opts.order_with_respect_to
|
||||
remote = wrt.remote_field.model
|
||||
lazy_related_operation(make_foreign_order_accessors, cls, remote)
|
||||
|
||||
# Give the class a docstring -- its definition.
|
||||
if cls.__doc__ is None:
|
||||
|
@ -803,8 +797,8 @@ class Model(six.with_metaclass(ModelBase)):
|
|||
# If this is a model with an order_with_respect_to
|
||||
# autopopulate the _order field
|
||||
field = meta.order_with_respect_to
|
||||
order_value = cls._base_manager.using(using).filter(
|
||||
**{field.name: getattr(self, field.attname)}).count()
|
||||
filter_args = field.get_filter_kwargs_for_object(self)
|
||||
order_value = cls._base_manager.using(using).filter(**filter_args).count()
|
||||
self._order = order_value
|
||||
|
||||
fields = meta.local_concrete_fields
|
||||
|
@ -892,9 +886,8 @@ class Model(six.with_metaclass(ModelBase)):
|
|||
op = 'gt' if is_next else 'lt'
|
||||
order = '_order' if is_next else '-_order'
|
||||
order_field = self._meta.order_with_respect_to
|
||||
obj = self._default_manager.filter(**{
|
||||
order_field.name: getattr(self, order_field.attname)
|
||||
}).filter(**{
|
||||
filter_args = order_field.get_filter_kwargs_for_object(self)
|
||||
obj = self._default_manager.filter(**filter_args).filter(**{
|
||||
'_order__%s' % op: self._default_manager.values('_order').filter(**{
|
||||
self._meta.pk.name: self.pk
|
||||
})
|
||||
|
@ -1653,23 +1646,34 @@ class Model(six.with_metaclass(ModelBase)):
|
|||
def method_set_order(ordered_obj, self, id_list, using=None):
|
||||
if using is None:
|
||||
using = DEFAULT_DB_ALIAS
|
||||
rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.remote_field.field_name)
|
||||
order_name = ordered_obj._meta.order_with_respect_to.name
|
||||
order_wrt = ordered_obj._meta.order_with_respect_to
|
||||
filter_args = order_wrt.get_forward_related_filter(self)
|
||||
# FIXME: It would be nice if there was an "update many" version of update
|
||||
# for situations like this.
|
||||
with transaction.atomic(using=using, savepoint=False):
|
||||
for i, j in enumerate(id_list):
|
||||
ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
|
||||
ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i)
|
||||
|
||||
|
||||
def method_get_order(ordered_obj, self):
|
||||
rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.remote_field.field_name)
|
||||
order_name = ordered_obj._meta.order_with_respect_to.name
|
||||
order_wrt = ordered_obj._meta.order_with_respect_to
|
||||
filter_args = order_wrt.get_forward_related_filter(self)
|
||||
pk_name = ordered_obj._meta.pk.name
|
||||
return [r[pk_name] for r in
|
||||
ordered_obj.objects.filter(**{order_name: rel_val}).values(pk_name)]
|
||||
return ordered_obj.objects.filter(**filter_args).values_list(pk_name, flat=True)
|
||||
|
||||
|
||||
def make_foreign_order_accessors(model, related_model):
|
||||
setattr(
|
||||
related_model,
|
||||
'get_%s_order' % model.__name__.lower(),
|
||||
curry(method_get_order, model)
|
||||
)
|
||||
setattr(
|
||||
related_model,
|
||||
'set_%s_order' % model.__name__.lower(),
|
||||
curry(method_set_order, model)
|
||||
)
|
||||
|
||||
########
|
||||
# MISC #
|
||||
########
|
||||
|
|
|
@ -678,6 +678,13 @@ class Field(RegisterLookupMixin):
|
|||
setattr(cls, 'get_%s_display' % self.name,
|
||||
curry(cls._get_FIELD_display, field=self))
|
||||
|
||||
def get_filter_kwargs_for_object(self, obj):
|
||||
"""
|
||||
Return a dict that when passed as kwargs to self.model.filter(), would
|
||||
yield all instances having the same value for this field as obj has.
|
||||
"""
|
||||
return {self.name: getattr(obj, self.attname)}
|
||||
|
||||
def get_attname(self):
|
||||
return self.name
|
||||
|
||||
|
|
|
@ -303,6 +303,33 @@ class RelatedField(Field):
|
|||
field.do_related_class(related, model)
|
||||
lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self)
|
||||
|
||||
def get_forward_related_filter(self, obj):
|
||||
"""
|
||||
Return the keyword arguments that when supplied to
|
||||
self.model.object.filter(), would select all instances related through
|
||||
this field to the remote obj. This is used to build the querysets
|
||||
returned by related descriptors. obj is an instance of
|
||||
self.related_field.model.
|
||||
"""
|
||||
return {
|
||||
'%s__%s' % (self.name, rh_field.name): getattr(obj, rh_field.attname)
|
||||
for _, rh_field in self.related_fields
|
||||
}
|
||||
|
||||
def get_reverse_related_filter(self, obj):
|
||||
"""
|
||||
Complement to get_forward_related_filter(). Return the keyword
|
||||
arguments that when passed to self.related_field.model.object.filter()
|
||||
select all instances of self.related_field.model related through
|
||||
this field to obj. obj is an instance of self.model.
|
||||
"""
|
||||
base_filter = {
|
||||
rh_field.attname: getattr(obj, lh_field.attname)
|
||||
for lh_field, rh_field in self.related_fields
|
||||
}
|
||||
base_filter.update(self.get_extra_descriptor_filter(obj) or {})
|
||||
return base_filter
|
||||
|
||||
@property
|
||||
def swappable_setting(self):
|
||||
"""
|
||||
|
@ -453,11 +480,9 @@ class SingleRelatedObjectDescriptor(object):
|
|||
if related_pk is None:
|
||||
rel_obj = None
|
||||
else:
|
||||
params = {}
|
||||
for lh_field, rh_field in self.related.field.related_fields:
|
||||
params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname)
|
||||
filter_args = self.related.field.get_forward_related_filter(instance)
|
||||
try:
|
||||
rel_obj = self.get_queryset(instance=instance).get(**params)
|
||||
rel_obj = self.get_queryset(instance=instance).get(**filter_args)
|
||||
except self.related.related_model.DoesNotExist:
|
||||
rel_obj = None
|
||||
else:
|
||||
|
@ -603,16 +628,8 @@ class ReverseSingleRelatedObjectDescriptor(object):
|
|||
if None in val:
|
||||
rel_obj = None
|
||||
else:
|
||||
params = {
|
||||
rh_field.attname: getattr(instance, lh_field.attname)
|
||||
for lh_field, rh_field in self.field.related_fields}
|
||||
qs = self.get_queryset(instance=instance)
|
||||
extra_filter = self.field.get_extra_descriptor_filter(instance)
|
||||
if isinstance(extra_filter, dict):
|
||||
params.update(extra_filter)
|
||||
qs = qs.filter(**params)
|
||||
else:
|
||||
qs = qs.filter(extra_filter, **params)
|
||||
qs = qs.filter(**self.field.get_reverse_related_filter(instance))
|
||||
# Assuming the database enforces foreign keys, this won't fail.
|
||||
rel_obj = qs.get()
|
||||
if not self.field.remote_field.multiple:
|
||||
|
|
|
@ -187,6 +187,13 @@ Minor features
|
|||
makes it possible to use ``REMOTE_USER`` for setups where the header is only
|
||||
populated on login pages instead of every request in the session.
|
||||
|
||||
:mod:`django.contrib.contenttypes`
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
* It's now possible to use
|
||||
:attr:`~django.db.models.Options.order_with_respect_to` with a
|
||||
``GenericForeignKey``.
|
||||
|
||||
:mod:`django.contrib.gis`
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.contrib.contenttypes.fields import (
|
||||
GenericForeignKey, GenericRelation,
|
||||
)
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import models
|
||||
from django.utils.encoding import python_2_unicode_compatible
|
||||
from django.utils.http import urlquote
|
||||
|
@ -76,3 +80,38 @@ class FooWithBrokenAbsoluteUrl(FooWithoutUrl):
|
|||
|
||||
def get_absolute_url(self):
|
||||
return "/users/%s/" % self.unknown_field
|
||||
|
||||
|
||||
class Question(models.Model):
|
||||
text = models.CharField(max_length=200)
|
||||
answer_set = GenericRelation('Answer')
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Answer(models.Model):
|
||||
text = models.CharField(max_length=200)
|
||||
content_type = models.ForeignKey(ContentType, models.CASCADE)
|
||||
object_id = models.PositiveIntegerField()
|
||||
question = GenericForeignKey()
|
||||
|
||||
class Meta:
|
||||
order_with_respect_to = 'question'
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Post(models.Model):
|
||||
"""An ordered tag on an item."""
|
||||
title = models.CharField(max_length=200)
|
||||
content_type = models.ForeignKey(ContentType, models.CASCADE, null=True)
|
||||
object_id = models.PositiveIntegerField(null=True)
|
||||
parent = GenericForeignKey()
|
||||
children = GenericRelation('Post')
|
||||
|
||||
class Meta:
|
||||
order_with_respect_to = 'parent'
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
from order_with_respect_to.tests import (
|
||||
OrderWithRespectToTests, OrderWithRespectToTests2,
|
||||
)
|
||||
|
||||
from .models import Answer, Post, Question
|
||||
|
||||
|
||||
class OrderWithRespectToGFKTests(OrderWithRespectToTests):
|
||||
Answer = Answer
|
||||
Question = Question
|
||||
|
||||
del OrderWithRespectToTests
|
||||
|
||||
|
||||
class OrderWithRespectToGFKTests2(OrderWithRespectToTests2):
|
||||
Post = Post
|
||||
|
||||
del OrderWithRespectToTests2
|
|
@ -1,5 +1,9 @@
|
|||
"""
|
||||
Tests for the order_with_respect_to Meta attribute.
|
||||
|
||||
We explicitly declare app_label on these models, because they are reused by
|
||||
contenttypes_tests. When those tests are run in isolation, these models need
|
||||
app_label because order_with_respect_to isn't in INSTALLED_APPS.
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
@ -10,6 +14,9 @@ from django.utils.encoding import python_2_unicode_compatible
|
|||
class Question(models.Model):
|
||||
text = models.CharField(max_length=200)
|
||||
|
||||
class Meta:
|
||||
app_label = 'order_with_respect_to'
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Answer(models.Model):
|
||||
|
@ -18,6 +25,7 @@ class Answer(models.Model):
|
|||
|
||||
class Meta:
|
||||
order_with_respect_to = 'question'
|
||||
app_label = 'order_with_respect_to'
|
||||
|
||||
def __str__(self):
|
||||
return six.text_type(self.text)
|
||||
|
@ -30,6 +38,7 @@ class Post(models.Model):
|
|||
|
||||
class Meta:
|
||||
order_with_respect_to = "parent"
|
||||
app_label = 'order_with_respect_to'
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
|
|
|
@ -10,13 +10,17 @@ from .models import Answer, Post, Question
|
|||
|
||||
class OrderWithRespectToTests(TestCase):
|
||||
|
||||
# Hook to allow subclasses to run these tests with alternate models.
|
||||
Answer = Answer
|
||||
Question = Question
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.q1 = Question.objects.create(text="Which Beatle starts with the letter 'R'?")
|
||||
Answer.objects.create(text="John", question=cls.q1)
|
||||
Answer.objects.create(text="Paul", question=cls.q1)
|
||||
Answer.objects.create(text="George", question=cls.q1)
|
||||
Answer.objects.create(text="Ringo", question=cls.q1)
|
||||
cls.q1 = cls.Question.objects.create(text="Which Beatle starts with the letter 'R'?")
|
||||
cls.Answer.objects.create(text="John", question=cls.q1)
|
||||
cls.Answer.objects.create(text="Paul", question=cls.q1)
|
||||
cls.Answer.objects.create(text="George", question=cls.q1)
|
||||
cls.Answer.objects.create(text="Ringo", question=cls.q1)
|
||||
|
||||
def test_default_to_insertion_order(self):
|
||||
# Answers will always be ordered in the order they were inserted.
|
||||
|
@ -30,30 +34,30 @@ class OrderWithRespectToTests(TestCase):
|
|||
def test_previous_and_next_in_order(self):
|
||||
# We can retrieve the answers related to a particular object, in the
|
||||
# order they were created, once we have a particular object.
|
||||
a1 = Answer.objects.filter(question=self.q1)[0]
|
||||
a1 = self.q1.answer_set.all()[0]
|
||||
self.assertEqual(a1.text, "John")
|
||||
self.assertEqual(a1.get_next_in_order().text, "Paul")
|
||||
|
||||
a2 = list(Answer.objects.filter(question=self.q1))[-1]
|
||||
a2 = list(self.q1.answer_set.all())[-1]
|
||||
self.assertEqual(a2.text, "Ringo")
|
||||
self.assertEqual(a2.get_previous_in_order().text, "George")
|
||||
|
||||
def test_item_ordering(self):
|
||||
# We can retrieve the ordering of the queryset from a particular item.
|
||||
a1 = Answer.objects.filter(question=self.q1)[1]
|
||||
a1 = self.q1.answer_set.all()[1]
|
||||
id_list = [o.pk for o in self.q1.answer_set.all()]
|
||||
self.assertEqual(a1.question.get_answer_order(), id_list)
|
||||
self.assertSequenceEqual(a1.question.get_answer_order(), id_list)
|
||||
|
||||
# It doesn't matter which answer we use to check the order, it will
|
||||
# always be the same.
|
||||
a2 = Answer.objects.create(text="Number five", question=self.q1)
|
||||
self.assertEqual(
|
||||
a1.question.get_answer_order(), a2.question.get_answer_order()
|
||||
a2 = self.Answer.objects.create(text="Number five", question=self.q1)
|
||||
self.assertListEqual(
|
||||
list(a1.question.get_answer_order()), list(a2.question.get_answer_order())
|
||||
)
|
||||
|
||||
def test_change_ordering(self):
|
||||
# The ordering can be altered
|
||||
a = Answer.objects.create(text="Number five", question=self.q1)
|
||||
a = self.Answer.objects.create(text="Number five", question=self.q1)
|
||||
|
||||
# Swap the last two items in the order list
|
||||
id_list = [o.pk for o in self.q1.answer_set.all()]
|
||||
|
@ -61,7 +65,7 @@ class OrderWithRespectToTests(TestCase):
|
|||
id_list.insert(-1, x)
|
||||
|
||||
# By default, the ordering is different from the swapped version
|
||||
self.assertNotEqual(a.question.get_answer_order(), id_list)
|
||||
self.assertNotEqual(list(a.question.get_answer_order()), id_list)
|
||||
|
||||
# Change the ordering to the swapped version -
|
||||
# this changes the ordering of the queryset.
|
||||
|
@ -76,19 +80,25 @@ class OrderWithRespectToTests(TestCase):
|
|||
|
||||
class OrderWithRespectToTests2(TestCase):
|
||||
|
||||
# Provide the Post model as a class attribute so that we can subclass this
|
||||
# test case in contenttypes_tests.test_order_with_respect_to and run these
|
||||
# tests with alternative implementations of Post.
|
||||
Post = Post
|
||||
|
||||
def test_recursive_ordering(self):
|
||||
p1 = Post.objects.create(title='1')
|
||||
p2 = Post.objects.create(title='2')
|
||||
p1_1 = Post.objects.create(title="1.1", parent=p1)
|
||||
p1_2 = Post.objects.create(title="1.2", parent=p1)
|
||||
Post.objects.create(title="2.1", parent=p2)
|
||||
p1_3 = Post.objects.create(title="1.3", parent=p1)
|
||||
self.assertEqual(p1.get_post_order(), [p1_1.pk, p1_2.pk, p1_3.pk])
|
||||
p1 = self.Post.objects.create(title="1")
|
||||
p2 = self.Post.objects.create(title="2")
|
||||
p1_1 = self.Post.objects.create(title="1.1", parent=p1)
|
||||
p1_2 = self.Post.objects.create(title="1.2", parent=p1)
|
||||
self.Post.objects.create(title="2.1", parent=p2)
|
||||
p1_3 = self.Post.objects.create(title="1.3", parent=p1)
|
||||
self.assertSequenceEqual(p1.get_post_order(), [p1_1.pk, p1_2.pk, p1_3.pk])
|
||||
|
||||
def test_duplicate_order_field(self):
|
||||
|
||||
class Bar(models.Model):
|
||||
pass
|
||||
class Meta:
|
||||
app_label = 'order_with_respect_to'
|
||||
|
||||
class Foo(models.Model):
|
||||
bar = models.ForeignKey(Bar, models.CASCADE)
|
||||
|
@ -96,6 +106,7 @@ class OrderWithRespectToTests2(TestCase):
|
|||
|
||||
class Meta:
|
||||
order_with_respect_to = 'bar'
|
||||
app_label = 'order_with_respect_to'
|
||||
|
||||
count = 0
|
||||
for field in Foo._meta.local_fields:
|
||||
|
|
Loading…
Reference in New Issue