From bfa080f402ddfc383abdbac96fa198e7f2f8ec59 Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Sat, 20 Mar 2010 15:02:59 +0000 Subject: [PATCH] Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report. git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 33 +++++++++++++++---- django/db/models/sql/compiler.py | 6 ++-- .../select_related_onetoone/models.py | 3 +- .../select_related_onetoone/tests.py | 15 +++++---- 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index f6b4419d27..8adf0d555c 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet): def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, - requested=None, offset=0, only_load=None): + requested=None, offset=0, only_load=None, local_only=False): """ Helper function that recursively returns an object with the specified related attributes already populated. @@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, * only_load - if the query has had only() or defer() applied, this is the list of field names that will be returned. If None, the full field list for `klass` can be assumed. + * local_only - Only populate local fields. This is used when building + following reverse select-related relations """ if max_depth and requested is None and cur_depth > max_depth: # We've recursed deeply enough; stop now. @@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, skip = set() init_list = [] # Build the list of fields that *haven't* been requested - for field in klass._meta.fields: + for field, model in klass._meta.get_fields_with_model(): if field.name not in load_fields: skip.add(field.name) + elif local_only and model is not None: + continue else: init_list.append(field.attname) # Retrieve all the requested fields @@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, else: # Load all fields on klass - field_count = len(klass._meta.fields) + if local_only: + field_names = [f.attname for f in klass._meta.local_fields] + else: + field_names = [f.attname for f in klass._meta.fields] + field_count = len(field_names) fields = row[index_start : index_start + field_count] # If all the select_related columns are None, then the related # object must be non-existent - set the relation to None. @@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, if fields == (None,) * field_count: obj = None else: - obj = klass(*fields) + obj = klass(**dict(zip(field_names, fields))) # If an object was retrieved, set the database state. if obj: @@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, next = requested[f.related_query_name()] # Recursively retrieve the data for the related object cached_row = get_cached_row(model, row, index_end, using, - max_depth, cur_depth+1, next) + max_depth, cur_depth+1, next, local_only=True) # If the recursive descent found an object, populate the # descriptor caches relevant to the object if cached_row: @@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, # If the related object exists, populate # the descriptor cache. setattr(rel_obj, f.get_cache_name(), obj) - + # Now populate all the non-local field values + # on the related object + for rel_field,rel_model in rel_obj._meta.get_fields_with_model(): + if rel_model is not None: + setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) + # populate the field cache for any related object + # that has already been retrieved + if rel_field.rel: + try: + cached_obj = getattr(obj, rel_field.get_cache_name()) + setattr(rel_obj, rel_field.get_cache_name(), cached_obj) + except AttributeError: + # Related object hasn't been cached yet + pass return obj, index_end def delete_objects(seen_objs, using): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b7d63d381e..2fe03302a9 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -215,7 +215,7 @@ class SQLCompiler(object): return result def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False): + start_alias=None, opts=None, as_pairs=False, local_only=False): """ Computes the default columns for selecting every field in the base model. Will sometimes be called to pull in related models (e.g. via @@ -240,6 +240,8 @@ class SQLCompiler(object): if start_alias: seen = {None: start_alias} for field, model in opts.get_fields_with_model(): + if local_only and model is not None: + continue if start_alias: try: alias = seen[model] @@ -643,7 +645,7 @@ class SQLCompiler(object): ) used.add(alias) columns, aliases = self.get_default_columns(start_alias=alias, - opts=model._meta, as_pairs=True) + opts=model._meta, as_pairs=True, local_only=True) self.query.related_select_cols.extend(columns) self.query.related_select_fields.extend(model._meta.fields) diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py index 6b46366530..3d6da9b4c5 100644 --- a/tests/regressiontests/select_related_onetoone/models.py +++ b/tests/regressiontests/select_related_onetoone/models.py @@ -43,8 +43,7 @@ class StatDetails(models.Model): class AdvancedUserStat(UserStat): - pass - + karma = models.IntegerField() class Image(models.Model): name = models.CharField(max_length=100) diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py index 5a4a3e4ed6..4ccb58440a 100644 --- a/tests/regressiontests/select_related_onetoone/tests.py +++ b/tests/regressiontests/select_related_onetoone/tests.py @@ -2,7 +2,7 @@ from django import db from django.conf import settings from django.test import TestCase -from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, +from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat, Image, Product) class ReverseSelectRelatedTestCase(TestCase): @@ -22,7 +22,7 @@ class ReverseSelectRelatedTestCase(TestCase): user2 = User.objects.create(username="bob") results2 = UserStatResult.objects.create(results='moar results') - advstat = AdvancedUserStat.objects.create(user=user2, posts=200, + advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, results=results2) StatDetails.objects.create(base_stats=advstat, comments=250) @@ -74,18 +74,21 @@ class ReverseSelectRelatedTestCase(TestCase): self.assertQueries(2) def test_follow_from_child_class(self): - stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) + stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) self.assertEqual(stat.statdetails.comments, 250) + self.assertEqual(stat.user.username, 'bob') self.assertQueries(1) def test_follow_inheritance(self): - stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) + stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) self.assertEqual(stat.advanceduserstat.posts, 200) + self.assertEqual(stat.user.username, 'bob') + self.assertEqual(stat.advanceduserstat.user.username, 'bob') self.assertQueries(1) - + def test_nullable_relation(self): im = Image.objects.create(name="imag1") p1 = Product.objects.create(name="Django Plushie", image=im) p2 = Product.objects.create(name="Talking Django Plushie") - + self.assertEqual(len(Product.objects.select_related("image")), 2)