Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2010-03-20 15:02:59 +00:00
parent 4528f39886
commit bfa080f402
4 changed files with 41 additions and 16 deletions

View File

@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet):
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
requested=None, offset=0, only_load=None): requested=None, offset=0, only_load=None, local_only=False):
""" """
Helper function that recursively returns an object with the specified Helper function that recursively returns an object with the specified
related attributes already populated. related attributes already populated.
@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
* only_load - if the query has had only() or defer() applied, * only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None, this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed. the full field list for `klass` can be assumed.
* local_only - Only populate local fields. This is used when building
following reverse select-related relations
""" """
if max_depth and requested is None and cur_depth > max_depth: if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now. # We've recursed deeply enough; stop now.
@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
skip = set() skip = set()
init_list = [] init_list = []
# Build the list of fields that *haven't* been requested # Build the list of fields that *haven't* been requested
for field in klass._meta.fields: for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields: if field.name not in load_fields:
skip.add(field.name) skip.add(field.name)
elif local_only and model is not None:
continue
else: else:
init_list.append(field.attname) init_list.append(field.attname)
# Retrieve all the requested fields # Retrieve all the requested fields
@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
else: else:
# Load all fields on klass # Load all fields on klass
field_count = len(klass._meta.fields) if local_only:
field_names = [f.attname for f in klass._meta.local_fields]
else:
field_names = [f.attname for f in klass._meta.fields]
field_count = len(field_names)
fields = row[index_start : index_start + field_count] fields = row[index_start : index_start + field_count]
# If all the select_related columns are None, then the related # If all the select_related columns are None, then the related
# object must be non-existent - set the relation to None. # object must be non-existent - set the relation to None.
@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
if fields == (None,) * field_count: if fields == (None,) * field_count:
obj = None obj = None
else: else:
obj = klass(*fields) obj = klass(**dict(zip(field_names, fields)))
# If an object was retrieved, set the database state. # If an object was retrieved, set the database state.
if obj: if obj:
@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
next = requested[f.related_query_name()] next = requested[f.related_query_name()]
# Recursively retrieve the data for the related object # Recursively retrieve the data for the related object
cached_row = get_cached_row(model, row, index_end, using, cached_row = get_cached_row(model, row, index_end, using,
max_depth, cur_depth+1, next) max_depth, cur_depth+1, next, local_only=True)
# If the recursive descent found an object, populate the # If the recursive descent found an object, populate the
# descriptor caches relevant to the object # descriptor caches relevant to the object
if cached_row: if cached_row:
@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
# If the related object exists, populate # If the related object exists, populate
# the descriptor cache. # the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj) setattr(rel_obj, f.get_cache_name(), obj)
# Now populate all the non-local field values
# on the related object
for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
if rel_model is not None:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
# populate the field cache for any related object
# that has already been retrieved
if rel_field.rel:
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
return obj, index_end return obj, index_end
def delete_objects(seen_objs, using): def delete_objects(seen_objs, using):

View File

@ -215,7 +215,7 @@ class SQLCompiler(object):
return result return result
def get_default_columns(self, with_aliases=False, col_aliases=None, def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False): start_alias=None, opts=None, as_pairs=False, local_only=False):
""" """
Computes the default columns for selecting every field in the base Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via model. Will sometimes be called to pull in related models (e.g. via
@ -240,6 +240,8 @@ class SQLCompiler(object):
if start_alias: if start_alias:
seen = {None: start_alias} seen = {None: start_alias}
for field, model in opts.get_fields_with_model(): for field, model in opts.get_fields_with_model():
if local_only and model is not None:
continue
if start_alias: if start_alias:
try: try:
alias = seen[model] alias = seen[model]
@ -643,7 +645,7 @@ class SQLCompiler(object):
) )
used.add(alias) used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias, columns, aliases = self.get_default_columns(start_alias=alias,
opts=model._meta, as_pairs=True) opts=model._meta, as_pairs=True, local_only=True)
self.query.related_select_cols.extend(columns) self.query.related_select_cols.extend(columns)
self.query.related_select_fields.extend(model._meta.fields) self.query.related_select_fields.extend(model._meta.fields)

View File

@ -43,8 +43,7 @@ class StatDetails(models.Model):
class AdvancedUserStat(UserStat): class AdvancedUserStat(UserStat):
pass karma = models.IntegerField()
class Image(models.Model): class Image(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)

View File

@ -22,7 +22,7 @@ class ReverseSelectRelatedTestCase(TestCase):
user2 = User.objects.create(username="bob") user2 = User.objects.create(username="bob")
results2 = UserStatResult.objects.create(results='moar results') results2 = UserStatResult.objects.create(results='moar results')
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
results=results2) results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250) StatDetails.objects.create(base_stats=advstat, comments=250)
@ -74,13 +74,16 @@ class ReverseSelectRelatedTestCase(TestCase):
self.assertQueries(2) self.assertQueries(2)
def test_follow_from_child_class(self): def test_follow_from_child_class(self):
stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
self.assertEqual(stat.statdetails.comments, 250) self.assertEqual(stat.statdetails.comments, 250)
self.assertEqual(stat.user.username, 'bob')
self.assertQueries(1) self.assertQueries(1)
def test_follow_inheritance(self): def test_follow_inheritance(self):
stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200) self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.user.username, 'bob')
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
self.assertQueries(1) self.assertQueries(1)
def test_nullable_relation(self): def test_nullable_relation(self):