From 58cd220f51d5e294cb9e67c12a6e9d08523e282f Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Wed, 27 Jan 2010 13:30:29 +0000 Subject: [PATCH] Fixed #7270 -- Added the ability to follow reverse OneToOneFields in select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch. git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/fields/related.py | 4 +- django/db/models/query.py | 74 ++++++++++++++++- django/db/models/query_utils.py | 18 +++- django/db/models/related.py | 3 + django/db/models/sql/compiler.py | 68 ++++++++++++++- docs/ref/models/querysets.txt | 24 ++++-- .../select_related_onetoone/__init__.py | 0 .../select_related_onetoone/models.py | 46 ++++++++++ .../select_related_onetoone/tests.py | 83 +++++++++++++++++++ 9 files changed, 306 insertions(+), 14 deletions(-) create mode 100644 tests/regressiontests/select_related_onetoone/__init__.py create mode 100644 tests/regressiontests/select_related_onetoone/models.py create mode 100644 tests/regressiontests/select_related_onetoone/tests.py diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4020d5e268..5de6fb1067 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object): # SingleRelatedObjectDescriptor instance. def __init__(self, related): self.related = related - self.cache_name = '_%s_cache' % related.get_accessor_name() + self.cache_name = related.get_cache_name() def __get__(self, instance, instance_type=None): if instance is None: @@ -319,7 +319,7 @@ class ReverseSingleRelatedObjectDescriptor(object): # cache. This cache also might not exist if the related object # hasn't been accessed yet. if related: - cache_name = '_%s_cache' % self.field.related.get_accessor_name() + cache_name = self.field.related.get_cache_name() try: delattr(related, cache_name) except AttributeError: diff --git a/django/db/models/query.py b/django/db/models/query.py index 3b290a6457..8cb3dbecfc 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, """ Helper function that recursively returns an object with the specified related attributes already populated. + + This method may be called recursively to populate deep select_related() + clauses. + + Arguments: + * klass - the class to retrieve (and instantiate) + * row - the row of data returned by the database cursor + * index_start - the index of the row at which data for this + object is known to start + * max_depth - the maximum depth to which a select_related() + relationship should be explored. + * cur_depth - the current depth in the select_related() tree. + Used in recursive calls to determin if we should dig deeper. + * requested - A dictionary describing the select_related() tree + that is to be retrieved. keys are field names; values are + dictionaries describing the keys on that related object that + are themselves to be select_related(). + * offset - the number of additional fields that are known to + exist in `row` for `klass`. This usually means the number of + annotated results on `klass`. + * only_load - if the query has had only() or defer() applied, + this is the list of field names that will be returned. If None, + the full field list for `klass` can be assumed. """ if max_depth and requested is None and cur_depth > max_depth: # We've recursed deeply enough; stop now. @@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, # Handle deferred fields. skip = set() init_list = [] - pk_val = row[index_start + klass._meta.pk_index()] + # Build the list of fields that *haven't* been requested for field in klass._meta.fields: if field.name not in load_fields: skip.add(field.name) else: init_list.append(field.attname) + # Retrieve all the requested fields field_count = len(init_list) fields = row[index_start : index_start + field_count] + # If all the select_related columns are None, then the related + # object must be non-existent - set the relation to None. + # Otherwise, construct the related object. if fields == (None,) * field_count: obj = None elif skip: @@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, else: obj = klass(*fields) else: + # Load all fields on klass field_count = len(klass._meta.fields) fields = row[index_start : index_start + field_count] + # If all the select_related columns are None, then the related + # object must be non-existent - set the relation to None. + # Otherwise, construct the related object. if fields == (None,) * field_count: obj = None else: obj = klass(*fields) index_end = index_start + field_count + offset + # Iterate over each related object, populating any + # select_related() fields for f in klass._meta.fields: if not select_related_descend(f, restricted, requested): continue @@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, next = requested[f.name] else: next = None + # Recursively retrieve the data for the related object cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1, next) + # If the recursive descent found an object, populate the + # descriptor caches relevant to the object if cached_row: rel_obj, index_end = cached_row if obj is not None: + # If the base object exists, populate the + # descriptor cache setattr(obj, f.get_cache_name(), rel_obj) + if f.unique: + # If the field is unique, populate the + # reverse descriptor cache on the related object + setattr(rel_obj, f.related.get_cache_name(), obj) + + # Now do the same, but for reverse related objects. + # Only handle the restricted case - i.e., don't do a depth + # descent into reverse relations unless explicitly requested + if restricted: + related_fields = [ + (o.field, o.model) + for o in klass._meta.get_all_related_objects() + if o.field.unique + ] + for f, model in related_fields: + if not select_related_descend(f, restricted, requested, reverse=True): + continue + next = requested[f.related_query_name()] + # Recursively retrieve the data for the related object + cached_row = get_cached_row(model, row, index_end, max_depth, + cur_depth+1, next) + # If the recursive descent found an object, populate the + # descriptor caches relevant to the object + if cached_row: + rel_obj, index_end = cached_row + if obj is not None: + # If the field is unique, populate the + # reverse descriptor cache + setattr(obj, f.related.get_cache_name(), rel_obj) + if rel_obj is not None: + # If the related object exists, populate + # the descriptor cache. + setattr(rel_obj, f.get_cache_name(), obj) + return obj, index_end def delete_objects(seen_objs, using): diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 9f6083ce7e..8e804ec3ef 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -197,19 +197,29 @@ class DeferredAttribute(object): """ instance.__dict__[self.field_name] = value -def select_related_descend(field, restricted, requested): +def select_related_descend(field, restricted, requested, reverse=False): """ Returns True if this field should be used to descend deeper for select_related() purposes. Used by both the query construction code (sql.query.fill_related_selections()) and the model instance creation code (query.get_cached_row()). + + Arguments: + * field - the field to be checked + * restricted - a boolean field, indicating if the field list has been + manually restricted using a requested clause) + * requested - The select_related() dictionary. + * reverse - boolean, True if we are checking a reverse select related """ if not field.rel: return False - if field.rel.parent_link: - return False - if restricted and field.name not in requested: + if field.rel.parent_link and not reverse: return False + if restricted: + if reverse and field.related_query_name() not in requested: + return False + if not reverse and field.name not in requested: + return False if not restricted and field.null: return False return True diff --git a/django/db/models/related.py b/django/db/models/related.py index afdf3f7b61..e4afd8a6f8 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -45,3 +45,6 @@ class RelatedObject(object): return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') else: return self.field.rel.related_name or (self.opts.object_name.lower()) + + def get_cache_name(self): + return "_%s_cache" % self.get_accessor_name() diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 6a95d32259..1625a0e6c9 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -520,7 +520,7 @@ class SQLCompiler(object): # Setup for the case when only particular related fields should be # included in the related selection. - if requested is None and restricted is not False: + if requested is None: if isinstance(self.query.select_related, dict): requested = self.query.select_related restricted = True @@ -600,6 +600,72 @@ class SQLCompiler(object): self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, used, next, restricted, new_nullable, dupe_set, avoid) + if restricted: + related_fields = [ + (o.field, o.model) + for o in opts.get_all_related_objects() + if o.field.unique + ] + for f, model in related_fields: + if not select_related_descend(f, restricted, requested, reverse=True): + continue + # The "avoid" set is aliases we want to avoid just for this + # particular branch of the recursion. They aren't permanently + # forbidden from reuse in the related selection tables (which is + # what "used" specifies). + avoid = avoid_set.copy() + dupe_set = orig_dupe_set.copy() + table = model._meta.db_table + + int_opts = opts + alias = root_alias + alias_chain = [] + chain = opts.get_base_chain(f.rel.to) + if chain is not None: + for int_model in chain: + # Proxy model have elements in base chain + # with no parents, assign the new options + # object and skip to the next base in that + # case + if not int_opts.parents[int_model]: + int_opts = int_model._meta + continue + lhs_col = int_opts.parents[int_model].column + dedupe = lhs_col in opts.duplicate_targets + if dedupe: + avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), + ()) + dupe_set.add((opts, lhs_col)) + int_opts = int_model._meta + alias = self.query.join( + (alias, int_opts.db_table, lhs_col, int_opts.pk.column), + exclusions=used, promote=True, reuse=used + ) + alias_chain.append(alias) + for dupe_opts, dupe_col in dupe_set: + self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) + dedupe = f.column in opts.duplicate_targets + if dupe_set or dedupe: + avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) + if dedupe: + dupe_set.add((opts, f.column)) + alias = self.query.join( + (alias, table, f.rel.get_related_field().column, f.column), + exclusions=used.union(avoid), + promote=True + ) + used.add(alias) + columns, aliases = self.get_default_columns(start_alias=alias, + opts=model._meta, as_pairs=True) + self.query.related_select_cols.extend(columns) + self.query.related_select_fields.extend(model._meta.fields) + + next = requested.get(f.related_query_name(), {}) + new_nullable = f.null or None + + self.fill_related_selections(model._meta, table, cur_depth+1, + used, next, restricted, new_nullable) + def deferred_to_columns(self): """ Converts the self.deferred_loading data structure to mapping of table diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 4740d9ca10..db2fa5687c 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -619,17 +619,29 @@ This is also valid:: ...and would also pull in the ``building`` relation. -You can only refer to ``ForeignKey`` relations in the list of fields passed to -``select_related``. You *can* refer to foreign keys that have ``null=True`` -(unlike the default ``select_related()`` call). It's an error to use both a -list of fields and the ``depth`` parameter in the same ``select_related()`` -call, since they are conflicting options. +You can refer to any ``ForeignKey`` or ``OneToOneField`` relation in +the list of fields passed to ``select_related``. Ths includes foreign +keys that have ``null=True`` (unlike the default ``select_related()`` +call). It's an error to use both a list of fields and the ``depth`` +parameter in the same ``select_related()`` call, since they are +conflicting options. .. versionadded:: 1.0 Both the ``depth`` argument and the ability to specify field names in the call to ``select_related()`` are new in Django version 1.0. +.. versionchanged:: 1.2 + +You can also refer to the reverse direction of a ``OneToOneFields`` in +the list of fields passed to ``select_related`` -- that is, you can traverse +a ``OneToOneField`` back to the object on which the field is defined. Instead +of specifying the field name, use the ``related_name`` for the field on the +related object. + +``OneToOneFields`` will not be traversed in the reverse direction if you +are performing a depth-based ``select_related``. + .. _queryset-extra: ``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)`` @@ -1335,7 +1347,7 @@ extract two field values, where only one is expected:: entries = Entry.objects.filter(blog__in=list(values)) Note the ``list()`` call around the Blog ``QuerySet`` to force execution of - the first query. Without it, a nested query would be executed, because + the first query. Without it, a nested query would be executed, because :ref:`querysets-are-lazy`. gt diff --git a/tests/regressiontests/select_related_onetoone/__init__.py b/tests/regressiontests/select_related_onetoone/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py new file mode 100644 index 0000000000..0014278768 --- /dev/null +++ b/tests/regressiontests/select_related_onetoone/models.py @@ -0,0 +1,46 @@ +from django.db import models + + +class User(models.Model): + username = models.CharField(max_length=100) + email = models.EmailField() + + def __unicode__(self): + return self.username + + +class UserProfile(models.Model): + user = models.OneToOneField(User) + city = models.CharField(max_length=100) + state = models.CharField(max_length=2) + + def __unicode__(self): + return "%s, %s" % (self.city, self.state) + + +class UserStatResult(models.Model): + results = models.CharField(max_length=50) + + def __unicode__(self): + return 'UserStatResults, results = %s' % (self.results,) + + +class UserStat(models.Model): + user = models.OneToOneField(User, primary_key=True) + posts = models.IntegerField() + results = models.ForeignKey(UserStatResult) + + def __unicode__(self): + return 'UserStat, posts = %s' % (self.posts,) + + +class StatDetails(models.Model): + base_stats = models.OneToOneField(UserStat) + comments = models.IntegerField() + + def __unicode__(self): + return 'StatDetails, comments = %s' % (self.comments,) + + +class AdvancedUserStat(UserStat): + pass diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py new file mode 100644 index 0000000000..b9e5beb7e9 --- /dev/null +++ b/tests/regressiontests/select_related_onetoone/tests.py @@ -0,0 +1,83 @@ +from django import db +from django.conf import settings +from django.test import TestCase + +from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat + +class ReverseSelectRelatedTestCase(TestCase): + def setUp(self): + # Explicitly enable debug for these tests - we need to count + # the queries that have been issued. + self.old_debug = settings.DEBUG + settings.DEBUG = True + + user = User.objects.create(username="test") + userprofile = UserProfile.objects.create(user=user, state="KS", + city="Lawrence") + results = UserStatResult.objects.create(results='first results') + userstat = UserStat.objects.create(user=user, posts=150, + results=results) + details = StatDetails.objects.create(base_stats=userstat, comments=259) + + user2 = User.objects.create(username="bob") + results2 = UserStatResult.objects.create(results='moar results') + advstat = AdvancedUserStat.objects.create(user=user2, posts=200, + results=results2) + StatDetails.objects.create(base_stats=advstat, comments=250) + + db.reset_queries() + + def assertQueries(self, queries): + self.assertEqual(len(db.connection.queries), queries) + + def tearDown(self): + settings.DEBUG = self.old_debug + + def test_basic(self): + u = User.objects.select_related("userprofile").get(username="test") + self.assertEqual(u.userprofile.state, "KS") + self.assertQueries(1) + + def test_follow_next_level(self): + u = User.objects.select_related("userstat__results").get(username="test") + self.assertEqual(u.userstat.posts, 150) + self.assertEqual(u.userstat.results.results, 'first results') + self.assertQueries(1) + + def test_follow_two(self): + u = User.objects.select_related("userprofile", "userstat").get(username="test") + self.assertEqual(u.userprofile.state, "KS") + self.assertEqual(u.userstat.posts, 150) + self.assertQueries(1) + + def test_follow_two_next_level(self): + u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") + self.assertEqual(u.userstat.results.results, 'first results') + self.assertEqual(u.userstat.statdetails.comments, 259) + self.assertQueries(1) + + def test_forward_and_back(self): + stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") + self.assertEqual(stat.user.userprofile.state, 'KS') + self.assertEqual(stat.user.userstat.posts, 150) + self.assertQueries(1) + + def test_back_and_forward(self): + u = User.objects.select_related("userstat").get(username="test") + self.assertEqual(u.userstat.user.username, 'test') + self.assertQueries(1) + + def test_not_followed_by_default(self): + u = User.objects.select_related().get(username="test") + self.assertEqual(u.userstat.posts, 150) + self.assertQueries(2) + + def test_follow_from_child_class(self): + stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) + self.assertEqual(stat.statdetails.comments, 250) + self.assertQueries(1) + + def test_follow_inheritance(self): + stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) + self.assertEqual(stat.advanceduserstat.posts, 200) + self.assertQueries(1)