Fixed #13781 -- Improved select_related in inheritance situations
The select_related code got confused when it needed to travel a reverse relation to a model which had different parent than the originally travelled relation. Thanks to Trac aliases shauncutts for report and ungenio for original patch (committed patch is somewhat modified version of that).
This commit is contained in:
parent
92d7f541da
commit
f51e409a5f
|
@ -75,6 +75,7 @@ class Options(object):
|
|||
from django.db.backends.util import truncate_name
|
||||
|
||||
cls._meta = self
|
||||
self.model = cls
|
||||
self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS
|
||||
# First, construct the default values for these options.
|
||||
self.object_name = cls.__name__
|
||||
|
@ -464,7 +465,7 @@ class Options(object):
|
|||
a granparent or even more distant relation.
|
||||
"""
|
||||
if not self.parents:
|
||||
return
|
||||
return None
|
||||
if model in self.parents:
|
||||
return [model]
|
||||
for parent in self.parents:
|
||||
|
@ -472,8 +473,7 @@ class Options(object):
|
|||
if res:
|
||||
res.insert(0, parent)
|
||||
return res
|
||||
raise TypeError('%r is not an ancestor of this model'
|
||||
% model._meta.module_name)
|
||||
return None
|
||||
|
||||
def get_parent_list(self):
|
||||
"""
|
||||
|
|
|
@ -1300,7 +1300,7 @@ class EmptyQuerySet(QuerySet):
|
|||
value_annotation = False
|
||||
|
||||
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
||||
only_load=None, local_only=False):
|
||||
only_load=None, from_parent=None):
|
||||
"""
|
||||
Helper function that recursively returns an information for a klass, to be
|
||||
used in get_cached_row. It exists just to compute this information only
|
||||
|
@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
|||
* only_load - if the query has had only() or defer() applied,
|
||||
this is the list of field names that will be returned. If None,
|
||||
the full field list for `klass` can be assumed.
|
||||
* local_only - Only populate local fields. This is used when
|
||||
following reverse select-related relations
|
||||
* from_parent - the parent model used to get to this model
|
||||
|
||||
Note that when travelling from parent to child, we will only load child
|
||||
fields which aren't in the parent.
|
||||
"""
|
||||
if max_depth and requested is None and cur_depth > max_depth:
|
||||
# We've recursed deeply enough; stop now.
|
||||
|
@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
|||
for field, model in klass._meta.get_fields_with_model():
|
||||
if field.name not in load_fields:
|
||||
skip.add(field.attname)
|
||||
elif local_only and model is not None:
|
||||
elif from_parent and issubclass(from_parent, model.__class__):
|
||||
# Avoid loading fields already loaded for parent model for
|
||||
# child models.
|
||||
continue
|
||||
else:
|
||||
init_list.append(field.attname)
|
||||
|
@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
|||
else:
|
||||
# Load all fields on klass
|
||||
|
||||
# We trying to not populate field_names variable for perfomance reason.
|
||||
# If field_names variable is set, it is used to instantiate desired fields,
|
||||
# by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
|
||||
# But kwargs version of Model.__init__ is slower, so we should avoid using
|
||||
# it when it is not really neccesary.
|
||||
if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
|
||||
field_count = len(klass._meta.local_fields)
|
||||
field_names = [f.attname for f in klass._meta.local_fields]
|
||||
else:
|
||||
field_count = len(klass._meta.fields)
|
||||
# Check if we need to skip some parent fields.
|
||||
if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
|
||||
# Only load those fields which haven't been already loaded into
|
||||
# 'from_parent'.
|
||||
non_seen_models = [p for p in klass._meta.get_parent_list()
|
||||
if not issubclass(from_parent, p)]
|
||||
# Load local fields, too...
|
||||
non_seen_models.append(klass)
|
||||
field_names = [f.attname for f in klass._meta.fields
|
||||
if f.model in non_seen_models]
|
||||
field_count = len(field_names)
|
||||
# Try to avoid populating field_names variable for perfomance reasons.
|
||||
# If field_names variable is set, we use **kwargs based model init
|
||||
# which is slower than normal init.
|
||||
if field_count == len(klass._meta.fields):
|
||||
field_names = ()
|
||||
|
||||
restricted = requested is not None
|
||||
|
@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
|||
if o.field.unique and select_related_descend(o.field, restricted, requested,
|
||||
only_load.get(o.model), reverse=True):
|
||||
next = requested[o.field.related_query_name()]
|
||||
parent = klass if issubclass(o.model, klass) else None
|
||||
klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
|
||||
requested=next, only_load=only_load, local_only=True)
|
||||
requested=next, only_load=only_load, from_parent=parent)
|
||||
reverse_related_fields.append((o.field, klass_info))
|
||||
if field_names:
|
||||
pk_idx = field_names.index(klass._meta.pk.attname)
|
||||
|
@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
|||
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
|
||||
|
||||
|
||||
def get_cached_row(row, index_start, using, klass_info, offset=0):
|
||||
def get_cached_row(row, index_start, using, klass_info, offset=0,
|
||||
parent_data=()):
|
||||
"""
|
||||
Helper function that recursively returns an object with the specified
|
||||
related attributes already populated.
|
||||
|
@ -1420,11 +1432,14 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
|
|||
annotated results on `klass`.
|
||||
* using - the database alias on which the query is being executed.
|
||||
* klass_info - result of the get_klass_info function
|
||||
* parent_data - parent model data in format (field, value). Used
|
||||
to populate the non-local fields of child models.
|
||||
"""
|
||||
if klass_info is None:
|
||||
return None
|
||||
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
|
||||
|
||||
|
||||
fields = row[index_start : index_start + field_count]
|
||||
# If the pk column is None (or the Oracle equivalent ''), then the related
|
||||
# object must be non-existent - set the relation to None.
|
||||
|
@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
|
|||
obj = klass(**dict(zip(field_names, fields)))
|
||||
else:
|
||||
obj = klass(*fields)
|
||||
|
||||
# If an object was retrieved, set the database state.
|
||||
if obj:
|
||||
obj._state.db = using
|
||||
|
@ -1464,28 +1478,29 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
|
|||
# Only handle the restricted case - i.e., don't do a depth
|
||||
# descent into reverse relations unless explicitly requested
|
||||
for f, klass_info in reverse_related_fields:
|
||||
# Transfer data from this object to childs.
|
||||
parent_data = []
|
||||
for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
|
||||
if rel_model is not None and isinstance(obj, rel_model):
|
||||
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
|
||||
# Recursively retrieve the data for the related object
|
||||
cached_row = get_cached_row(row, index_end, using, klass_info)
|
||||
cached_row = get_cached_row(row, index_end, using, klass_info,
|
||||
parent_data=parent_data)
|
||||
# If the recursive descent found an object, populate the
|
||||
# descriptor caches relevant to the object
|
||||
if cached_row:
|
||||
rel_obj, index_end = cached_row
|
||||
if obj is not None:
|
||||
# If the field is unique, populate the
|
||||
# reverse descriptor cache
|
||||
# populate the reverse descriptor cache
|
||||
setattr(obj, f.related.get_cache_name(), rel_obj)
|
||||
if rel_obj is not None:
|
||||
# If the related object exists, populate
|
||||
# the descriptor cache.
|
||||
setattr(rel_obj, f.get_cache_name(), obj)
|
||||
# Now populate all the non-local field values
|
||||
# on the related object
|
||||
for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
|
||||
if rel_model is not None:
|
||||
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
|
||||
# populate the field cache for any related object
|
||||
# that has already been retrieved
|
||||
# Populate related object caches using parent data.
|
||||
for rel_field, _ in parent_data:
|
||||
if rel_field.rel:
|
||||
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
|
||||
try:
|
||||
cached_obj = getattr(obj, rel_field.get_cache_name())
|
||||
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
|
||||
|
|
|
@ -240,7 +240,7 @@ class SQLCompiler(object):
|
|||
return result
|
||||
|
||||
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
||||
start_alias=None, opts=None, as_pairs=False, local_only=False):
|
||||
start_alias=None, opts=None, as_pairs=False, from_parent=None):
|
||||
"""
|
||||
Computes the default columns for selecting every field in the base
|
||||
model. Will sometimes be called to pull in related models (e.g. via
|
||||
|
@ -265,7 +265,8 @@ class SQLCompiler(object):
|
|||
if start_alias:
|
||||
seen = {None: start_alias}
|
||||
for field, model in opts.get_fields_with_model():
|
||||
if local_only and model is not None:
|
||||
if from_parent and model is not None and issubclass(from_parent, model):
|
||||
# Avoid loading data for already loaded parents.
|
||||
continue
|
||||
if start_alias:
|
||||
try:
|
||||
|
@ -686,11 +687,13 @@ class SQLCompiler(object):
|
|||
(alias, table, f.rel.get_related_field().column, f.column),
|
||||
promote=True
|
||||
)
|
||||
from_parent = (opts.model if issubclass(model, opts.model)
|
||||
else None)
|
||||
columns, aliases = self.get_default_columns(start_alias=alias,
|
||||
opts=model._meta, as_pairs=True, local_only=True)
|
||||
opts=model._meta, as_pairs=True, from_parent=from_parent)
|
||||
self.query.related_select_cols.extend(
|
||||
SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))
|
||||
|
||||
SelectInfo(col, field) for col, field
|
||||
in zip(columns, model._meta.fields))
|
||||
next = requested.get(f.related_query_name(), {})
|
||||
# Use True here because we are looking at the _reverse_ side of
|
||||
# the relation, which is always nullable.
|
||||
|
|
|
@ -51,6 +51,7 @@ class StatDetails(models.Model):
|
|||
class AdvancedUserStat(UserStat):
|
||||
karma = models.IntegerField()
|
||||
|
||||
|
||||
class Image(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
|
||||
|
@ -58,3 +59,44 @@ class Image(models.Model):
|
|||
class Product(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
image = models.OneToOneField(Image, null=True)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Parent1(models.Model):
|
||||
name1 = models.CharField(max_length=50)
|
||||
|
||||
def __str__(self):
|
||||
return self.name1
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Parent2(models.Model):
|
||||
# Avoid having two "id" fields in the Child1 subclass
|
||||
id2 = models.AutoField(primary_key=True)
|
||||
name2 = models.CharField(max_length=50)
|
||||
|
||||
def __str__(self):
|
||||
return self.name2
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Child1(Parent1, Parent2):
|
||||
value = models.IntegerField()
|
||||
|
||||
def __str__(self):
|
||||
return self.name1
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Child2(Parent1):
|
||||
parent2 = models.OneToOneField(Parent2)
|
||||
value = models.IntegerField()
|
||||
|
||||
def __str__(self):
|
||||
return self.name1
|
||||
|
||||
class Child3(Child2):
|
||||
value3 = models.IntegerField()
|
||||
|
||||
class Child4(Child1):
|
||||
value4 = models.IntegerField()
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils import unittest
|
||||
|
||||
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
|
||||
AdvancedUserStat, Image, Product)
|
||||
AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2, Child3,
|
||||
Child4)
|
||||
|
||||
|
||||
class ReverseSelectRelatedTestCase(TestCase):
|
||||
|
@ -21,6 +23,14 @@ class ReverseSelectRelatedTestCase(TestCase):
|
|||
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
|
||||
results=results2)
|
||||
StatDetails.objects.create(base_stats=advstat, comments=250)
|
||||
p1 = Parent1(name1="Only Parent1")
|
||||
p1.save()
|
||||
c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1)
|
||||
c1.save()
|
||||
p2 = Parent2(name2="Child2 Parent2")
|
||||
p2.save()
|
||||
c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2)
|
||||
c2.save()
|
||||
|
||||
def test_basic(self):
|
||||
with self.assertNumQueries(1):
|
||||
|
@ -108,3 +118,93 @@ class ReverseSelectRelatedTestCase(TestCase):
|
|||
image = Image.objects.select_related('product').get()
|
||||
with self.assertRaises(Product.DoesNotExist):
|
||||
image.product
|
||||
|
||||
def test_parent_only(self):
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent1.objects.select_related('child1').get(name1="Only Parent1")
|
||||
with self.assertNumQueries(0):
|
||||
with self.assertRaises(Child1.DoesNotExist):
|
||||
p.child1
|
||||
|
||||
def test_multiple_subclass(self):
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
|
||||
self.assertEqual(p.child1.name2, 'Child1 Parent2')
|
||||
|
||||
def test_onetoone_with_subclass(self):
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
|
||||
self.assertEqual(p.child2.name1, 'Child2 Parent1')
|
||||
|
||||
def test_onetoone_with_two_subclasses(self):
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2")
|
||||
self.assertEqual(p.child2.name1, 'Child2 Parent1')
|
||||
with self.assertRaises(Child3.DoesNotExist):
|
||||
p.child2.child3
|
||||
p3 = Parent2(name2="Child3 Parent2")
|
||||
p3.save()
|
||||
c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3)
|
||||
c2.save()
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2")
|
||||
self.assertEqual(p.child2.name1, 'Child3 Parent1')
|
||||
self.assertEqual(p.child2.child3.value3, 3)
|
||||
self.assertEqual(p.child2.child3.value, p.child2.value)
|
||||
self.assertEqual(p.child2.name1, p.child2.child3.name1)
|
||||
|
||||
def test_multiinheritance_two_subclasses(self):
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1")
|
||||
self.assertEqual(p.child1.name2, 'Child1 Parent2')
|
||||
self.assertEqual(p.child1.name1, p.name1)
|
||||
with self.assertRaises(Child4.DoesNotExist):
|
||||
p.child1.child4
|
||||
Child4(name1='n1', name2='n2', value=1, value4=4).save()
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2")
|
||||
self.assertEqual(p.name2, 'n2')
|
||||
self.assertEqual(p.child1.name1, 'n1')
|
||||
self.assertEqual(p.child1.name2, p.name2)
|
||||
self.assertEqual(p.child1.value, 1)
|
||||
self.assertEqual(p.child1.child4.name1, p.child1.name1)
|
||||
self.assertEqual(p.child1.child4.name2, p.child1.name2)
|
||||
self.assertEqual(p.child1.child4.value, p.child1.value)
|
||||
self.assertEqual(p.child1.child4.value4, 4)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_inheritance_deferred(self):
|
||||
c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
|
||||
with self.assertNumQueries(1):
|
||||
p = Parent2.objects.select_related('child1').only(
|
||||
'id2', 'child1__value').get(name2="n2")
|
||||
self.assertEqual(p.id2, c.id2)
|
||||
self.assertEqual(p.child1.value, 1)
|
||||
p = Parent2.objects.select_related('child1').only(
|
||||
'id2', 'child1__value').get(name2="n2")
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEquals(p.name2, 'n2')
|
||||
p = Parent2.objects.select_related('child1').only(
|
||||
'id2', 'child1__value').get(name2="n2")
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEquals(p.child1.name2, 'n2')
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_inheritance_deferred2(self):
|
||||
c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
|
||||
qs = Parent2.objects.select_related('child1', 'child4').only(
|
||||
'id2', 'child1__value', 'child1__child4__value4')
|
||||
with self.assertNumQueries(1):
|
||||
p = qs.get(name2="n2")
|
||||
self.assertEqual(p.id2, c.id2)
|
||||
self.assertEqual(p.child1.value, 1)
|
||||
self.assertEqual(p.child1.child4.value4, 4)
|
||||
self.assertEqual(p.child1.child4.id2, c.id2)
|
||||
p = qs.get(name2="n2")
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEquals(p.child1.name2, 'n2')
|
||||
p = qs.get(name2="n2")
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEquals(p.child1.name1, 'n1')
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEquals(p.child1.child4.name1, 'n1')
|
||||
|
|
Loading…
Reference in New Issue