Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2010-03-20 15:02:59 +00:00
parent 4528f39886
commit bfa080f402
4 changed files with 41 additions and 16 deletions

View File

@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet):
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
requested=None, offset=0, only_load=None):
requested=None, offset=0, only_load=None, local_only=False):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
* local_only - Only populate local fields. This is used when building
following reverse select-related relations
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
skip = set()
init_list = []
# Build the list of fields that *haven't* been requested
for field in klass._meta.fields:
for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields:
skip.add(field.name)
elif local_only and model is not None:
continue
else:
init_list.append(field.attname)
# Retrieve all the requested fields
@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
else:
# Load all fields on klass
field_count = len(klass._meta.fields)
if local_only:
field_names = [f.attname for f in klass._meta.local_fields]
else:
field_names = [f.attname for f in klass._meta.fields]
field_count = len(field_names)
fields = row[index_start : index_start + field_count]
# If all the select_related columns are None, then the related
# object must be non-existent - set the relation to None.
@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
if fields == (None,) * field_count:
obj = None
else:
obj = klass(*fields)
obj = klass(**dict(zip(field_names, fields)))
# If an object was retrieved, set the database state.
if obj:
@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
next = requested[f.related_query_name()]
# Recursively retrieve the data for the related object
cached_row = get_cached_row(model, row, index_end, using,
max_depth, cur_depth+1, next)
max_depth, cur_depth+1, next, local_only=True)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
# Now populate all the non-local field values
# on the related object
for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
if rel_model is not None:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
# populate the field cache for any related object
# that has already been retrieved
if rel_field.rel:
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
return obj, index_end
def delete_objects(seen_objs, using):

View File

@ -215,7 +215,7 @@ class SQLCompiler(object):
return result
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False):
start_alias=None, opts=None, as_pairs=False, local_only=False):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
@ -240,6 +240,8 @@ class SQLCompiler(object):
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
if local_only and model is not None:
continue
if start_alias:
try:
alias = seen[model]
@ -643,7 +645,7 @@ class SQLCompiler(object):
)
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=model._meta, as_pairs=True)
opts=model._meta, as_pairs=True, local_only=True)
self.query.related_select_cols.extend(columns)
self.query.related_select_fields.extend(model._meta.fields)

View File

@ -43,8 +43,7 @@ class StatDetails(models.Model):
class AdvancedUserStat(UserStat):
pass
karma = models.IntegerField()
class Image(models.Model):
name = models.CharField(max_length=100)

View File

@ -2,7 +2,7 @@ from django import db
from django.conf import settings
from django.test import TestCase
from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
AdvancedUserStat, Image, Product)
class ReverseSelectRelatedTestCase(TestCase):
@ -22,7 +22,7 @@ class ReverseSelectRelatedTestCase(TestCase):
user2 = User.objects.create(username="bob")
results2 = UserStatResult.objects.create(results='moar results')
advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250)
@ -74,18 +74,21 @@ class ReverseSelectRelatedTestCase(TestCase):
self.assertQueries(2)
def test_follow_from_child_class(self):
stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
self.assertEqual(stat.statdetails.comments, 250)
self.assertEqual(stat.user.username, 'bob')
self.assertQueries(1)
def test_follow_inheritance(self):
stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.user.username, 'bob')
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
self.assertQueries(1)
def test_nullable_relation(self):
im = Image.objects.create(name="imag1")
p1 = Product.objects.create(name="Django Plushie", image=im)
p2 = Product.objects.create(name="Talking Django Plushie")
self.assertEqual(len(Product.objects.select_related("image")), 2)