[1.5.x] Fixed #19547 -- Caching of related instances.

When &'ing or |'ing querysets, wrong values could be cached, and crashes
could happen.

Thanks Marc Tamlyn for figuring out the problem and writing the patch.

Backport of 07fbc6a.
This commit is contained in:
Aymeric Augustin 2013-01-02 21:42:52 +01:00
parent da2cdd3a0f
commit 056ace0f39
5 changed files with 79 additions and 9 deletions

View File

@ -497,7 +497,7 @@ class ForeignRelatedObjectsDescriptor(object):
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
qs._known_related_object = (rel_field.name, self.instance)
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs
def get_prefetch_query_set(self, instances):

View File

@ -44,7 +44,7 @@ class QuerySet(object):
self._for_write = False
self._prefetch_related_lookups = []
self._prefetch_done = False
self._known_related_object = None # (attname, rel_obj)
self._known_related_objects = {} # {rel_field, {pk: rel_obj}}
########################
# PYTHON MAGIC METHODS #
@ -221,6 +221,7 @@ class QuerySet(object):
if isinstance(other, EmptyQuerySet):
return other._clone()
combined = self._clone()
combined._merge_known_related_objects(other)
combined.query.combine(other.query, sql.AND)
return combined
@ -229,6 +230,7 @@ class QuerySet(object):
combined = self._clone()
if isinstance(other, EmptyQuerySet):
return combined
combined._merge_known_related_objects(other)
combined.query.combine(other.query, sql.OR)
return combined
@ -289,10 +291,9 @@ class QuerySet(object):
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
# Cache db, model and known_related_object outside the loop
# Cache db and model outside the loop
db = self.db
model = self.model
kro_attname, kro_instance = self._known_related_object or (None, None)
compiler = self.query.get_compiler(using=db)
if fill_cache:
klass_info = get_klass_info(model, max_depth=max_depth,
@ -323,9 +324,16 @@ class QuerySet(object):
for i, aggregate in enumerate(aggregate_select):
setattr(obj, aggregate, row[i + aggregate_start])
# Add the known related object to the model, if there is one
if kro_instance:
setattr(obj, kro_attname, kro_instance)
# Add the known related objects to the model, if there are any
if self._known_related_objects:
for field, rel_objs in self._known_related_objects.items():
pk = getattr(obj, field.get_attname())
try:
rel_obj = rel_objs[pk]
except KeyError:
pass # may happen in qs1 | qs2 scenarios
else:
setattr(obj, field.name, rel_obj)
yield obj
@ -902,7 +910,7 @@ class QuerySet(object):
c = klass(model=self.model, query=query, using=self._db)
c._for_write = self._for_write
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
c._known_related_object = self._known_related_object
c._known_related_objects = self._known_related_objects
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@ -942,6 +950,13 @@ class QuerySet(object):
"""
pass
def _merge_known_related_objects(self, other):
"""
Keep track of all known related objects from either QuerySet instance.
"""
for field, objects in other._known_related_objects.items():
self._known_related_objects.setdefault(field, {}).update(objects)
def _setup_aggregate_query(self, aggregates):
"""
Prepare the query for computing a result that contains aggregate annotations.

View File

@ -13,11 +13,19 @@
"name": "Tourney 2"
}
},
{
"pk": 1,
"model": "known_related_objects.organiser",
"fields": {
"name": "Organiser 1"
}
},
{
"pk": 1,
"model": "known_related_objects.pool",
"fields": {
"tournament": 1,
"organiser": 1,
"name": "T1 Pool 1"
}
},
@ -26,6 +34,7 @@
"model": "known_related_objects.pool",
"fields": {
"tournament": 1,
"organiser": 1,
"name": "T1 Pool 2"
}
},
@ -34,6 +43,7 @@
"model": "known_related_objects.pool",
"fields": {
"tournament": 2,
"organiser": 1,
"name": "T2 Pool 1"
}
},
@ -42,6 +52,7 @@
"model": "known_related_objects.pool",
"fields": {
"tournament": 2,
"organiser": 1,
"name": "T2 Pool 2"
}
},

View File

@ -9,9 +9,13 @@ from django.db import models
class Tournament(models.Model):
name = models.CharField(max_length=30)
class Organiser(models.Model):
name = models.CharField(max_length=30)
class Pool(models.Model):
name = models.CharField(max_length=30)
tournament = models.ForeignKey(Tournament)
organiser = models.ForeignKey(Organiser)
class PoolStyle(models.Model):
name = models.CharField(max_length=30)

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
from django.test import TestCase
from .models import Tournament, Pool, PoolStyle
from .models import Tournament, Organiser, Pool, PoolStyle
class ExistingRelatedInstancesTests(TestCase):
fixtures = ['tournament.json']
@ -27,6 +27,46 @@ class ExistingRelatedInstancesTests(TestCase):
pool2 = tournaments[1].pool_set.all()[0]
self.assertIs(tournaments[1], pool2.tournament)
def test_queryset_or(self):
tournament_1 = Tournament.objects.get(pk=1)
tournament_2 = Tournament.objects.get(pk=2)
with self.assertNumQueries(1):
pools = tournament_1.pool_set.all() | tournament_2.pool_set.all()
related_objects = set(pool.tournament for pool in pools)
self.assertEqual(related_objects, set((tournament_1, tournament_2)))
def test_queryset_or_different_cached_items(self):
tournament = Tournament.objects.get(pk=1)
organiser = Organiser.objects.get(pk=1)
with self.assertNumQueries(1):
pools = tournament.pool_set.all() | organiser.pool_set.all()
first = pools.filter(pk=1)[0]
self.assertIs(first.tournament, tournament)
self.assertIs(first.organiser, organiser)
def test_queryset_or_only_one_with_precache(self):
tournament_1 = Tournament.objects.get(pk=1)
tournament_2 = Tournament.objects.get(pk=2)
# 2 queries here as pool id 3 has tournament 2, which is not cached
with self.assertNumQueries(2):
pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3)
related_objects = set(pool.tournament for pool in pools)
self.assertEqual(related_objects, set((tournament_1, tournament_2)))
# and the other direction
with self.assertNumQueries(2):
pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all()
related_objects = set(pool.tournament for pool in pools)
self.assertEqual(related_objects, set((tournament_1, tournament_2)))
def test_queryset_and(self):
tournament = Tournament.objects.get(pk=1)
organiser = Organiser.objects.get(pk=1)
with self.assertNumQueries(1):
pools = tournament.pool_set.all() & organiser.pool_set.all()
first = pools.filter(pk=1)[0]
self.assertIs(first.tournament, tournament)
self.assertIs(first.organiser, organiser)
def test_one_to_one(self):
with self.assertNumQueries(2):
style = PoolStyle.objects.get(pk=1)