diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4b6a5b0aed..9a657d9d26 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -496,7 +496,7 @@ class ForeignRelatedObjectsDescriptor(object): except (AttributeError, KeyError): db = self._db or router.db_for_read(self.model, instance=self.instance) qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) - qs._known_related_object = (rel_field.name, self.instance) + qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} return qs def get_prefetch_query_set(self, instances): diff --git a/django/db/models/query.py b/django/db/models/query.py index ee58a77886..26b93b20bf 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -44,7 +44,7 @@ class QuerySet(object): self._for_write = False self._prefetch_related_lookups = [] self._prefetch_done = False - self._known_related_object = None # (attname, rel_obj) + self._known_related_objects = {} # {rel_field, {pk: rel_obj}} ######################## # PYTHON MAGIC METHODS # @@ -221,6 +221,7 @@ class QuerySet(object): if isinstance(other, EmptyQuerySet): return other._clone() combined = self._clone() + combined._merge_known_related_objects(other) combined.query.combine(other.query, sql.AND) return combined @@ -229,6 +230,7 @@ class QuerySet(object): combined = self._clone() if isinstance(other, EmptyQuerySet): return combined + combined._merge_known_related_objects(other) combined.query.combine(other.query, sql.OR) return combined @@ -289,10 +291,9 @@ class QuerySet(object): init_list.append(field.attname) model_cls = deferred_class_factory(self.model, skip) - # Cache db, model and known_related_object outside the loop + # Cache db and model outside the loop db = self.db model = self.model - kro_attname, kro_instance = self._known_related_object or (None, None) compiler = self.query.get_compiler(using=db) if fill_cache: klass_info = get_klass_info(model, max_depth=max_depth, @@ -323,9 +324,16 @@ class QuerySet(object): for i, aggregate in enumerate(aggregate_select): setattr(obj, aggregate, row[i + aggregate_start]) - # Add the known related object to the model, if there is one - if kro_instance: - setattr(obj, kro_attname, kro_instance) + # Add the known related objects to the model, if there are any + if self._known_related_objects: + for field, rel_objs in self._known_related_objects.items(): + pk = getattr(obj, field.get_attname()) + try: + rel_obj = rel_objs[pk] + except KeyError: + pass # may happen in qs1 | qs2 scenarios + else: + setattr(obj, field.name, rel_obj) yield obj @@ -902,7 +910,7 @@ class QuerySet(object): c = klass(model=self.model, query=query, using=self._db) c._for_write = self._for_write c._prefetch_related_lookups = self._prefetch_related_lookups[:] - c._known_related_object = self._known_related_object + c._known_related_objects = self._known_related_objects c.__dict__.update(kwargs) if setup and hasattr(c, '_setup_query'): c._setup_query() @@ -942,6 +950,13 @@ class QuerySet(object): """ pass + def _merge_known_related_objects(self, other): + """ + Keep track of all known related objects from either QuerySet instance. + """ + for field, objects in other._known_related_objects.items(): + self._known_related_objects.setdefault(field, {}).update(objects) + def _setup_aggregate_query(self, aggregates): """ Prepare the query for computing a result that contains aggregate annotations. diff --git a/tests/modeltests/known_related_objects/fixtures/tournament.json b/tests/modeltests/known_related_objects/fixtures/tournament.json index 2f2b1c5627..b8f053e152 100644 --- a/tests/modeltests/known_related_objects/fixtures/tournament.json +++ b/tests/modeltests/known_related_objects/fixtures/tournament.json @@ -13,11 +13,19 @@ "name": "Tourney 2" } }, + { + "pk": 1, + "model": "known_related_objects.organiser", + "fields": { + "name": "Organiser 1" + } + }, { "pk": 1, "model": "known_related_objects.pool", "fields": { "tournament": 1, + "organiser": 1, "name": "T1 Pool 1" } }, @@ -26,6 +34,7 @@ "model": "known_related_objects.pool", "fields": { "tournament": 1, + "organiser": 1, "name": "T1 Pool 2" } }, @@ -34,6 +43,7 @@ "model": "known_related_objects.pool", "fields": { "tournament": 2, + "organiser": 1, "name": "T2 Pool 1" } }, @@ -42,6 +52,7 @@ "model": "known_related_objects.pool", "fields": { "tournament": 2, + "organiser": 1, "name": "T2 Pool 2" } }, diff --git a/tests/modeltests/known_related_objects/models.py b/tests/modeltests/known_related_objects/models.py index 4c516dd7e8..e256cc38f2 100644 --- a/tests/modeltests/known_related_objects/models.py +++ b/tests/modeltests/known_related_objects/models.py @@ -9,9 +9,13 @@ from django.db import models class Tournament(models.Model): name = models.CharField(max_length=30) +class Organiser(models.Model): + name = models.CharField(max_length=30) + class Pool(models.Model): name = models.CharField(max_length=30) tournament = models.ForeignKey(Tournament) + organiser = models.ForeignKey(Organiser) class PoolStyle(models.Model): name = models.CharField(max_length=30) diff --git a/tests/modeltests/known_related_objects/tests.py b/tests/modeltests/known_related_objects/tests.py index 24feab2241..2371ac2e20 100644 --- a/tests/modeltests/known_related_objects/tests.py +++ b/tests/modeltests/known_related_objects/tests.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from django.test import TestCase -from .models import Tournament, Pool, PoolStyle +from .models import Tournament, Organiser, Pool, PoolStyle class ExistingRelatedInstancesTests(TestCase): fixtures = ['tournament.json'] @@ -27,6 +27,46 @@ class ExistingRelatedInstancesTests(TestCase): pool2 = tournaments[1].pool_set.all()[0] self.assertIs(tournaments[1], pool2.tournament) + def test_queryset_or(self): + tournament_1 = Tournament.objects.get(pk=1) + tournament_2 = Tournament.objects.get(pk=2) + with self.assertNumQueries(1): + pools = tournament_1.pool_set.all() | tournament_2.pool_set.all() + related_objects = set(pool.tournament for pool in pools) + self.assertEqual(related_objects, set((tournament_1, tournament_2))) + + def test_queryset_or_different_cached_items(self): + tournament = Tournament.objects.get(pk=1) + organiser = Organiser.objects.get(pk=1) + with self.assertNumQueries(1): + pools = tournament.pool_set.all() | organiser.pool_set.all() + first = pools.filter(pk=1)[0] + self.assertIs(first.tournament, tournament) + self.assertIs(first.organiser, organiser) + + def test_queryset_or_only_one_with_precache(self): + tournament_1 = Tournament.objects.get(pk=1) + tournament_2 = Tournament.objects.get(pk=2) + # 2 queries here as pool id 3 has tournament 2, which is not cached + with self.assertNumQueries(2): + pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3) + related_objects = set(pool.tournament for pool in pools) + self.assertEqual(related_objects, set((tournament_1, tournament_2))) + # and the other direction + with self.assertNumQueries(2): + pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all() + related_objects = set(pool.tournament for pool in pools) + self.assertEqual(related_objects, set((tournament_1, tournament_2))) + + def test_queryset_and(self): + tournament = Tournament.objects.get(pk=1) + organiser = Organiser.objects.get(pk=1) + with self.assertNumQueries(1): + pools = tournament.pool_set.all() & organiser.pool_set.all() + first = pools.filter(pk=1)[0] + self.assertIs(first.tournament, tournament) + self.assertIs(first.organiser, organiser) + def test_one_to_one(self): with self.assertNumQueries(2): style = PoolStyle.objects.get(pk=1)