From 53a5fb3cc0137bebeebc0d4d321dbfe20397b065 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 30 May 2016 00:11:31 -0400 Subject: [PATCH] Fixed #26676 -- Prevented prefetching to_attr from caching its result in through attr. Thanks Ursidours for the report. --- django/db/models/query.py | 19 +++++++++++-------- tests/prefetch_related/tests.py | 10 ++++++++-- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 8bd981db511..2f4ef52e6c5 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1451,7 +1451,8 @@ def prefetch_related_objects(model_instances, *related_lookups): # We assume that objects retrieved are homogeneous (which is the premise # of prefetch_related), so what applies to first object applies to all. first_obj = obj_list[0] - prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, through_attr) + to_attr = lookup.get_current_to_attr(level)[0] + prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, through_attr, to_attr) if not attr_found: raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid " @@ -1504,9 +1505,9 @@ def prefetch_related_objects(model_instances, *related_lookups): obj_list = new_obj_list -def get_prefetcher(instance, attr): +def get_prefetcher(instance, through_attr, to_attr): """ - For the attribute 'attr' on the given instance, finds + For the attribute 'through_attr' on the given instance, finds an object that has a get_prefetch_queryset(). Returns a 4 tuple containing: (the object with get_prefetch_queryset (or None), @@ -1520,9 +1521,9 @@ def get_prefetcher(instance, attr): # For singly related objects, we have to avoid getting the attribute # from the object, as this will trigger the query. So we first try # on the class, in order to get the descriptor object. - rel_obj_descriptor = getattr(instance.__class__, attr, None) + rel_obj_descriptor = getattr(instance.__class__, through_attr, None) if rel_obj_descriptor is None: - attr_found = hasattr(instance, attr) + attr_found = hasattr(instance, through_attr) else: attr_found = True if rel_obj_descriptor: @@ -1536,10 +1537,13 @@ def get_prefetcher(instance, attr): # descriptor doesn't support prefetching, so we go ahead and get # the attribute on the instance rather than the class to # support many related managers - rel_obj = getattr(instance, attr) + rel_obj = getattr(instance, through_attr) if hasattr(rel_obj, 'get_prefetch_queryset'): prefetcher = rel_obj - is_fetched = attr in instance._prefetched_objects_cache + if through_attr != to_attr: + is_fetched = hasattr(instance, to_attr) + else: + is_fetched = through_attr in instance._prefetched_objects_cache return prefetcher, rel_obj_descriptor, attr_found, is_fetched @@ -1619,7 +1623,6 @@ def prefetch_one_level(instances, prefetcher, lookup, level): else: if as_attr: setattr(obj, to_attr, vals) - obj._prefetched_objects_cache[cache_name] = vals else: manager = getattr(obj, to_attr) if leaf and lookup.queryset is not None: diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index 41eee219b80..c34682a33d8 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -5,7 +5,7 @@ import warnings from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ObjectDoesNotExist from django.db import connection -from django.db.models import Prefetch +from django.db.models import Prefetch, QuerySet from django.db.models.query import get_prefetcher from django.test import TestCase, override_settings from django.test.utils import CaptureQueriesContext @@ -737,6 +737,12 @@ class CustomPrefetchTests(TestCase): with self.assertRaisesMessage(ValueError, 'Prefetch querysets cannot use values().'): Prefetch('houses', House.objects.values('pk')) + def test_to_attr_doesnt_cache_through_attr_as_list(self): + house = House.objects.prefetch_related( + Prefetch('rooms', queryset=Room.objects.all(), to_attr='to_rooms'), + ).get(pk=self.house3.pk) + self.assertIsInstance(house.rooms.all(), QuerySet) + class DefaultManagerTests(TestCase): @@ -1268,7 +1274,7 @@ class Ticket21760Tests(TestCase): house.save() def test_bug(self): - prefetcher = get_prefetcher(self.rooms[0], 'house')[0] + prefetcher = get_prefetcher(self.rooms[0], 'house', 'house')[0] queryset = prefetcher.get_prefetch_queryset(list(Room.objects.all()))[0] self.assertNotIn(' JOIN ', force_text(queryset.query))