Fixed #11319 - Added lookup support for ForeignKey.to_field. Also reverted no-longer-needed model formsets workaround for lack of such support from r10756. Thanks Russell and Alex for review.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@15303 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Carl Meyer 2011-01-25 03:14:28 +00:00
parent 7c888a7aa9
commit 227c5e80db
8 changed files with 122 additions and 14 deletions

View File

@ -178,9 +178,20 @@ class RelatedField(object):
# the primary key may itself be an object - so we need to keep drilling # the primary key may itself be an object - so we need to keep drilling
# down until we hit a value that can be used for a comparison. # down until we hit a value that can be used for a comparison.
v = value v = value
# In the case of an FK to 'self', this check allows to_field to be used
# for both forwards and reverse lookups across the FK. (For normal FKs,
# it's only relevant for forward lookups).
if isinstance(v, self.rel.to):
field_name = getattr(self.rel, "field_name", None)
else:
field_name = None
try: try:
while True: while True:
v = getattr(v, v._meta.pk.name) if field_name is None:
field_name = v._meta.pk.name
v = getattr(v, field_name)
field_name = None
except AttributeError: except AttributeError:
pass pass
except exceptions.ObjectDoesNotExist: except exceptions.ObjectDoesNotExist:

View File

@ -1364,6 +1364,11 @@ class Query(object):
table = opts.db_table table = opts.db_table
from_col = local_field.column from_col = local_field.column
to_col = field.column to_col = field.column
# In case of a recursive FK, use the to_field for
# reverse lookups as well
if orig_field.model is local_field.model:
target = opts.get_field(field.rel.field_name)
else:
target = opts.pk target = opts.pk
orig_opts._join_cache[name] = (table, from_col, to_col, orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target) opts, target)

View File

@ -700,13 +700,9 @@ class BaseInlineFormSet(BaseModelFormSet):
self.save_as_new = save_as_new self.save_as_new = save_as_new
# is there a better way to get the object descriptor? # is there a better way to get the object descriptor?
self.rel_name = RelatedObject(self.fk.rel.to, self.model, self.fk).get_accessor_name() self.rel_name = RelatedObject(self.fk.rel.to, self.model, self.fk).get_accessor_name()
if self.fk.rel.field_name == self.fk.rel.to._meta.pk.name:
backlink_value = self.instance
else:
backlink_value = getattr(self.instance, self.fk.rel.field_name)
if queryset is None: if queryset is None:
queryset = self.model._default_manager queryset = self.model._default_manager
qs = queryset.filter(**{self.fk.name: backlink_value}) qs = queryset.filter(**{self.fk.name: self.instance})
super(BaseInlineFormSet, self).__init__(data, files, prefix=prefix, super(BaseInlineFormSet, self).__init__(data, files, prefix=prefix,
queryset=qs) queryset=qs)

View File

@ -158,11 +158,9 @@ class CustomPKTests(TestCase):
new_bar = Bar.objects.create() new_bar = Bar.objects.create()
new_foo = Foo.objects.create(bar=new_bar) new_foo = Foo.objects.create(bar=new_bar)
# FIXME: This still doesn't work, but will require some changes in f = Foo.objects.get(bar=new_bar.pk)
# get_db_prep_lookup to fix it. self.assertEqual(f, new_foo)
# f = Foo.objects.get(bar=new_bar.pk) self.assertEqual(f.bar, new_bar)
# self.assertEqual(f, new_foo)
# self.assertEqual(f.bar, new_bar)
f = Foo.objects.get(bar=new_bar) f = Foo.objects.get(bar=new_bar)
self.assertEqual(f, new_foo), self.assertEqual(f, new_foo),

View File

@ -44,3 +44,10 @@ class Email(Contact):
class Researcher(models.Model): class Researcher(models.Model):
contacts = models.ManyToManyField(Contact, related_name="research_contacts") contacts = models.ManyToManyField(Contact, related_name="research_contacts")
class Food(models.Model):
name = models.CharField(max_length=20, unique=True)
class Eaten(models.Model):
food = models.ForeignKey(Food, to_field="name")
meal = models.CharField(max_length=20)

View File

@ -5,7 +5,7 @@ from django.db import backend, connection, transaction, DEFAULT_DB_ALIAS
from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
from models import (Book, Award, AwardNote, Person, Child, Toy, PlayedWith, from models import (Book, Award, AwardNote, Person, Child, Toy, PlayedWith,
PlayedWithNote, Contact, Email, Researcher) PlayedWithNote, Contact, Email, Researcher, Food, Eaten)
# Can't run this test under SQLite, because you can't # Can't run this test under SQLite, because you can't
@ -119,6 +119,16 @@ class DeleteCascadeTransactionTests(TransactionTestCase):
email.delete() email.delete()
def test_to_field(self):
"""
Cascade deletion works with ForeignKey.to_field set to non-PK.
"""
apple = Food.objects.create(name="apple")
eaten = Eaten.objects.create(food=apple, meal="lunch")
apple.delete()
class LargeDeleteTests(TestCase): class LargeDeleteTests(TestCase):
def test_large_deletes(self): def test_large_deletes(self):
"Regression for #13309 -- if the number of objects > chunk size, deletion still occurs" "Regression for #13309 -- if the number of objects > chunk size, deletion still occurs"

View File

@ -274,3 +274,23 @@ class Plaything(models.Model):
class Article(models.Model): class Article(models.Model):
name = models.CharField(max_length=20) name = models.CharField(max_length=20)
created = models.DateTimeField() created = models.DateTimeField()
class Food(models.Model):
name = models.CharField(max_length=20, unique=True)
def __unicode__(self):
return self.name
class Eaten(models.Model):
food = models.ForeignKey(Food, to_field="name")
meal = models.CharField(max_length=20)
def __unicode__(self):
return u"%s at %s" % (self.food, self.meal)
class Node(models.Model):
num = models.IntegerField(unique=True)
parent = models.ForeignKey("self", to_field="num", null=True)
def __unicode__(self):
return u"%s" % self.num

View File

@ -14,7 +14,7 @@ from django.utils.datastructures import SortedDict
from models import (Annotation, Article, Author, Celebrity, Child, Cover, Detail, from models import (Annotation, Article, Author, Celebrity, Child, Cover, Detail,
DumbCategory, ExtraInfo, Fan, Item, LeafA, LoopX, LoopZ, ManagedModel, DumbCategory, ExtraInfo, Fan, Item, LeafA, LoopX, LoopZ, ManagedModel,
Member, NamedCategory, Note, Number, Plaything, PointerA, Ranking, Related, Member, NamedCategory, Note, Number, Plaything, PointerA, Ranking, Related,
Report, ReservedName, Tag, TvChef, Valid, X) Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Node)
class BaseQuerysetTest(TestCase): class BaseQuerysetTest(TestCase):
@ -1515,6 +1515,67 @@ class EscapingTests(TestCase):
) )
class ToFieldTests(TestCase):
def test_in_query(self):
apple = Food.objects.create(name="apple")
pear = Food.objects.create(name="pear")
lunch = Eaten.objects.create(food=apple, meal="lunch")
dinner = Eaten.objects.create(food=pear, meal="dinner")
self.assertEqual(
set(Eaten.objects.filter(food__in=[apple, pear])),
set([lunch, dinner]),
)
def test_reverse_in(self):
apple = Food.objects.create(name="apple")
pear = Food.objects.create(name="pear")
lunch_apple = Eaten.objects.create(food=apple, meal="lunch")
lunch_pear = Eaten.objects.create(food=pear, meal="dinner")
self.assertEqual(
set(Food.objects.filter(eaten__in=[lunch_apple, lunch_pear])),
set([apple, pear])
)
def test_single_object(self):
apple = Food.objects.create(name="apple")
lunch = Eaten.objects.create(food=apple, meal="lunch")
dinner = Eaten.objects.create(food=apple, meal="dinner")
self.assertEqual(
set(Eaten.objects.filter(food=apple)),
set([lunch, dinner])
)
def test_single_object_reverse(self):
apple = Food.objects.create(name="apple")
lunch = Eaten.objects.create(food=apple, meal="lunch")
self.assertEqual(
set(Food.objects.filter(eaten=lunch)),
set([apple])
)
def test_recursive_fk(self):
node1 = Node.objects.create(num=42)
node2 = Node.objects.create(num=1, parent=node1)
self.assertEqual(
list(Node.objects.filter(parent=node1)),
[node2]
)
def test_recursive_fk_reverse(self):
node1 = Node.objects.create(num=42)
node2 = Node.objects.create(num=1, parent=node1)
self.assertEqual(
list(Node.objects.filter(node=node2)),
[node1]
)
class ConditionalTests(BaseQuerysetTest): class ConditionalTests(BaseQuerysetTest):
"""Tests whose execution depend on dfferent environment conditions like """Tests whose execution depend on dfferent environment conditions like
Python version or DB backend features""" Python version or DB backend features"""