Fixed #7270 -- Added the ability to follow reverse OneToOneFields in select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2010-01-27 13:30:29 +00:00
parent 8e8d4b5888
commit 58cd220f51
9 changed files with 306 additions and 14 deletions

View File

@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object):
# SingleRelatedObjectDescriptor instance. # SingleRelatedObjectDescriptor instance.
def __init__(self, related): def __init__(self, related):
self.related = related self.related = related
self.cache_name = '_%s_cache' % related.get_accessor_name() self.cache_name = related.get_cache_name()
def __get__(self, instance, instance_type=None): def __get__(self, instance, instance_type=None):
if instance is None: if instance is None:
@ -319,7 +319,7 @@ class ReverseSingleRelatedObjectDescriptor(object):
# cache. This cache also might not exist if the related object # cache. This cache also might not exist if the related object
# hasn't been accessed yet. # hasn't been accessed yet.
if related: if related:
cache_name = '_%s_cache' % self.field.related.get_accessor_name() cache_name = self.field.related.get_cache_name()
try: try:
delattr(related, cache_name) delattr(related, cache_name)
except AttributeError: except AttributeError:

View File

@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
""" """
Helper function that recursively returns an object with the specified Helper function that recursively returns an object with the specified
related attributes already populated. related attributes already populated.
This method may be called recursively to populate deep select_related()
clauses.
Arguments:
* klass - the class to retrieve (and instantiate)
* row - the row of data returned by the database cursor
* index_start - the index of the row at which data for this
object is known to start
* max_depth - the maximum depth to which a select_related()
relationship should be explored.
* cur_depth - the current depth in the select_related() tree.
Used in recursive calls to determin if we should dig deeper.
* requested - A dictionary describing the select_related() tree
that is to be retrieved. keys are field names; values are
dictionaries describing the keys on that related object that
are themselves to be select_related().
* offset - the number of additional fields that are known to
exist in `row` for `klass`. This usually means the number of
annotated results on `klass`.
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
""" """
if max_depth and requested is None and cur_depth > max_depth: if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now. # We've recursed deeply enough; stop now.
@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
# Handle deferred fields. # Handle deferred fields.
skip = set() skip = set()
init_list = [] init_list = []
pk_val = row[index_start + klass._meta.pk_index()] # Build the list of fields that *haven't* been requested
for field in klass._meta.fields: for field in klass._meta.fields:
if field.name not in load_fields: if field.name not in load_fields:
skip.add(field.name) skip.add(field.name)
else: else:
init_list.append(field.attname) init_list.append(field.attname)
# Retrieve all the requested fields
field_count = len(init_list) field_count = len(init_list)
fields = row[index_start : index_start + field_count] fields = row[index_start : index_start + field_count]
# If all the select_related columns are None, then the related
# object must be non-existent - set the relation to None.
# Otherwise, construct the related object.
if fields == (None,) * field_count: if fields == (None,) * field_count:
obj = None obj = None
elif skip: elif skip:
@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
else: else:
obj = klass(*fields) obj = klass(*fields)
else: else:
# Load all fields on klass
field_count = len(klass._meta.fields) field_count = len(klass._meta.fields)
fields = row[index_start : index_start + field_count] fields = row[index_start : index_start + field_count]
# If all the select_related columns are None, then the related
# object must be non-existent - set the relation to None.
# Otherwise, construct the related object.
if fields == (None,) * field_count: if fields == (None,) * field_count:
obj = None obj = None
else: else:
obj = klass(*fields) obj = klass(*fields)
index_end = index_start + field_count + offset index_end = index_start + field_count + offset
# Iterate over each related object, populating any
# select_related() fields
for f in klass._meta.fields: for f in klass._meta.fields:
if not select_related_descend(f, restricted, requested): if not select_related_descend(f, restricted, requested):
continue continue
@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
next = requested[f.name] next = requested[f.name]
else: else:
next = None next = None
# Recursively retrieve the data for the related object
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
cur_depth+1, next) cur_depth+1, next)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row: if cached_row:
rel_obj, index_end = cached_row rel_obj, index_end = cached_row
if obj is not None: if obj is not None:
# If the base object exists, populate the
# descriptor cache
setattr(obj, f.get_cache_name(), rel_obj) setattr(obj, f.get_cache_name(), rel_obj)
if f.unique:
# If the field is unique, populate the
# reverse descriptor cache on the related object
setattr(rel_obj, f.related.get_cache_name(), obj)
# Now do the same, but for reverse related objects.
# Only handle the restricted case - i.e., don't do a depth
# descent into reverse relations unless explicitly requested
if restricted:
related_fields = [
(o.field, o.model)
for o in klass._meta.get_all_related_objects()
if o.field.unique
]
for f, model in related_fields:
if not select_related_descend(f, restricted, requested, reverse=True):
continue
next = requested[f.related_query_name()]
# Recursively retrieve the data for the related object
cached_row = get_cached_row(model, row, index_end, max_depth,
cur_depth+1, next)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
# If the field is unique, populate the
# reverse descriptor cache
setattr(obj, f.related.get_cache_name(), rel_obj)
if rel_obj is not None:
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
return obj, index_end return obj, index_end
def delete_objects(seen_objs, using): def delete_objects(seen_objs, using):

View File

@ -197,19 +197,29 @@ class DeferredAttribute(object):
""" """
instance.__dict__[self.field_name] = value instance.__dict__[self.field_name] = value
def select_related_descend(field, restricted, requested): def select_related_descend(field, restricted, requested, reverse=False):
""" """
Returns True if this field should be used to descend deeper for Returns True if this field should be used to descend deeper for
select_related() purposes. Used by both the query construction code select_related() purposes. Used by both the query construction code
(sql.query.fill_related_selections()) and the model instance creation code (sql.query.fill_related_selections()) and the model instance creation code
(query.get_cached_row()). (query.get_cached_row()).
Arguments:
* field - the field to be checked
* restricted - a boolean field, indicating if the field list has been
manually restricted using a requested clause)
* requested - The select_related() dictionary.
* reverse - boolean, True if we are checking a reverse select related
""" """
if not field.rel: if not field.rel:
return False return False
if field.rel.parent_link: if field.rel.parent_link and not reverse:
return False
if restricted and field.name not in requested:
return False return False
if restricted:
if reverse and field.related_query_name() not in requested:
return False
if not reverse and field.name not in requested:
return False
if not restricted and field.null: if not restricted and field.null:
return False return False
return True return True

View File

@ -45,3 +45,6 @@ class RelatedObject(object):
return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') return self.field.rel.related_name or (self.opts.object_name.lower() + '_set')
else: else:
return self.field.rel.related_name or (self.opts.object_name.lower()) return self.field.rel.related_name or (self.opts.object_name.lower())
def get_cache_name(self):
return "_%s_cache" % self.get_accessor_name()

View File

@ -520,7 +520,7 @@ class SQLCompiler(object):
# Setup for the case when only particular related fields should be # Setup for the case when only particular related fields should be
# included in the related selection. # included in the related selection.
if requested is None and restricted is not False: if requested is None:
if isinstance(self.query.select_related, dict): if isinstance(self.query.select_related, dict):
requested = self.query.select_related requested = self.query.select_related
restricted = True restricted = True
@ -600,6 +600,72 @@ class SQLCompiler(object):
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set, avoid) used, next, restricted, new_nullable, dupe_set, avoid)
if restricted:
related_fields = [
(o.field, o.model)
for o in opts.get_all_related_objects()
if o.field.unique
]
for f, model in related_fields:
if not select_related_descend(f, restricted, requested, reverse=True):
continue
# The "avoid" set is aliases we want to avoid just for this
# particular branch of the recursion. They aren't permanently
# forbidden from reuse in the related selection tables (which is
# what "used" specifies).
avoid = avoid_set.copy()
dupe_set = orig_dupe_set.copy()
table = model._meta.db_table
int_opts = opts
alias = root_alias
alias_chain = []
chain = opts.get_base_chain(f.rel.to)
if chain is not None:
for int_model in chain:
# Proxy model have elements in base chain
# with no parents, assign the new options
# object and skip to the next base in that
# case
if not int_opts.parents[int_model]:
int_opts = int_model._meta
continue
lhs_col = int_opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col),
())
dupe_set.add((opts, lhs_col))
int_opts = int_model._meta
alias = self.query.join(
(alias, int_opts.db_table, lhs_col, int_opts.pk.column),
exclusions=used, promote=True, reuse=used
)
alias_chain.append(alias)
for dupe_opts, dupe_col in dupe_set:
self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
dedupe = f.column in opts.duplicate_targets
if dupe_set or dedupe:
avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
if dedupe:
dupe_set.add((opts, f.column))
alias = self.query.join(
(alias, table, f.rel.get_related_field().column, f.column),
exclusions=used.union(avoid),
promote=True
)
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=model._meta, as_pairs=True)
self.query.related_select_cols.extend(columns)
self.query.related_select_fields.extend(model._meta.fields)
next = requested.get(f.related_query_name(), {})
new_nullable = f.null or None
self.fill_related_selections(model._meta, table, cur_depth+1,
used, next, restricted, new_nullable)
def deferred_to_columns(self): def deferred_to_columns(self):
""" """
Converts the self.deferred_loading data structure to mapping of table Converts the self.deferred_loading data structure to mapping of table

View File

@ -619,17 +619,29 @@ This is also valid::
...and would also pull in the ``building`` relation. ...and would also pull in the ``building`` relation.
You can only refer to ``ForeignKey`` relations in the list of fields passed to You can refer to any ``ForeignKey`` or ``OneToOneField`` relation in
``select_related``. You *can* refer to foreign keys that have ``null=True`` the list of fields passed to ``select_related``. Ths includes foreign
(unlike the default ``select_related()`` call). It's an error to use both a keys that have ``null=True`` (unlike the default ``select_related()``
list of fields and the ``depth`` parameter in the same ``select_related()`` call). It's an error to use both a list of fields and the ``depth``
call, since they are conflicting options. parameter in the same ``select_related()`` call, since they are
conflicting options.
.. versionadded:: 1.0 .. versionadded:: 1.0
Both the ``depth`` argument and the ability to specify field names in the call Both the ``depth`` argument and the ability to specify field names in the call
to ``select_related()`` are new in Django version 1.0. to ``select_related()`` are new in Django version 1.0.
.. versionchanged:: 1.2
You can also refer to the reverse direction of a ``OneToOneFields`` in
the list of fields passed to ``select_related`` -- that is, you can traverse
a ``OneToOneField`` back to the object on which the field is defined. Instead
of specifying the field name, use the ``related_name`` for the field on the
related object.
``OneToOneFields`` will not be traversed in the reverse direction if you
are performing a depth-based ``select_related``.
.. _queryset-extra: .. _queryset-extra:
``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)`` ``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)``
@ -1335,7 +1347,7 @@ extract two field values, where only one is expected::
entries = Entry.objects.filter(blog__in=list(values)) entries = Entry.objects.filter(blog__in=list(values))
Note the ``list()`` call around the Blog ``QuerySet`` to force execution of Note the ``list()`` call around the Blog ``QuerySet`` to force execution of
the first query. Without it, a nested query would be executed, because the first query. Without it, a nested query would be executed, because
:ref:`querysets-are-lazy`. :ref:`querysets-are-lazy`.
gt gt

View File

@ -0,0 +1,46 @@
from django.db import models
class User(models.Model):
username = models.CharField(max_length=100)
email = models.EmailField()
def __unicode__(self):
return self.username
class UserProfile(models.Model):
user = models.OneToOneField(User)
city = models.CharField(max_length=100)
state = models.CharField(max_length=2)
def __unicode__(self):
return "%s, %s" % (self.city, self.state)
class UserStatResult(models.Model):
results = models.CharField(max_length=50)
def __unicode__(self):
return 'UserStatResults, results = %s' % (self.results,)
class UserStat(models.Model):
user = models.OneToOneField(User, primary_key=True)
posts = models.IntegerField()
results = models.ForeignKey(UserStatResult)
def __unicode__(self):
return 'UserStat, posts = %s' % (self.posts,)
class StatDetails(models.Model):
base_stats = models.OneToOneField(UserStat)
comments = models.IntegerField()
def __unicode__(self):
return 'StatDetails, comments = %s' % (self.comments,)
class AdvancedUserStat(UserStat):
pass

View File

@ -0,0 +1,83 @@
from django import db
from django.conf import settings
from django.test import TestCase
from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat
class ReverseSelectRelatedTestCase(TestCase):
def setUp(self):
# Explicitly enable debug for these tests - we need to count
# the queries that have been issued.
self.old_debug = settings.DEBUG
settings.DEBUG = True
user = User.objects.create(username="test")
userprofile = UserProfile.objects.create(user=user, state="KS",
city="Lawrence")
results = UserStatResult.objects.create(results='first results')
userstat = UserStat.objects.create(user=user, posts=150,
results=results)
details = StatDetails.objects.create(base_stats=userstat, comments=259)
user2 = User.objects.create(username="bob")
results2 = UserStatResult.objects.create(results='moar results')
advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250)
db.reset_queries()
def assertQueries(self, queries):
self.assertEqual(len(db.connection.queries), queries)
def tearDown(self):
settings.DEBUG = self.old_debug
def test_basic(self):
u = User.objects.select_related("userprofile").get(username="test")
self.assertEqual(u.userprofile.state, "KS")
self.assertQueries(1)
def test_follow_next_level(self):
u = User.objects.select_related("userstat__results").get(username="test")
self.assertEqual(u.userstat.posts, 150)
self.assertEqual(u.userstat.results.results, 'first results')
self.assertQueries(1)
def test_follow_two(self):
u = User.objects.select_related("userprofile", "userstat").get(username="test")
self.assertEqual(u.userprofile.state, "KS")
self.assertEqual(u.userstat.posts, 150)
self.assertQueries(1)
def test_follow_two_next_level(self):
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
self.assertEqual(u.userstat.results.results, 'first results')
self.assertEqual(u.userstat.statdetails.comments, 259)
self.assertQueries(1)
def test_forward_and_back(self):
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
self.assertEqual(stat.user.userprofile.state, 'KS')
self.assertEqual(stat.user.userstat.posts, 150)
self.assertQueries(1)
def test_back_and_forward(self):
u = User.objects.select_related("userstat").get(username="test")
self.assertEqual(u.userstat.user.username, 'test')
self.assertQueries(1)
def test_not_followed_by_default(self):
u = User.objects.select_related().get(username="test")
self.assertEqual(u.userstat.posts, 150)
self.assertQueries(2)
def test_follow_from_child_class(self):
stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
self.assertEqual(stat.statdetails.comments, 250)
self.assertQueries(1)
def test_follow_inheritance(self):
stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertQueries(1)