Fixed #13781 -- Improved select_related in inheritance situations

The select_related code got confused when it needed to travel a
reverse relation to a model which had different parent than the
originally travelled relation.

Thanks to Trac aliases shauncutts for report and ungenio for original
patch (committed patch is somewhat modified version of that).
This commit is contained in:
Anssi Kääriäinen 2012-11-09 20:21:46 +02:00
parent 92d7f541da
commit f51e409a5f
5 changed files with 203 additions and 43 deletions

View File

@ -75,6 +75,7 @@ class Options(object):
from django.db.backends.util import truncate_name
cls._meta = self
self.model = cls
self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS
# First, construct the default values for these options.
self.object_name = cls.__name__
@ -464,7 +465,7 @@ class Options(object):
a granparent or even more distant relation.
"""
if not self.parents:
return
return None
if model in self.parents:
return [model]
for parent in self.parents:
@ -472,8 +473,7 @@ class Options(object):
if res:
res.insert(0, parent)
return res
raise TypeError('%r is not an ancestor of this model'
% model._meta.module_name)
return None
def get_parent_list(self):
"""

View File

@ -1300,7 +1300,7 @@ class EmptyQuerySet(QuerySet):
value_annotation = False
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
only_load=None, local_only=False):
only_load=None, from_parent=None):
"""
Helper function that recursively returns an information for a klass, to be
used in get_cached_row. It exists just to compute this information only
@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
* local_only - Only populate local fields. This is used when
following reverse select-related relations
* from_parent - the parent model used to get to this model
Note that when travelling from parent to child, we will only load child
fields which aren't in the parent.
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields:
skip.add(field.attname)
elif local_only and model is not None:
elif from_parent and issubclass(from_parent, model.__class__):
# Avoid loading fields already loaded for parent model for
# child models.
continue
else:
init_list.append(field.attname)
@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else:
# Load all fields on klass
# We trying to not populate field_names variable for perfomance reason.
# If field_names variable is set, it is used to instantiate desired fields,
# by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
# But kwargs version of Model.__init__ is slower, so we should avoid using
# it when it is not really neccesary.
if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
field_count = len(klass._meta.local_fields)
field_names = [f.attname for f in klass._meta.local_fields]
else:
field_count = len(klass._meta.fields)
field_count = len(klass._meta.fields)
# Check if we need to skip some parent fields.
if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
field_names = [f.attname for f in klass._meta.fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
if field_count == len(klass._meta.fields):
field_names = ()
restricted = requested is not None
@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
if o.field.unique and select_related_descend(o.field, restricted, requested,
only_load.get(o.model), reverse=True):
next = requested[o.field.related_query_name()]
parent = klass if issubclass(o.model, klass) else None
klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
requested=next, only_load=only_load, local_only=True)
requested=next, only_load=only_load, from_parent=parent)
reverse_related_fields.append((o.field, klass_info))
if field_names:
pk_idx = field_names.index(klass._meta.pk.attname)
@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
def get_cached_row(row, index_start, using, klass_info, offset=0):
def get_cached_row(row, index_start, using, klass_info, offset=0,
parent_data=()):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
@ -1418,13 +1430,16 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
* offset - the number of additional fields that are known to
exist in row for `klass`. This usually means the number of
annotated results on `klass`.
* using - the database alias on which the query is being executed.
* using - the database alias on which the query is being executed.
* klass_info - result of the get_klass_info function
* parent_data - parent model data in format (field, value). Used
to populate the non-local fields of child models.
"""
if klass_info is None:
return None
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
fields = row[index_start : index_start + field_count]
# If the pk column is None (or the Oracle equivalent ''), then the related
# object must be non-existent - set the relation to None.
@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
obj = klass(**dict(zip(field_names, fields)))
else:
obj = klass(*fields)
# If an object was retrieved, set the database state.
if obj:
obj._state.db = using
@ -1464,34 +1478,35 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
# Only handle the restricted case - i.e., don't do a depth
# descent into reverse relations unless explicitly requested
for f, klass_info in reverse_related_fields:
# Transfer data from this object to childs.
parent_data = []
for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
if rel_model is not None and isinstance(obj, rel_model):
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
# Recursively retrieve the data for the related object
cached_row = get_cached_row(row, index_end, using, klass_info)
cached_row = get_cached_row(row, index_end, using, klass_info,
parent_data=parent_data)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
# If the field is unique, populate the
# reverse descriptor cache
# populate the reverse descriptor cache
setattr(obj, f.related.get_cache_name(), rel_obj)
if rel_obj is not None:
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
# Now populate all the non-local field values
# on the related object
for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
if rel_model is not None:
# Populate related object caches using parent data.
for rel_field, _ in parent_data:
if rel_field.rel:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
# populate the field cache for any related object
# that has already been retrieved
if rel_field.rel:
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
return obj, index_end

View File

@ -240,7 +240,7 @@ class SQLCompiler(object):
return result
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, local_only=False):
start_alias=None, opts=None, as_pairs=False, from_parent=None):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
@ -265,7 +265,8 @@ class SQLCompiler(object):
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
if local_only and model is not None:
if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue
if start_alias:
try:
@ -686,11 +687,13 @@ class SQLCompiler(object):
(alias, table, f.rel.get_related_field().column, f.column),
promote=True
)
from_parent = (opts.model if issubclass(model, opts.model)
else None)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=model._meta, as_pairs=True, local_only=True)
opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))
SelectInfo(col, field) for col, field
in zip(columns, model._meta.fields))
next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable.

View File

@ -51,6 +51,7 @@ class StatDetails(models.Model):
class AdvancedUserStat(UserStat):
karma = models.IntegerField()
class Image(models.Model):
name = models.CharField(max_length=100)
@ -58,3 +59,44 @@ class Image(models.Model):
class Product(models.Model):
name = models.CharField(max_length=100)
image = models.OneToOneField(Image, null=True)
@python_2_unicode_compatible
class Parent1(models.Model):
name1 = models.CharField(max_length=50)
def __str__(self):
return self.name1
@python_2_unicode_compatible
class Parent2(models.Model):
# Avoid having two "id" fields in the Child1 subclass
id2 = models.AutoField(primary_key=True)
name2 = models.CharField(max_length=50)
def __str__(self):
return self.name2
@python_2_unicode_compatible
class Child1(Parent1, Parent2):
value = models.IntegerField()
def __str__(self):
return self.name1
@python_2_unicode_compatible
class Child2(Parent1):
parent2 = models.OneToOneField(Parent2)
value = models.IntegerField()
def __str__(self):
return self.name1
class Child3(Child2):
value3 = models.IntegerField()
class Child4(Child1):
value4 = models.IntegerField()

View File

@ -1,9 +1,11 @@
from __future__ import absolute_import
from django.test import TestCase
from django.utils import unittest
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
AdvancedUserStat, Image, Product)
AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2, Child3,
Child4)
class ReverseSelectRelatedTestCase(TestCase):
@ -21,6 +23,14 @@ class ReverseSelectRelatedTestCase(TestCase):
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250)
p1 = Parent1(name1="Only Parent1")
p1.save()
c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1)
c1.save()
p2 = Parent2(name2="Child2 Parent2")
p2.save()
c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2)
c2.save()
def test_basic(self):
with self.assertNumQueries(1):
@ -108,3 +118,93 @@ class ReverseSelectRelatedTestCase(TestCase):
image = Image.objects.select_related('product').get()
with self.assertRaises(Product.DoesNotExist):
image.product
def test_parent_only(self):
with self.assertNumQueries(1):
p = Parent1.objects.select_related('child1').get(name1="Only Parent1")
with self.assertNumQueries(0):
with self.assertRaises(Child1.DoesNotExist):
p.child1
def test_multiple_subclass(self):
with self.assertNumQueries(1):
p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
self.assertEqual(p.child1.name2, 'Child1 Parent2')
def test_onetoone_with_subclass(self):
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
self.assertEqual(p.child2.name1, 'Child2 Parent1')
def test_onetoone_with_two_subclasses(self):
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2")
self.assertEqual(p.child2.name1, 'Child2 Parent1')
with self.assertRaises(Child3.DoesNotExist):
p.child2.child3
p3 = Parent2(name2="Child3 Parent2")
p3.save()
c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3)
c2.save()
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2")
self.assertEqual(p.child2.name1, 'Child3 Parent1')
self.assertEqual(p.child2.child3.value3, 3)
self.assertEqual(p.child2.child3.value, p.child2.value)
self.assertEqual(p.child2.name1, p.child2.child3.name1)
def test_multiinheritance_two_subclasses(self):
with self.assertNumQueries(1):
p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1")
self.assertEqual(p.child1.name2, 'Child1 Parent2')
self.assertEqual(p.child1.name1, p.name1)
with self.assertRaises(Child4.DoesNotExist):
p.child1.child4
Child4(name1='n1', name2='n2', value=1, value4=4).save()
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2")
self.assertEqual(p.name2, 'n2')
self.assertEqual(p.child1.name1, 'n1')
self.assertEqual(p.child1.name2, p.name2)
self.assertEqual(p.child1.value, 1)
self.assertEqual(p.child1.child4.name1, p.child1.name1)
self.assertEqual(p.child1.child4.name2, p.child1.name2)
self.assertEqual(p.child1.child4.value, p.child1.value)
self.assertEqual(p.child1.child4.value4, 4)
@unittest.expectedFailure
def test_inheritance_deferred(self):
c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
with self.assertNumQueries(1):
p = Parent2.objects.select_related('child1').only(
'id2', 'child1__value').get(name2="n2")
self.assertEqual(p.id2, c.id2)
self.assertEqual(p.child1.value, 1)
p = Parent2.objects.select_related('child1').only(
'id2', 'child1__value').get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.name2, 'n2')
p = Parent2.objects.select_related('child1').only(
'id2', 'child1__value').get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.child1.name2, 'n2')
@unittest.expectedFailure
def test_inheritance_deferred2(self):
c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
qs = Parent2.objects.select_related('child1', 'child4').only(
'id2', 'child1__value', 'child1__child4__value4')
with self.assertNumQueries(1):
p = qs.get(name2="n2")
self.assertEqual(p.id2, c.id2)
self.assertEqual(p.child1.value, 1)
self.assertEqual(p.child1.child4.value4, 4)
self.assertEqual(p.child1.child4.id2, c.id2)
p = qs.get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.child1.name2, 'n2')
p = qs.get(name2="n2")
with self.assertNumQueries(1):
self.assertEquals(p.child1.name1, 'n1')
with self.assertNumQueries(1):
self.assertEquals(p.child1.child4.name1, 'n1')