Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
4528f39886
commit
bfa080f402
|
@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet):
|
||||||
|
|
||||||
|
|
||||||
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
requested=None, offset=0, only_load=None):
|
requested=None, offset=0, only_load=None, local_only=False):
|
||||||
"""
|
"""
|
||||||
Helper function that recursively returns an object with the specified
|
Helper function that recursively returns an object with the specified
|
||||||
related attributes already populated.
|
related attributes already populated.
|
||||||
|
@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
* only_load - if the query has had only() or defer() applied,
|
* only_load - if the query has had only() or defer() applied,
|
||||||
this is the list of field names that will be returned. If None,
|
this is the list of field names that will be returned. If None,
|
||||||
the full field list for `klass` can be assumed.
|
the full field list for `klass` can be assumed.
|
||||||
|
* local_only - Only populate local fields. This is used when building
|
||||||
|
following reverse select-related relations
|
||||||
"""
|
"""
|
||||||
if max_depth and requested is None and cur_depth > max_depth:
|
if max_depth and requested is None and cur_depth > max_depth:
|
||||||
# We've recursed deeply enough; stop now.
|
# We've recursed deeply enough; stop now.
|
||||||
|
@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
skip = set()
|
skip = set()
|
||||||
init_list = []
|
init_list = []
|
||||||
# Build the list of fields that *haven't* been requested
|
# Build the list of fields that *haven't* been requested
|
||||||
for field in klass._meta.fields:
|
for field, model in klass._meta.get_fields_with_model():
|
||||||
if field.name not in load_fields:
|
if field.name not in load_fields:
|
||||||
skip.add(field.name)
|
skip.add(field.name)
|
||||||
|
elif local_only and model is not None:
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
init_list.append(field.attname)
|
init_list.append(field.attname)
|
||||||
# Retrieve all the requested fields
|
# Retrieve all the requested fields
|
||||||
|
@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Load all fields on klass
|
# Load all fields on klass
|
||||||
field_count = len(klass._meta.fields)
|
if local_only:
|
||||||
|
field_names = [f.attname for f in klass._meta.local_fields]
|
||||||
|
else:
|
||||||
|
field_names = [f.attname for f in klass._meta.fields]
|
||||||
|
field_count = len(field_names)
|
||||||
fields = row[index_start : index_start + field_count]
|
fields = row[index_start : index_start + field_count]
|
||||||
# If all the select_related columns are None, then the related
|
# If all the select_related columns are None, then the related
|
||||||
# object must be non-existent - set the relation to None.
|
# object must be non-existent - set the relation to None.
|
||||||
|
@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
if fields == (None,) * field_count:
|
if fields == (None,) * field_count:
|
||||||
obj = None
|
obj = None
|
||||||
else:
|
else:
|
||||||
obj = klass(*fields)
|
obj = klass(**dict(zip(field_names, fields)))
|
||||||
|
|
||||||
# If an object was retrieved, set the database state.
|
# If an object was retrieved, set the database state.
|
||||||
if obj:
|
if obj:
|
||||||
|
@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
next = requested[f.related_query_name()]
|
next = requested[f.related_query_name()]
|
||||||
# Recursively retrieve the data for the related object
|
# Recursively retrieve the data for the related object
|
||||||
cached_row = get_cached_row(model, row, index_end, using,
|
cached_row = get_cached_row(model, row, index_end, using,
|
||||||
max_depth, cur_depth+1, next)
|
max_depth, cur_depth+1, next, local_only=True)
|
||||||
# If the recursive descent found an object, populate the
|
# If the recursive descent found an object, populate the
|
||||||
# descriptor caches relevant to the object
|
# descriptor caches relevant to the object
|
||||||
if cached_row:
|
if cached_row:
|
||||||
|
@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
# If the related object exists, populate
|
# If the related object exists, populate
|
||||||
# the descriptor cache.
|
# the descriptor cache.
|
||||||
setattr(rel_obj, f.get_cache_name(), obj)
|
setattr(rel_obj, f.get_cache_name(), obj)
|
||||||
|
# Now populate all the non-local field values
|
||||||
|
# on the related object
|
||||||
|
for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
|
||||||
|
if rel_model is not None:
|
||||||
|
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
|
||||||
|
# populate the field cache for any related object
|
||||||
|
# that has already been retrieved
|
||||||
|
if rel_field.rel:
|
||||||
|
try:
|
||||||
|
cached_obj = getattr(obj, rel_field.get_cache_name())
|
||||||
|
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
|
||||||
|
except AttributeError:
|
||||||
|
# Related object hasn't been cached yet
|
||||||
|
pass
|
||||||
return obj, index_end
|
return obj, index_end
|
||||||
|
|
||||||
def delete_objects(seen_objs, using):
|
def delete_objects(seen_objs, using):
|
||||||
|
|
|
@ -215,7 +215,7 @@ class SQLCompiler(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
||||||
start_alias=None, opts=None, as_pairs=False):
|
start_alias=None, opts=None, as_pairs=False, local_only=False):
|
||||||
"""
|
"""
|
||||||
Computes the default columns for selecting every field in the base
|
Computes the default columns for selecting every field in the base
|
||||||
model. Will sometimes be called to pull in related models (e.g. via
|
model. Will sometimes be called to pull in related models (e.g. via
|
||||||
|
@ -240,6 +240,8 @@ class SQLCompiler(object):
|
||||||
if start_alias:
|
if start_alias:
|
||||||
seen = {None: start_alias}
|
seen = {None: start_alias}
|
||||||
for field, model in opts.get_fields_with_model():
|
for field, model in opts.get_fields_with_model():
|
||||||
|
if local_only and model is not None:
|
||||||
|
continue
|
||||||
if start_alias:
|
if start_alias:
|
||||||
try:
|
try:
|
||||||
alias = seen[model]
|
alias = seen[model]
|
||||||
|
@ -643,7 +645,7 @@ class SQLCompiler(object):
|
||||||
)
|
)
|
||||||
used.add(alias)
|
used.add(alias)
|
||||||
columns, aliases = self.get_default_columns(start_alias=alias,
|
columns, aliases = self.get_default_columns(start_alias=alias,
|
||||||
opts=model._meta, as_pairs=True)
|
opts=model._meta, as_pairs=True, local_only=True)
|
||||||
self.query.related_select_cols.extend(columns)
|
self.query.related_select_cols.extend(columns)
|
||||||
self.query.related_select_fields.extend(model._meta.fields)
|
self.query.related_select_fields.extend(model._meta.fields)
|
||||||
|
|
||||||
|
|
|
@ -43,8 +43,7 @@ class StatDetails(models.Model):
|
||||||
|
|
||||||
|
|
||||||
class AdvancedUserStat(UserStat):
|
class AdvancedUserStat(UserStat):
|
||||||
pass
|
karma = models.IntegerField()
|
||||||
|
|
||||||
|
|
||||||
class Image(models.Model):
|
class Image(models.Model):
|
||||||
name = models.CharField(max_length=100)
|
name = models.CharField(max_length=100)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from django import db
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
|
from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
|
||||||
AdvancedUserStat, Image, Product)
|
AdvancedUserStat, Image, Product)
|
||||||
|
|
||||||
class ReverseSelectRelatedTestCase(TestCase):
|
class ReverseSelectRelatedTestCase(TestCase):
|
||||||
|
@ -22,7 +22,7 @@ class ReverseSelectRelatedTestCase(TestCase):
|
||||||
|
|
||||||
user2 = User.objects.create(username="bob")
|
user2 = User.objects.create(username="bob")
|
||||||
results2 = UserStatResult.objects.create(results='moar results')
|
results2 = UserStatResult.objects.create(results='moar results')
|
||||||
advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
|
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
|
||||||
results=results2)
|
results=results2)
|
||||||
StatDetails.objects.create(base_stats=advstat, comments=250)
|
StatDetails.objects.create(base_stats=advstat, comments=250)
|
||||||
|
|
||||||
|
@ -74,18 +74,21 @@ class ReverseSelectRelatedTestCase(TestCase):
|
||||||
self.assertQueries(2)
|
self.assertQueries(2)
|
||||||
|
|
||||||
def test_follow_from_child_class(self):
|
def test_follow_from_child_class(self):
|
||||||
stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
|
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
|
||||||
self.assertEqual(stat.statdetails.comments, 250)
|
self.assertEqual(stat.statdetails.comments, 250)
|
||||||
|
self.assertEqual(stat.user.username, 'bob')
|
||||||
self.assertQueries(1)
|
self.assertQueries(1)
|
||||||
|
|
||||||
def test_follow_inheritance(self):
|
def test_follow_inheritance(self):
|
||||||
stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
|
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
|
||||||
self.assertEqual(stat.advanceduserstat.posts, 200)
|
self.assertEqual(stat.advanceduserstat.posts, 200)
|
||||||
|
self.assertEqual(stat.user.username, 'bob')
|
||||||
|
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
|
||||||
self.assertQueries(1)
|
self.assertQueries(1)
|
||||||
|
|
||||||
def test_nullable_relation(self):
|
def test_nullable_relation(self):
|
||||||
im = Image.objects.create(name="imag1")
|
im = Image.objects.create(name="imag1")
|
||||||
p1 = Product.objects.create(name="Django Plushie", image=im)
|
p1 = Product.objects.create(name="Django Plushie", image=im)
|
||||||
p2 = Product.objects.create(name="Talking Django Plushie")
|
p2 = Product.objects.create(name="Talking Django Plushie")
|
||||||
|
|
||||||
self.assertEqual(len(Product.objects.select_related("image")), 2)
|
self.assertEqual(len(Product.objects.select_related("image")), 2)
|
||||||
|
|
Loading…
Reference in New Issue