Fixed #19547 -- Caching of related instances.

When &'ing or |'ing querysets, wrong values could be cached, and crashes
could happen.

Thanks Marc Tamlyn for figuring out the problem and writing the patch.
This commit is contained in:
Aymeric Augustin 2013-01-02 21:42:52 +01:00
parent 695b2089e7
commit 07fbc6ae0e
5 changed files with 79 additions and 9 deletions

View File

@ -496,7 +496,7 @@ class ForeignRelatedObjectsDescriptor(object):
except (AttributeError, KeyError): except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance) db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters) qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
qs._known_related_object = (rel_field.name, self.instance) qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs return qs
def get_prefetch_query_set(self, instances): def get_prefetch_query_set(self, instances):

View File

@ -44,7 +44,7 @@ class QuerySet(object):
self._for_write = False self._for_write = False
self._prefetch_related_lookups = [] self._prefetch_related_lookups = []
self._prefetch_done = False self._prefetch_done = False
self._known_related_object = None # (attname, rel_obj) self._known_related_objects = {} # {rel_field, {pk: rel_obj}}
######################## ########################
# PYTHON MAGIC METHODS # # PYTHON MAGIC METHODS #
@ -221,6 +221,7 @@ class QuerySet(object):
if isinstance(other, EmptyQuerySet): if isinstance(other, EmptyQuerySet):
return other._clone() return other._clone()
combined = self._clone() combined = self._clone()
combined._merge_known_related_objects(other)
combined.query.combine(other.query, sql.AND) combined.query.combine(other.query, sql.AND)
return combined return combined
@ -229,6 +230,7 @@ class QuerySet(object):
combined = self._clone() combined = self._clone()
if isinstance(other, EmptyQuerySet): if isinstance(other, EmptyQuerySet):
return combined return combined
combined._merge_known_related_objects(other)
combined.query.combine(other.query, sql.OR) combined.query.combine(other.query, sql.OR)
return combined return combined
@ -289,10 +291,9 @@ class QuerySet(object):
init_list.append(field.attname) init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip) model_cls = deferred_class_factory(self.model, skip)
# Cache db, model and known_related_object outside the loop # Cache db and model outside the loop
db = self.db db = self.db
model = self.model model = self.model
kro_attname, kro_instance = self._known_related_object or (None, None)
compiler = self.query.get_compiler(using=db) compiler = self.query.get_compiler(using=db)
if fill_cache: if fill_cache:
klass_info = get_klass_info(model, max_depth=max_depth, klass_info = get_klass_info(model, max_depth=max_depth,
@ -323,9 +324,16 @@ class QuerySet(object):
for i, aggregate in enumerate(aggregate_select): for i, aggregate in enumerate(aggregate_select):
setattr(obj, aggregate, row[i + aggregate_start]) setattr(obj, aggregate, row[i + aggregate_start])
# Add the known related object to the model, if there is one # Add the known related objects to the model, if there are any
if kro_instance: if self._known_related_objects:
setattr(obj, kro_attname, kro_instance) for field, rel_objs in self._known_related_objects.items():
pk = getattr(obj, field.get_attname())
try:
rel_obj = rel_objs[pk]
except KeyError:
pass # may happen in qs1 | qs2 scenarios
else:
setattr(obj, field.name, rel_obj)
yield obj yield obj
@ -902,7 +910,7 @@ class QuerySet(object):
c = klass(model=self.model, query=query, using=self._db) c = klass(model=self.model, query=query, using=self._db)
c._for_write = self._for_write c._for_write = self._for_write
c._prefetch_related_lookups = self._prefetch_related_lookups[:] c._prefetch_related_lookups = self._prefetch_related_lookups[:]
c._known_related_object = self._known_related_object c._known_related_objects = self._known_related_objects
c.__dict__.update(kwargs) c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'): if setup and hasattr(c, '_setup_query'):
c._setup_query() c._setup_query()
@ -942,6 +950,13 @@ class QuerySet(object):
""" """
pass pass
def _merge_known_related_objects(self, other):
"""
Keep track of all known related objects from either QuerySet instance.
"""
for field, objects in other._known_related_objects.items():
self._known_related_objects.setdefault(field, {}).update(objects)
def _setup_aggregate_query(self, aggregates): def _setup_aggregate_query(self, aggregates):
""" """
Prepare the query for computing a result that contains aggregate annotations. Prepare the query for computing a result that contains aggregate annotations.

View File

@ -13,11 +13,19 @@
"name": "Tourney 2" "name": "Tourney 2"
} }
}, },
{
"pk": 1,
"model": "known_related_objects.organiser",
"fields": {
"name": "Organiser 1"
}
},
{ {
"pk": 1, "pk": 1,
"model": "known_related_objects.pool", "model": "known_related_objects.pool",
"fields": { "fields": {
"tournament": 1, "tournament": 1,
"organiser": 1,
"name": "T1 Pool 1" "name": "T1 Pool 1"
} }
}, },
@ -26,6 +34,7 @@
"model": "known_related_objects.pool", "model": "known_related_objects.pool",
"fields": { "fields": {
"tournament": 1, "tournament": 1,
"organiser": 1,
"name": "T1 Pool 2" "name": "T1 Pool 2"
} }
}, },
@ -34,6 +43,7 @@
"model": "known_related_objects.pool", "model": "known_related_objects.pool",
"fields": { "fields": {
"tournament": 2, "tournament": 2,
"organiser": 1,
"name": "T2 Pool 1" "name": "T2 Pool 1"
} }
}, },
@ -42,6 +52,7 @@
"model": "known_related_objects.pool", "model": "known_related_objects.pool",
"fields": { "fields": {
"tournament": 2, "tournament": 2,
"organiser": 1,
"name": "T2 Pool 2" "name": "T2 Pool 2"
} }
}, },

View File

@ -9,9 +9,13 @@ from django.db import models
class Tournament(models.Model): class Tournament(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
class Organiser(models.Model):
name = models.CharField(max_length=30)
class Pool(models.Model): class Pool(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
tournament = models.ForeignKey(Tournament) tournament = models.ForeignKey(Tournament)
organiser = models.ForeignKey(Organiser)
class PoolStyle(models.Model): class PoolStyle(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
from django.test import TestCase from django.test import TestCase
from .models import Tournament, Pool, PoolStyle from .models import Tournament, Organiser, Pool, PoolStyle
class ExistingRelatedInstancesTests(TestCase): class ExistingRelatedInstancesTests(TestCase):
fixtures = ['tournament.json'] fixtures = ['tournament.json']
@ -27,6 +27,46 @@ class ExistingRelatedInstancesTests(TestCase):
pool2 = tournaments[1].pool_set.all()[0] pool2 = tournaments[1].pool_set.all()[0]
self.assertIs(tournaments[1], pool2.tournament) self.assertIs(tournaments[1], pool2.tournament)
def test_queryset_or(self):
tournament_1 = Tournament.objects.get(pk=1)
tournament_2 = Tournament.objects.get(pk=2)
with self.assertNumQueries(1):
pools = tournament_1.pool_set.all() | tournament_2.pool_set.all()
related_objects = set(pool.tournament for pool in pools)
self.assertEqual(related_objects, set((tournament_1, tournament_2)))
def test_queryset_or_different_cached_items(self):
tournament = Tournament.objects.get(pk=1)
organiser = Organiser.objects.get(pk=1)
with self.assertNumQueries(1):
pools = tournament.pool_set.all() | organiser.pool_set.all()
first = pools.filter(pk=1)[0]
self.assertIs(first.tournament, tournament)
self.assertIs(first.organiser, organiser)
def test_queryset_or_only_one_with_precache(self):
tournament_1 = Tournament.objects.get(pk=1)
tournament_2 = Tournament.objects.get(pk=2)
# 2 queries here as pool id 3 has tournament 2, which is not cached
with self.assertNumQueries(2):
pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3)
related_objects = set(pool.tournament for pool in pools)
self.assertEqual(related_objects, set((tournament_1, tournament_2)))
# and the other direction
with self.assertNumQueries(2):
pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all()
related_objects = set(pool.tournament for pool in pools)
self.assertEqual(related_objects, set((tournament_1, tournament_2)))
def test_queryset_and(self):
tournament = Tournament.objects.get(pk=1)
organiser = Organiser.objects.get(pk=1)
with self.assertNumQueries(1):
pools = tournament.pool_set.all() & organiser.pool_set.all()
first = pools.filter(pk=1)[0]
self.assertIs(first.tournament, tournament)
self.assertIs(first.organiser, organiser)
def test_one_to_one(self): def test_one_to_one(self):
with self.assertNumQueries(2): with self.assertNumQueries(2):
style = PoolStyle.objects.get(pk=1) style = PoolStyle.objects.get(pk=1)