From f51e409a5fb34020e170494320a421503689aea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Fri, 9 Nov 2012 20:21:46 +0200 Subject: [PATCH] Fixed #13781 -- Improved select_related in inheritance situations The select_related code got confused when it needed to travel a reverse relation to a model which had different parent than the originally travelled relation. Thanks to Trac aliases shauncutts for report and ungenio for original patch (committed patch is somewhat modified version of that). --- django/db/models/options.py | 6 +- django/db/models/query.py | 83 ++++++++------ django/db/models/sql/compiler.py | 13 ++- .../select_related_onetoone/models.py | 42 ++++++++ .../select_related_onetoone/tests.py | 102 +++++++++++++++++- 5 files changed, 203 insertions(+), 43 deletions(-) diff --git a/django/db/models/options.py b/django/db/models/options.py index 7ea6e4b744..ab2f44e2f7 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -75,6 +75,7 @@ class Options(object): from django.db.backends.util import truncate_name cls._meta = self + self.model = cls self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS # First, construct the default values for these options. self.object_name = cls.__name__ @@ -464,7 +465,7 @@ class Options(object): a granparent or even more distant relation. """ if not self.parents: - return + return None if model in self.parents: return [model] for parent in self.parents: @@ -472,8 +473,7 @@ class Options(object): if res: res.insert(0, parent) return res - raise TypeError('%r is not an ancestor of this model' - % model._meta.module_name) + return None def get_parent_list(self): """ diff --git a/django/db/models/query.py b/django/db/models/query.py index 67fef52f36..d5379a5f6a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1300,7 +1300,7 @@ class EmptyQuerySet(QuerySet): value_annotation = False def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, - only_load=None, local_only=False): + only_load=None, from_parent=None): """ Helper function that recursively returns an information for a klass, to be used in get_cached_row. It exists just to compute this information only @@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, * only_load - if the query has had only() or defer() applied, this is the list of field names that will be returned. If None, the full field list for `klass` can be assumed. - * local_only - Only populate local fields. This is used when - following reverse select-related relations + * from_parent - the parent model used to get to this model + + Note that when travelling from parent to child, we will only load child + fields which aren't in the parent. """ if max_depth and requested is None and cur_depth > max_depth: # We've recursed deeply enough; stop now. @@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, for field, model in klass._meta.get_fields_with_model(): if field.name not in load_fields: skip.add(field.attname) - elif local_only and model is not None: + elif from_parent and issubclass(from_parent, model.__class__): + # Avoid loading fields already loaded for parent model for + # child models. continue else: init_list.append(field.attname) @@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, else: # Load all fields on klass - # We trying to not populate field_names variable for perfomance reason. - # If field_names variable is set, it is used to instantiate desired fields, - # by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method. - # But kwargs version of Model.__init__ is slower, so we should avoid using - # it when it is not really neccesary. - if local_only and len(klass._meta.local_fields) != len(klass._meta.fields): - field_count = len(klass._meta.local_fields) - field_names = [f.attname for f in klass._meta.local_fields] - else: - field_count = len(klass._meta.fields) + field_count = len(klass._meta.fields) + # Check if we need to skip some parent fields. + if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields): + # Only load those fields which haven't been already loaded into + # 'from_parent'. + non_seen_models = [p for p in klass._meta.get_parent_list() + if not issubclass(from_parent, p)] + # Load local fields, too... + non_seen_models.append(klass) + field_names = [f.attname for f in klass._meta.fields + if f.model in non_seen_models] + field_count = len(field_names) + # Try to avoid populating field_names variable for perfomance reasons. + # If field_names variable is set, we use **kwargs based model init + # which is slower than normal init. + if field_count == len(klass._meta.fields): field_names = () restricted = requested is not None @@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, if o.field.unique and select_related_descend(o.field, restricted, requested, only_load.get(o.model), reverse=True): next = requested[o.field.related_query_name()] + parent = klass if issubclass(o.model, klass) else None klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, - requested=next, only_load=only_load, local_only=True) + requested=next, only_load=only_load, from_parent=parent) reverse_related_fields.append((o.field, klass_info)) if field_names: pk_idx = field_names.index(klass._meta.pk.attname) @@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx -def get_cached_row(row, index_start, using, klass_info, offset=0): +def get_cached_row(row, index_start, using, klass_info, offset=0, + parent_data=()): """ Helper function that recursively returns an object with the specified related attributes already populated. @@ -1418,13 +1430,16 @@ def get_cached_row(row, index_start, using, klass_info, offset=0): * offset - the number of additional fields that are known to exist in row for `klass`. This usually means the number of annotated results on `klass`. - * using - the database alias on which the query is being executed. + * using - the database alias on which the query is being executed. * klass_info - result of the get_klass_info function + * parent_data - parent model data in format (field, value). Used + to populate the non-local fields of child models. """ if klass_info is None: return None klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info + fields = row[index_start : index_start + field_count] # If the pk column is None (or the Oracle equivalent ''), then the related # object must be non-existent - set the relation to None. @@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0): obj = klass(**dict(zip(field_names, fields))) else: obj = klass(*fields) - # If an object was retrieved, set the database state. if obj: obj._state.db = using @@ -1464,34 +1478,35 @@ def get_cached_row(row, index_start, using, klass_info, offset=0): # Only handle the restricted case - i.e., don't do a depth # descent into reverse relations unless explicitly requested for f, klass_info in reverse_related_fields: + # Transfer data from this object to childs. + parent_data = [] + for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model(): + if rel_model is not None and isinstance(obj, rel_model): + parent_data.append((rel_field, getattr(obj, rel_field.attname))) # Recursively retrieve the data for the related object - cached_row = get_cached_row(row, index_end, using, klass_info) + cached_row = get_cached_row(row, index_end, using, klass_info, + parent_data=parent_data) # If the recursive descent found an object, populate the # descriptor caches relevant to the object if cached_row: rel_obj, index_end = cached_row if obj is not None: - # If the field is unique, populate the - # reverse descriptor cache + # populate the reverse descriptor cache setattr(obj, f.related.get_cache_name(), rel_obj) if rel_obj is not None: # If the related object exists, populate # the descriptor cache. setattr(rel_obj, f.get_cache_name(), obj) - # Now populate all the non-local field values - # on the related object - for rel_field, rel_model in rel_obj._meta.get_fields_with_model(): - if rel_model is not None: + # Populate related object caches using parent data. + for rel_field, _ in parent_data: + if rel_field.rel: setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) - # populate the field cache for any related object - # that has already been retrieved - if rel_field.rel: - try: - cached_obj = getattr(obj, rel_field.get_cache_name()) - setattr(rel_obj, rel_field.get_cache_name(), cached_obj) - except AttributeError: - # Related object hasn't been cached yet - pass + try: + cached_obj = getattr(obj, rel_field.get_cache_name()) + setattr(rel_obj, rel_field.get_cache_name(), cached_obj) + except AttributeError: + # Related object hasn't been cached yet + pass return obj, index_end diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 2ad4542341..8cfb12a8e3 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -240,7 +240,7 @@ class SQLCompiler(object): return result def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False, local_only=False): + start_alias=None, opts=None, as_pairs=False, from_parent=None): """ Computes the default columns for selecting every field in the base model. Will sometimes be called to pull in related models (e.g. via @@ -265,7 +265,8 @@ class SQLCompiler(object): if start_alias: seen = {None: start_alias} for field, model in opts.get_fields_with_model(): - if local_only and model is not None: + if from_parent and model is not None and issubclass(from_parent, model): + # Avoid loading data for already loaded parents. continue if start_alias: try: @@ -686,11 +687,13 @@ class SQLCompiler(object): (alias, table, f.rel.get_related_field().column, f.column), promote=True ) + from_parent = (opts.model if issubclass(model, opts.model) + else None) columns, aliases = self.get_default_columns(start_alias=alias, - opts=model._meta, as_pairs=True, local_only=True) + opts=model._meta, as_pairs=True, from_parent=from_parent) self.query.related_select_cols.extend( - SelectInfo(col, field) for col, field in zip(columns, model._meta.fields)) - + SelectInfo(col, field) for col, field + in zip(columns, model._meta.fields)) next = requested.get(f.related_query_name(), {}) # Use True here because we are looking at the _reverse_ side of # the relation, which is always nullable. diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py index 3284defb11..d32faafbb9 100644 --- a/tests/regressiontests/select_related_onetoone/models.py +++ b/tests/regressiontests/select_related_onetoone/models.py @@ -51,6 +51,7 @@ class StatDetails(models.Model): class AdvancedUserStat(UserStat): karma = models.IntegerField() + class Image(models.Model): name = models.CharField(max_length=100) @@ -58,3 +59,44 @@ class Image(models.Model): class Product(models.Model): name = models.CharField(max_length=100) image = models.OneToOneField(Image, null=True) + + +@python_2_unicode_compatible +class Parent1(models.Model): + name1 = models.CharField(max_length=50) + + def __str__(self): + return self.name1 + + +@python_2_unicode_compatible +class Parent2(models.Model): + # Avoid having two "id" fields in the Child1 subclass + id2 = models.AutoField(primary_key=True) + name2 = models.CharField(max_length=50) + + def __str__(self): + return self.name2 + + +@python_2_unicode_compatible +class Child1(Parent1, Parent2): + value = models.IntegerField() + + def __str__(self): + return self.name1 + + +@python_2_unicode_compatible +class Child2(Parent1): + parent2 = models.OneToOneField(Parent2) + value = models.IntegerField() + + def __str__(self): + return self.name1 + +class Child3(Child2): + value3 = models.IntegerField() + +class Child4(Child1): + value4 = models.IntegerField() diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py index 1373f04717..d4a1275e49 100644 --- a/tests/regressiontests/select_related_onetoone/tests.py +++ b/tests/regressiontests/select_related_onetoone/tests.py @@ -1,9 +1,11 @@ from __future__ import absolute_import from django.test import TestCase +from django.utils import unittest from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, - AdvancedUserStat, Image, Product) + AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2, Child3, + Child4) class ReverseSelectRelatedTestCase(TestCase): @@ -21,6 +23,14 @@ class ReverseSelectRelatedTestCase(TestCase): advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, results=results2) StatDetails.objects.create(base_stats=advstat, comments=250) + p1 = Parent1(name1="Only Parent1") + p1.save() + c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1) + c1.save() + p2 = Parent2(name2="Child2 Parent2") + p2.save() + c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2) + c2.save() def test_basic(self): with self.assertNumQueries(1): @@ -108,3 +118,93 @@ class ReverseSelectRelatedTestCase(TestCase): image = Image.objects.select_related('product').get() with self.assertRaises(Product.DoesNotExist): image.product + + def test_parent_only(self): + with self.assertNumQueries(1): + p = Parent1.objects.select_related('child1').get(name1="Only Parent1") + with self.assertNumQueries(0): + with self.assertRaises(Child1.DoesNotExist): + p.child1 + + def test_multiple_subclass(self): + with self.assertNumQueries(1): + p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") + self.assertEqual(p.child1.name2, 'Child1 Parent2') + + def test_onetoone_with_subclass(self): + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") + self.assertEqual(p.child2.name1, 'Child2 Parent1') + + def test_onetoone_with_two_subclasses(self): + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2") + self.assertEqual(p.child2.name1, 'Child2 Parent1') + with self.assertRaises(Child3.DoesNotExist): + p.child2.child3 + p3 = Parent2(name2="Child3 Parent2") + p3.save() + c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3) + c2.save() + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2") + self.assertEqual(p.child2.name1, 'Child3 Parent1') + self.assertEqual(p.child2.child3.value3, 3) + self.assertEqual(p.child2.child3.value, p.child2.value) + self.assertEqual(p.child2.name1, p.child2.child3.name1) + + def test_multiinheritance_two_subclasses(self): + with self.assertNumQueries(1): + p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1") + self.assertEqual(p.child1.name2, 'Child1 Parent2') + self.assertEqual(p.child1.name1, p.name1) + with self.assertRaises(Child4.DoesNotExist): + p.child1.child4 + Child4(name1='n1', name2='n2', value=1, value4=4).save() + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2") + self.assertEqual(p.name2, 'n2') + self.assertEqual(p.child1.name1, 'n1') + self.assertEqual(p.child1.name2, p.name2) + self.assertEqual(p.child1.value, 1) + self.assertEqual(p.child1.child4.name1, p.child1.name1) + self.assertEqual(p.child1.child4.name2, p.child1.name2) + self.assertEqual(p.child1.child4.value, p.child1.value) + self.assertEqual(p.child1.child4.value4, 4) + + @unittest.expectedFailure + def test_inheritance_deferred(self): + c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child1').only( + 'id2', 'child1__value').get(name2="n2") + self.assertEqual(p.id2, c.id2) + self.assertEqual(p.child1.value, 1) + p = Parent2.objects.select_related('child1').only( + 'id2', 'child1__value').get(name2="n2") + with self.assertNumQueries(1): + self.assertEquals(p.name2, 'n2') + p = Parent2.objects.select_related('child1').only( + 'id2', 'child1__value').get(name2="n2") + with self.assertNumQueries(1): + self.assertEquals(p.child1.name2, 'n2') + + @unittest.expectedFailure + def test_inheritance_deferred2(self): + c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) + qs = Parent2.objects.select_related('child1', 'child4').only( + 'id2', 'child1__value', 'child1__child4__value4') + with self.assertNumQueries(1): + p = qs.get(name2="n2") + self.assertEqual(p.id2, c.id2) + self.assertEqual(p.child1.value, 1) + self.assertEqual(p.child1.child4.value4, 4) + self.assertEqual(p.child1.child4.id2, c.id2) + p = qs.get(name2="n2") + with self.assertNumQueries(1): + self.assertEquals(p.child1.name2, 'n2') + p = qs.get(name2="n2") + with self.assertNumQueries(1): + self.assertEquals(p.child1.name1, 'n1') + with self.assertNumQueries(1): + self.assertEquals(p.child1.child4.name1, 'n1')