Fixed #29612 -- Added GenericRelation prefetch_related() cache invalidation.

This commit is contained in:
Tom Forbes 2018-08-07 23:08:23 +01:00 committed by Tim Graham
parent bf17f5e884
commit c02d473781
2 changed files with 56 additions and 0 deletions

View File

@ -545,6 +545,12 @@ def create_generic_related_manager(superclass, rel):
db = self._db or router.db_for_read(self.model, instance=self.instance) db = self._db or router.db_for_read(self.model, instance=self.instance)
return queryset.using(db).filter(**self.core_filters) return queryset.using(db).filter(**self.core_filters)
def _remove_prefetched_objects(self):
try:
self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
except (AttributeError, KeyError):
pass # nothing to clear from cache
def get_queryset(self): def get_queryset(self):
try: try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name] return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
@ -577,6 +583,7 @@ def create_generic_related_manager(superclass, rel):
) )
def add(self, *objs, bulk=True): def add(self, *objs, bulk=True):
self._remove_prefetched_objects()
db = router.db_for_write(self.model, instance=self.instance) db = router.db_for_write(self.model, instance=self.instance)
def check_and_update_obj(obj): def check_and_update_obj(obj):
@ -620,6 +627,7 @@ def create_generic_related_manager(superclass, rel):
clear.alters_data = True clear.alters_data = True
def _clear(self, queryset, bulk): def _clear(self, queryset, bulk):
self._remove_prefetched_objects()
db = router.db_for_write(self.model, instance=self.instance) db = router.db_for_write(self.model, instance=self.instance)
queryset = queryset.using(db) queryset = queryset.using(db)
if bulk: if bulk:
@ -656,6 +664,7 @@ def create_generic_related_manager(superclass, rel):
set.alters_data = True set.alters_data = True
def create(self, **kwargs): def create(self, **kwargs):
self._remove_prefetched_objects()
kwargs[self.content_type_field_name] = self.content_type kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val kwargs[self.object_id_field_name] = self.pk_val
db = router.db_for_write(self.model, instance=self.instance) db = router.db_for_write(self.model, instance=self.instance)

View File

@ -499,6 +499,53 @@ class GenericRelationsTests(TestCase):
tag = TaggedItem(content_object=spinach) tag = TaggedItem(content_object=spinach)
self.assertEqual(tag.content_object, spinach) self.assertEqual(tag.content_object, spinach)
def test_create_after_prefetch(self):
platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk)
self.assertSequenceEqual(platypus.tags.all(), [])
weird_tag = platypus.tags.create(tag='weird')
self.assertSequenceEqual(platypus.tags.all(), [weird_tag])
def test_add_after_prefetch(self):
platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk)
self.assertSequenceEqual(platypus.tags.all(), [])
weird_tag = TaggedItem.objects.create(tag='weird', content_object=platypus)
platypus.tags.add(weird_tag)
self.assertSequenceEqual(platypus.tags.all(), [weird_tag])
def test_remove_after_prefetch(self):
weird_tag = self.platypus.tags.create(tag='weird')
platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk)
self.assertSequenceEqual(platypus.tags.all(), [weird_tag])
platypus.tags.remove(weird_tag)
self.assertSequenceEqual(platypus.tags.all(), [])
def test_clear_after_prefetch(self):
weird_tag = self.platypus.tags.create(tag='weird')
platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk)
self.assertSequenceEqual(platypus.tags.all(), [weird_tag])
platypus.tags.clear()
self.assertSequenceEqual(platypus.tags.all(), [])
def test_set_after_prefetch(self):
platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk)
self.assertSequenceEqual(platypus.tags.all(), [])
furry_tag = TaggedItem.objects.create(tag='furry', content_object=platypus)
platypus.tags.set([furry_tag])
self.assertSequenceEqual(platypus.tags.all(), [furry_tag])
weird_tag = TaggedItem.objects.create(tag='weird', content_object=platypus)
platypus.tags.set([weird_tag])
self.assertSequenceEqual(platypus.tags.all(), [weird_tag])
def test_add_then_remove_after_prefetch(self):
furry_tag = self.platypus.tags.create(tag='furry')
platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk)
self.assertSequenceEqual(platypus.tags.all(), [furry_tag])
weird_tag = self.platypus.tags.create(tag='weird')
platypus.tags.add(weird_tag)
self.assertSequenceEqual(platypus.tags.all(), [furry_tag, weird_tag])
platypus.tags.remove(weird_tag)
self.assertSequenceEqual(platypus.tags.all(), [furry_tag])
class ProxyRelatedModelTest(TestCase): class ProxyRelatedModelTest(TestCase):
def test_default_behavior(self): def test_default_behavior(self):