diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index a16f9553c6..c4f95a12d2 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -237,13 +237,18 @@ class SingleRelatedObjectDescriptor(object): return self.related.model._base_manager.using(db) def get_prefetch_query_set(self, instances): - vals = set(instance._get_pk_val() for instance in instances) - params = {'%s__pk__in' % self.related.field.name: vals} - return (self.get_query_set(instance=instances[0]).filter(**params), - attrgetter(self.related.field.attname), - lambda obj: obj._get_pk_val(), - True, - self.cache_name) + rel_obj_attr = attrgetter(self.related.field.attname) + instance_attr = lambda obj: obj._get_pk_val() + instances_dict = dict((instance_attr(inst), inst) for inst in instances) + params = {'%s__pk__in' % self.related.field.name: instances_dict.keys()} + qs = self.get_query_set(instance=instances[0]).filter(**params) + # Since we're going to assign directly in the cache, + # we must manage the reverse relation cache manually. + rel_obj_cache_name = self.related.field.get_cache_name() + for rel_obj in qs: + instance = instances_dict[rel_obj_attr(rel_obj)] + setattr(rel_obj, rel_obj_cache_name, instance) + return qs, rel_obj_attr, instance_attr, True, self.cache_name def __get__(self, instance, instance_type=None): if instance is None: @@ -324,17 +329,23 @@ class ReverseSingleRelatedObjectDescriptor(object): return QuerySet(self.field.rel.to).using(db) def get_prefetch_query_set(self, instances): - vals = set(getattr(instance, self.field.attname) for instance in instances) + rel_obj_attr = attrgetter(self.field.rel.field_name) + instance_attr = attrgetter(self.field.attname) + instances_dict = dict((instance_attr(inst), inst) for inst in instances) other_field = self.field.rel.get_related_field() if other_field.rel: - params = {'%s__pk__in' % self.field.rel.field_name: vals} + params = {'%s__pk__in' % self.field.rel.field_name: instances_dict.keys()} else: - params = {'%s__in' % self.field.rel.field_name: vals} - return (self.get_query_set(instance=instances[0]).filter(**params), - attrgetter(self.field.rel.field_name), - attrgetter(self.field.attname), - True, - self.cache_name) + params = {'%s__in' % self.field.rel.field_name: instances_dict.keys()} + qs = self.get_query_set(instance=instances[0]).filter(**params) + # Since we're going to assign directly in the cache, + # we must manage the reverse relation cache manually. + if not self.field.rel.multiple: + rel_obj_cache_name = self.field.related.get_cache_name() + for rel_obj in qs: + instance = instances_dict[rel_obj_attr(rel_obj)] + setattr(rel_obj, rel_obj_cache_name, instance) + return qs, rel_obj_attr, instance_attr, True, self.cache_name def __get__(self, instance, instance_type=None): if instance is None: @@ -467,18 +478,24 @@ class ForeignRelatedObjectsDescriptor(object): return self.instance._prefetched_objects_cache[rel_field.related_query_name()] except (AttributeError, KeyError): db = self._db or router.db_for_read(self.model, instance=self.instance) - return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) + qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) + qs._known_related_object = (rel_field.name, self.instance) + return qs def get_prefetch_query_set(self, instances): + rel_obj_attr = attrgetter(rel_field.get_attname()) + instance_attr = attrgetter(attname) + instances_dict = dict((instance_attr(inst), inst) for inst in instances) db = self._db or router.db_for_read(self.model, instance=instances[0]) - query = {'%s__%s__in' % (rel_field.name, attname): - set(getattr(obj, attname) for obj in instances)} + query = {'%s__%s__in' % (rel_field.name, attname): instances_dict.keys()} qs = super(RelatedManager, self).get_query_set().using(db).filter(**query) - return (qs, - attrgetter(rel_field.get_attname()), - attrgetter(attname), - False, - rel_field.related_query_name()) + # Since we just bypassed this class' get_query_set(), we must manage + # the reverse relation manually. + for rel_obj in qs: + instance = instances_dict[rel_obj_attr(rel_obj)] + setattr(rel_obj, rel_field.name, instance) + cache_name = rel_field.related_query_name() + return qs, rel_obj_attr, instance_attr, False, cache_name def add(self, *objs): for obj in objs: diff --git a/django/db/models/query.py b/django/db/models/query.py index 65a36975d5..755820c3b0 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -41,6 +41,7 @@ class QuerySet(object): self._for_write = False self._prefetch_related_lookups = [] self._prefetch_done = False + self._known_related_object = None # (attname, rel_obj) ######################## # PYTHON MAGIC METHODS # @@ -282,9 +283,10 @@ class QuerySet(object): init_list.append(field.attname) model_cls = deferred_class_factory(self.model, skip) - # Cache db and model outside the loop + # Cache db, model and known_related_object outside the loop db = self.db model = self.model + kro_attname, kro_instance = self._known_related_object or (None, None) compiler = self.query.get_compiler(using=db) if fill_cache: klass_info = get_klass_info(model, max_depth=max_depth, @@ -294,12 +296,12 @@ class QuerySet(object): obj, _ = get_cached_row(row, index_start, db, klass_info, offset=len(aggregate_select)) else: + # Omit aggregates in object creation. + row_data = row[index_start:aggregate_start] if skip: - row_data = row[index_start:aggregate_start] obj = model_cls(**dict(zip(init_list, row_data))) else: - # Omit aggregates in object creation. - obj = model(*row[index_start:aggregate_start]) + obj = model(*row_data) # Store the source database of the object obj._state.db = db @@ -313,7 +315,11 @@ class QuerySet(object): # Add the aggregates to the model if aggregate_select: for i, aggregate in enumerate(aggregate_select): - setattr(obj, aggregate, row[i+aggregate_start]) + setattr(obj, aggregate, row[i + aggregate_start]) + + # Add the known related object to the model, if there is one + if kro_instance: + setattr(obj, kro_attname, kro_instance) yield obj @@ -864,6 +870,7 @@ class QuerySet(object): c = klass(model=self.model, query=query, using=self._db) c._for_write = self._for_write c._prefetch_related_lookups = self._prefetch_related_lookups[:] + c._known_related_object = self._known_related_object c.__dict__.update(kwargs) if setup and hasattr(c, '_setup_query'): c._setup_query() @@ -1781,9 +1788,7 @@ def prefetch_one_level(instances, prefetcher, attname): rel_obj_cache = {} for rel_obj in all_related_objects: rel_attr_val = rel_obj_attr(rel_obj) - if rel_attr_val not in rel_obj_cache: - rel_obj_cache[rel_attr_val] = [] - rel_obj_cache[rel_attr_val].append(rel_obj) + rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj) for obj in instances: instance_attr_val = instance_attr(obj) diff --git a/docs/releases/1.5.txt b/docs/releases/1.5.txt index e11991b257..46b599a622 100644 --- a/docs/releases/1.5.txt +++ b/docs/releases/1.5.txt @@ -44,6 +44,24 @@ reasons or when trying to avoid overwriting concurrent changes. See the :meth:`Model.save() ` documentation for more details. +Caching of related model instances +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When traversing relations, the ORM will avoid re-fetching objects that were +previously loaded. For example, with the tutorial's models:: + + >>> first_poll = Poll.objects.all()[0] + >>> first_choice = first_poll.choice_set.all()[0] + >>> first_choice.poll is first_poll + True + +In Django 1.5, the third line no longer triggers a new SQL query to fetch +``first_choice.poll``; it was set when by the second line. + +For one-to-one relationships, both sides can be cached. For many-to-one +relationships, only the single side of the relationship can be cached. This +is particularly helpful in combination with ``prefetch_related``. + Minor features ~~~~~~~~~~~~~~ diff --git a/tests/modeltests/known_related_objects/__init__.py b/tests/modeltests/known_related_objects/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/modeltests/known_related_objects/fixtures/tournament.json b/tests/modeltests/known_related_objects/fixtures/tournament.json new file mode 100644 index 0000000000..2f2b1c5627 --- /dev/null +++ b/tests/modeltests/known_related_objects/fixtures/tournament.json @@ -0,0 +1,65 @@ +[ + { + "pk": 1, + "model": "known_related_objects.tournament", + "fields": { + "name": "Tourney 1" + } + }, + { + "pk": 2, + "model": "known_related_objects.tournament", + "fields": { + "name": "Tourney 2" + } + }, + { + "pk": 1, + "model": "known_related_objects.pool", + "fields": { + "tournament": 1, + "name": "T1 Pool 1" + } + }, + { + "pk": 2, + "model": "known_related_objects.pool", + "fields": { + "tournament": 1, + "name": "T1 Pool 2" + } + }, + { + "pk": 3, + "model": "known_related_objects.pool", + "fields": { + "tournament": 2, + "name": "T2 Pool 1" + } + }, + { + "pk": 4, + "model": "known_related_objects.pool", + "fields": { + "tournament": 2, + "name": "T2 Pool 2" + } + }, + { + "pk": 1, + "model": "known_related_objects.poolstyle", + "fields": { + "name": "T1 Pool 2 Style", + "pool": 2 + } + }, + { + "pk": 2, + "model": "known_related_objects.poolstyle", + "fields": { + "name": "T2 Pool 1 Style", + "pool": 3 + } + } +] + diff --git a/tests/modeltests/known_related_objects/models.py b/tests/modeltests/known_related_objects/models.py new file mode 100644 index 0000000000..4c516dd7e8 --- /dev/null +++ b/tests/modeltests/known_related_objects/models.py @@ -0,0 +1,19 @@ +""" +Existing related object instance caching. + +Test that queries are not redone when going back through known relations. +""" + +from django.db import models + +class Tournament(models.Model): + name = models.CharField(max_length=30) + +class Pool(models.Model): + name = models.CharField(max_length=30) + tournament = models.ForeignKey(Tournament) + +class PoolStyle(models.Model): + name = models.CharField(max_length=30) + pool = models.OneToOneField(Pool) + diff --git a/tests/modeltests/known_related_objects/tests.py b/tests/modeltests/known_related_objects/tests.py new file mode 100644 index 0000000000..24feab2241 --- /dev/null +++ b/tests/modeltests/known_related_objects/tests.py @@ -0,0 +1,88 @@ +from __future__ import absolute_import + +from django.test import TestCase + +from .models import Tournament, Pool, PoolStyle + +class ExistingRelatedInstancesTests(TestCase): + fixtures = ['tournament.json'] + + def test_foreign_key(self): + with self.assertNumQueries(2): + tournament = Tournament.objects.get(pk=1) + pool = tournament.pool_set.all()[0] + self.assertIs(tournament, pool.tournament) + + def test_foreign_key_prefetch_related(self): + with self.assertNumQueries(2): + tournament = (Tournament.objects.prefetch_related('pool_set').get(pk=1)) + pool = tournament.pool_set.all()[0] + self.assertIs(tournament, pool.tournament) + + def test_foreign_key_multiple_prefetch(self): + with self.assertNumQueries(2): + tournaments = list(Tournament.objects.prefetch_related('pool_set')) + pool1 = tournaments[0].pool_set.all()[0] + self.assertIs(tournaments[0], pool1.tournament) + pool2 = tournaments[1].pool_set.all()[0] + self.assertIs(tournaments[1], pool2.tournament) + + def test_one_to_one(self): + with self.assertNumQueries(2): + style = PoolStyle.objects.get(pk=1) + pool = style.pool + self.assertIs(style, pool.poolstyle) + + def test_one_to_one_select_related(self): + with self.assertNumQueries(1): + style = PoolStyle.objects.select_related('pool').get(pk=1) + pool = style.pool + self.assertIs(style, pool.poolstyle) + + def test_one_to_one_multi_select_related(self): + with self.assertNumQueries(1): + poolstyles = list(PoolStyle.objects.select_related('pool')) + self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle) + self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle) + + def test_one_to_one_prefetch_related(self): + with self.assertNumQueries(2): + style = PoolStyle.objects.prefetch_related('pool').get(pk=1) + pool = style.pool + self.assertIs(style, pool.poolstyle) + + def test_one_to_one_multi_prefetch_related(self): + with self.assertNumQueries(2): + poolstyles = list(PoolStyle.objects.prefetch_related('pool')) + self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle) + self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle) + + def test_reverse_one_to_one(self): + with self.assertNumQueries(2): + pool = Pool.objects.get(pk=2) + style = pool.poolstyle + self.assertIs(pool, style.pool) + + def test_reverse_one_to_one_select_related(self): + with self.assertNumQueries(1): + pool = Pool.objects.select_related('poolstyle').get(pk=2) + style = pool.poolstyle + self.assertIs(pool, style.pool) + + def test_reverse_one_to_one_prefetch_related(self): + with self.assertNumQueries(2): + pool = Pool.objects.prefetch_related('poolstyle').get(pk=2) + style = pool.poolstyle + self.assertIs(pool, style.pool) + + def test_reverse_one_to_one_multi_select_related(self): + with self.assertNumQueries(1): + pools = list(Pool.objects.select_related('poolstyle')) + self.assertIs(pools[1], pools[1].poolstyle.pool) + self.assertIs(pools[2], pools[2].poolstyle.pool) + + def test_reverse_one_to_one_multi_prefetch_related(self): + with self.assertNumQueries(2): + pools = list(Pool.objects.prefetch_related('poolstyle')) + self.assertIs(pools[1], pools[1].poolstyle.pool) + self.assertIs(pools[2], pools[2].poolstyle.pool)