Fixed -- Made ``defer()`` work with reverse relations

Reverse o2o fields are now usable with defer.
This commit is contained in:
Tai Lee 2012-11-28 18:16:00 +02:00 committed by Anssi Kääriäinen
parent 2a0e4c249f
commit 6ebf115206
3 changed files with 50 additions and 5 deletions
django/db/models/sql
tests/regressiontests/defer_regress

View File

@ -599,17 +599,22 @@ class Query(object):
for name in parts[:-1]: for name in parts[:-1]:
old_model = cur_model old_model = cur_model
source = opts.get_field_by_name(name)[0] source = opts.get_field_by_name(name)[0]
cur_model = source.rel.to if is_reverse_o2o(source):
cur_model = source.model
else:
cur_model = source.rel.to
opts = cur_model._meta opts = cur_model._meta
# Even if we're "just passing through" this model, we must add # Even if we're "just passing through" this model, we must add
# both the current model's pk and the related reference field # both the current model's pk and the related reference field
# to the things we select. # (if it's not a reverse relation) to the things we select.
must_include[old_model].add(source) if not is_reverse_o2o(source):
must_include[old_model].add(source)
add_to_dict(must_include, cur_model, opts.pk) add_to_dict(must_include, cur_model, opts.pk)
field, model, _, _ = opts.get_field_by_name(parts[-1]) field, model, _, _ = opts.get_field_by_name(parts[-1])
if model is None: if model is None:
model = cur_model model = cur_model
add_to_dict(seen, model, field) if not is_reverse_o2o(field):
add_to_dict(seen, model, field)
if defer: if defer:
# We need to load all fields for each model, except those that # We need to load all fields for each model, except those that
@ -1983,3 +1988,10 @@ def add_to_dict(data, key, value):
data[key].add(value) data[key].add(value)
else: else:
data[key] = set([value]) data[key] = set([value])
def is_reverse_o2o(field):
"""
A little helper to check if the given field is reverse-o2o. The field is
expected to be some sort of relation field or related object.
"""
return not hasattr(field, 'rel') and field.field.unique

View File

@ -55,6 +55,10 @@ class Feature(models.Model):
class SpecialFeature(models.Model): class SpecialFeature(models.Model):
feature = models.ForeignKey(Feature) feature = models.ForeignKey(Feature)
class OneToOneItem(models.Model):
item = models.OneToOneField(Item, related_name="one_to_one_item")
name = models.CharField(max_length=15)
class ItemAndSimpleItem(models.Model): class ItemAndSimpleItem(models.Model):
item = models.ForeignKey(Item) item = models.ForeignKey(Item)
simple = models.ForeignKey(SimpleItem) simple = models.ForeignKey(SimpleItem)

View File

@ -9,7 +9,7 @@ from django.db.models.loading import cache
from django.test import TestCase from django.test import TestCase
from .models import (ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, from .models import (ResolveThis, Item, RelatedItem, Child, Leaf, Proxy,
SimpleItem, Feature, ItemAndSimpleItem, SpecialFeature) SimpleItem, Feature, ItemAndSimpleItem, OneToOneItem, SpecialFeature)
class DeferRegressionTest(TestCase): class DeferRegressionTest(TestCase):
@ -111,6 +111,7 @@ class DeferRegressionTest(TestCase):
Item, Item,
ItemAndSimpleItem, ItemAndSimpleItem,
Leaf, Leaf,
OneToOneItem,
Proxy, Proxy,
RelatedItem, RelatedItem,
ResolveThis, ResolveThis,
@ -147,6 +148,7 @@ class DeferRegressionTest(TestCase):
"Leaf_Deferred_name_value", "Leaf_Deferred_name_value",
"Leaf_Deferred_second_child_id_value", "Leaf_Deferred_second_child_id_value",
"Leaf_Deferred_value", "Leaf_Deferred_value",
"OneToOneItem",
"Proxy", "Proxy",
"RelatedItem", "RelatedItem",
"RelatedItem_Deferred_", "RelatedItem_Deferred_",
@ -182,6 +184,33 @@ class DeferRegressionTest(TestCase):
self.assertEqual(1, qs.count()) self.assertEqual(1, qs.count())
self.assertEqual('Foobar', qs[0].name) self.assertEqual('Foobar', qs[0].name)
def test_reverse_one_to_one_relations(self):
# Refs #14694. Test reverse relations which are known unique (reverse
# side has o2ofield or unique FK) - the o2o case
item = Item.objects.create(name="first", value=42)
o2o = OneToOneItem.objects.create(item=item, name="second")
self.assertEqual(len(Item.objects.defer('one_to_one_item__name')), 1)
self.assertEqual(len(Item.objects.select_related('one_to_one_item')), 1)
self.assertEqual(len(Item.objects.select_related(
'one_to_one_item').defer('one_to_one_item__name')), 1)
self.assertEqual(len(Item.objects.select_related('one_to_one_item').defer('value')), 1)
# Make sure that `only()` doesn't break when we pass in a unique relation,
# rather than a field on the relation.
self.assertEqual(len(Item.objects.only('one_to_one_item')), 1)
with self.assertNumQueries(1):
i = Item.objects.select_related('one_to_one_item')[0]
self.assertEquals(i.one_to_one_item.pk, o2o.pk)
self.assertEquals(i.one_to_one_item.name, "second")
with self.assertNumQueries(1):
i = Item.objects.select_related('one_to_one_item').defer(
'value', 'one_to_one_item__name')[0]
self.assertEquals(i.one_to_one_item.pk, o2o.pk)
self.assertEquals(i.name, "first")
with self.assertNumQueries(1):
self.assertEquals(i.one_to_one_item.name, "second")
with self.assertNumQueries(1):
self.assertEquals(i.value, 42)
def test_defer_with_select_related(self): def test_defer_with_select_related(self):
item1 = Item.objects.create(name="first", value=47) item1 = Item.objects.create(name="first", value=47)
item2 = Item.objects.create(name="second", value=42) item2 = Item.objects.create(name="second", value=42)