Fixed #13781 -- Improved select_related in inheritance situations

The select_related code got confused when it needed to travel a
reverse relation to a model which had different parent than the
originally travelled relation.

Thanks to Trac aliases shauncutts for report and ungenio for original
patch (committed patch is somewhat modified version of that).
This commit is contained in:
Anssi Kääriäinen 2012-11-09 20:21:46 +02:00
parent 92d7f541da
commit f51e409a5f
5 changed files with 203 additions and 43 deletions

View File

@ -75,6 +75,7 @@ class Options(object):
from django.db.backends.util import truncate_name from django.db.backends.util import truncate_name
cls._meta = self cls._meta = self
self.model = cls
self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS
# First, construct the default values for these options. # First, construct the default values for these options.
self.object_name = cls.__name__ self.object_name = cls.__name__
@ -464,7 +465,7 @@ class Options(object):
a granparent or even more distant relation. a granparent or even more distant relation.
""" """
if not self.parents: if not self.parents:
return return None
if model in self.parents: if model in self.parents:
return [model] return [model]
for parent in self.parents: for parent in self.parents:
@ -472,8 +473,7 @@ class Options(object):
if res: if res:
res.insert(0, parent) res.insert(0, parent)
return res return res
raise TypeError('%r is not an ancestor of this model' return None
% model._meta.module_name)
def get_parent_list(self): def get_parent_list(self):
""" """

View File

@ -1300,7 +1300,7 @@ class EmptyQuerySet(QuerySet):
value_annotation = False value_annotation = False
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
only_load=None, local_only=False): only_load=None, from_parent=None):
""" """
Helper function that recursively returns an information for a klass, to be Helper function that recursively returns an information for a klass, to be
used in get_cached_row. It exists just to compute this information only used in get_cached_row. It exists just to compute this information only
@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
* only_load - if the query has had only() or defer() applied, * only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None, this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed. the full field list for `klass` can be assumed.
* local_only - Only populate local fields. This is used when * from_parent - the parent model used to get to this model
following reverse select-related relations
Note that when travelling from parent to child, we will only load child
fields which aren't in the parent.
""" """
if max_depth and requested is None and cur_depth > max_depth: if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now. # We've recursed deeply enough; stop now.
@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
for field, model in klass._meta.get_fields_with_model(): for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields: if field.name not in load_fields:
skip.add(field.attname) skip.add(field.attname)
elif local_only and model is not None: elif from_parent and issubclass(from_parent, model.__class__):
# Avoid loading fields already loaded for parent model for
# child models.
continue continue
else: else:
init_list.append(field.attname) init_list.append(field.attname)
@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else: else:
# Load all fields on klass # Load all fields on klass
# We trying to not populate field_names variable for perfomance reason.
# If field_names variable is set, it is used to instantiate desired fields,
# by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
# But kwargs version of Model.__init__ is slower, so we should avoid using
# it when it is not really neccesary.
if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
field_count = len(klass._meta.local_fields)
field_names = [f.attname for f in klass._meta.local_fields]
else:
field_count = len(klass._meta.fields) field_count = len(klass._meta.fields)
# Check if we need to skip some parent fields.
if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
field_names = [f.attname for f in klass._meta.fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
if field_count == len(klass._meta.fields):
field_names = () field_names = ()
restricted = requested is not None restricted = requested is not None
@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
if o.field.unique and select_related_descend(o.field, restricted, requested, if o.field.unique and select_related_descend(o.field, restricted, requested,
only_load.get(o.model), reverse=True): only_load.get(o.model), reverse=True):
next = requested[o.field.related_query_name()] next = requested[o.field.related_query_name()]
parent = klass if issubclass(o.model, klass) else None
klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
requested=next, only_load=only_load, local_only=True) requested=next, only_load=only_load, from_parent=parent)
reverse_related_fields.append((o.field, klass_info)) reverse_related_fields.append((o.field, klass_info))
if field_names: if field_names:
pk_idx = field_names.index(klass._meta.pk.attname) pk_idx = field_names.index(klass._meta.pk.attname)
@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
def get_cached_row(row, index_start, using, klass_info, offset=0): def get_cached_row(row, index_start, using, klass_info, offset=0,
parent_data=()):
""" """
Helper function that recursively returns an object with the specified Helper function that recursively returns an object with the specified
related attributes already populated. related attributes already populated.
@ -1420,11 +1432,14 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
annotated results on `klass`. annotated results on `klass`.
* using - the database alias on which the query is being executed. * using - the database alias on which the query is being executed.
* klass_info - result of the get_klass_info function * klass_info - result of the get_klass_info function
* parent_data - parent model data in format (field, value). Used
to populate the non-local fields of child models.
""" """
if klass_info is None: if klass_info is None:
return None return None
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
fields = row[index_start : index_start + field_count] fields = row[index_start : index_start + field_count]
# If the pk column is None (or the Oracle equivalent ''), then the related # If the pk column is None (or the Oracle equivalent ''), then the related
# object must be non-existent - set the relation to None. # object must be non-existent - set the relation to None.
@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
obj = klass(**dict(zip(field_names, fields))) obj = klass(**dict(zip(field_names, fields)))
else: else:
obj = klass(*fields) obj = klass(*fields)
# If an object was retrieved, set the database state. # If an object was retrieved, set the database state.
if obj: if obj:
obj._state.db = using obj._state.db = using
@ -1464,28 +1478,29 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
# Only handle the restricted case - i.e., don't do a depth # Only handle the restricted case - i.e., don't do a depth
# descent into reverse relations unless explicitly requested # descent into reverse relations unless explicitly requested
for f, klass_info in reverse_related_fields: for f, klass_info in reverse_related_fields:
# Transfer data from this object to childs.
parent_data = []
for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
if rel_model is not None and isinstance(obj, rel_model):
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
# Recursively retrieve the data for the related object # Recursively retrieve the data for the related object
cached_row = get_cached_row(row, index_end, using, klass_info) cached_row = get_cached_row(row, index_end, using, klass_info,
parent_data=parent_data)
# If the recursive descent found an object, populate the # If the recursive descent found an object, populate the
# descriptor caches relevant to the object # descriptor caches relevant to the object
if cached_row: if cached_row:
rel_obj, index_end = cached_row rel_obj, index_end = cached_row
if obj is not None: if obj is not None:
# If the field is unique, populate the # populate the reverse descriptor cache
# reverse descriptor cache
setattr(obj, f.related.get_cache_name(), rel_obj) setattr(obj, f.related.get_cache_name(), rel_obj)
if rel_obj is not None: if rel_obj is not None:
# If the related object exists, populate # If the related object exists, populate
# the descriptor cache. # the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj) setattr(rel_obj, f.get_cache_name(), obj)
# Now populate all the non-local field values # Populate related object caches using parent data.
# on the related object for rel_field, _ in parent_data:
for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
if rel_model is not None:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
# populate the field cache for any related object
# that has already been retrieved
if rel_field.rel: if rel_field.rel:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
try: try:
cached_obj = getattr(obj, rel_field.get_cache_name()) cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj) setattr(rel_obj, rel_field.get_cache_name(), cached_obj)

View File

@ -240,7 +240,7 @@ class SQLCompiler(object):
return result return result
def get_default_columns(self, with_aliases=False, col_aliases=None, def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, local_only=False): start_alias=None, opts=None, as_pairs=False, from_parent=None):
""" """
Computes the default columns for selecting every field in the base Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via model. Will sometimes be called to pull in related models (e.g. via
@ -265,7 +265,8 @@ class SQLCompiler(object):
if start_alias: if start_alias:
seen = {None: start_alias} seen = {None: start_alias}
for field, model in opts.get_fields_with_model(): for field, model in opts.get_fields_with_model():
if local_only and model is not None: if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue continue
if start_alias: if start_alias:
try: try:
@ -686,11 +687,13 @@ class SQLCompiler(object):
(alias, table, f.rel.get_related_field().column, f.column), (alias, table, f.rel.get_related_field().column, f.column),
promote=True promote=True
) )
from_parent = (opts.model if issubclass(model, opts.model)
else None)
columns, aliases = self.get_default_columns(start_alias=alias, columns, aliases = self.get_default_columns(start_alias=alias,
opts=model._meta, as_pairs=True, local_only=True) opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend( self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field in zip(columns, model._meta.fields)) SelectInfo(col, field) for col, field
in zip(columns, model._meta.fields))
next = requested.get(f.related_query_name(), {}) next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of # Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable. # the relation, which is always nullable.

View File

@ -51,6 +51,7 @@ class StatDetails(models.Model):
class AdvancedUserStat(UserStat): class AdvancedUserStat(UserStat):
karma = models.IntegerField() karma = models.IntegerField()
class Image(models.Model): class Image(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
@ -58,3 +59,44 @@ class Image(models.Model):
class Product(models.Model): class Product(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
image = models.OneToOneField(Image, null=True) image = models.OneToOneField(Image, null=True)
@python_2_unicode_compatible
class Parent1(models.Model):
name1 = models.CharField(max_length=50)
def __str__(self):
return self.name1
@python_2_unicode_compatible
class Parent2(models.Model):
# Avoid having two "id" fields in the Child1 subclass
id2 = models.AutoField(primary_key=True)
name2 = models.CharField(max_length=50)
def __str__(self):
return self.name2
@python_2_unicode_compatible
class Child1(Parent1, Parent2):
value = models.IntegerField()
def __str__(self):
return self.name1
@python_2_unicode_compatible
class Child2(Parent1):
parent2 = models.OneToOneField(Parent2)
value = models.IntegerField()
def __str__(self):
return self.name1
class Child3(Child2):
value3 = models.IntegerField()
class Child4(Child1):
value4 = models.IntegerField()

View File

@ -1,9 +1,11 @@
from __future__ import absolute_import from __future__ import absolute_import
from django.test import TestCase from django.test import TestCase
from django.utils import unittest
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
AdvancedUserStat, Image, Product) AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2, Child3,
Child4)
class ReverseSelectRelatedTestCase(TestCase): class ReverseSelectRelatedTestCase(TestCase):
@ -21,6 +23,14 @@ class ReverseSelectRelatedTestCase(TestCase):
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
results=results2) results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250) StatDetails.objects.create(base_stats=advstat, comments=250)
p1 = Parent1(name1="Only Parent1")
p1.save()
c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1)
c1.save()
p2 = Parent2(name2="Child2 Parent2")
p2.save()
c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2)
c2.save()
def test_basic(self): def test_basic(self):
with self.assertNumQueries(1): with self.assertNumQueries(1):
@ -108,3 +118,93 @@ class ReverseSelectRelatedTestCase(TestCase):
image = Image.objects.select_related('product').get() image = Image.objects.select_related('product').get()
with self.assertRaises(Product.DoesNotExist): with self.assertRaises(Product.DoesNotExist):
image.product image.product
def test_parent_only(self):
with self.assertNumQueries(1):
p = Parent1.objects.select_related('child1').get(name1="Only Parent1")
with self.assertNumQueries(0):
with self.assertRaises(Child1.DoesNotExist):
p.child1
def test_multiple_subclass(self):
with self.assertNumQueries(1):
p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
self.assertEqual(p.child1.name2, 'Child1 Parent2')
def test_onetoone_with_subclass(self):
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
self.assertEqual(p.child2.name1, 'Child2 Parent1')
def test_onetoone_with_two_subclasses(self):
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2")
self.assertEqual(p.child2.name1, 'Child2 Parent1')
with self.assertRaises(Child3.DoesNotExist):
p.child2.child3
p3 = Parent2(name2="Child3 Parent2")
p3.save()
c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3)
c2.save()
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2")
self.assertEqual(p.child2.name1, 'Child3 Parent1')
self.assertEqual(p.child2.child3.value3, 3)
self.assertEqual(p.child2.child3.value, p.child2.value)
self.assertEqual(p.child2.name1, p.child2.child3.name1)
def test_multiinheritance_two_subclasses(self):
with self.assertNumQueries(1):
p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1")
self.assertEqual(p.child1.name2, 'Child1 Parent2')
self.assertEqual(p.child1.name1, p.name1)
with self.assertRaises(Child4.DoesNotExist):
p.child1.child4
Child4(name1='n1', name2='n2', value=1, value4=4).save()
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2")
self.assertEqual(p.name2, 'n2')
self.assertEqual(p.child1.name1, 'n1')
self.assertEqual(p.child1.name2, p.name2)
self.assertEqual(p.child1.value, 1)
self.assertEqual(p.child1.child4.name1, p.child1.name1)
self.assertEqual(p.child1.child4.name2, p.child1.name2)
self.assertEqual(p.child1.child4.value, p.child1.value)
self.assertEqual(p.child1.child4.value4, 4)
@unittest.expectedFailure
def test_inheritance_deferred(self):
c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child1').only(
'id2', 'child1__value').get(name2="n2")
self.assertEqual(p.id2, c.id2)
self.assertEqual(p.child1.value, 1)
p = Parent2.objects.select_related('child1').only(
'id2', 'child1__value').get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.name2, 'n2')
p = Parent2.objects.select_related('child1').only(
'id2', 'child1__value').get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.child1.name2, 'n2')
@unittest.expectedFailure
def test_inheritance_deferred2(self):
c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
qs = Parent2.objects.select_related('child1', 'child4').only(
'id2', 'child1__value', 'child1__child4__value4')
with self.assertNumQueries(1):
p = qs.get(name2="n2")
self.assertEqual(p.id2, c.id2)
self.assertEqual(p.child1.value, 1)
self.assertEqual(p.child1.child4.value4, 4)
self.assertEqual(p.child1.child4.id2, c.id2)
p = qs.get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.child1.name2, 'n2')
p = qs.get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.child1.name1, 'n1')
with self.assertNumQueries(1):
self.assertEquals(p.child1.child4.name1, 'n1')