diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index e55766aa6f9..ed98ecb48cb 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -545,6 +545,12 @@ def create_generic_related_manager(superclass, rel): db = self._db or router.db_for_read(self.model, instance=self.instance) return queryset.using(db).filter(**self.core_filters) + def _remove_prefetched_objects(self): + try: + self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) + except (AttributeError, KeyError): + pass # nothing to clear from cache + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] @@ -577,6 +583,7 @@ def create_generic_related_manager(superclass, rel): ) def add(self, *objs, bulk=True): + self._remove_prefetched_objects() db = router.db_for_write(self.model, instance=self.instance) def check_and_update_obj(obj): @@ -620,6 +627,7 @@ def create_generic_related_manager(superclass, rel): clear.alters_data = True def _clear(self, queryset, bulk): + self._remove_prefetched_objects() db = router.db_for_write(self.model, instance=self.instance) queryset = queryset.using(db) if bulk: @@ -656,6 +664,7 @@ def create_generic_related_manager(superclass, rel): set.alters_data = True def create(self, **kwargs): + self._remove_prefetched_objects() kwargs[self.content_type_field_name] = self.content_type kwargs[self.object_id_field_name] = self.pk_val db = router.db_for_write(self.model, instance=self.instance) diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index 4c27b6f3d51..2e99c5b5cf4 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -499,6 +499,53 @@ class GenericRelationsTests(TestCase): tag = TaggedItem(content_object=spinach) self.assertEqual(tag.content_object, spinach) + def test_create_after_prefetch(self): + platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), []) + weird_tag = platypus.tags.create(tag='weird') + self.assertSequenceEqual(platypus.tags.all(), [weird_tag]) + + def test_add_after_prefetch(self): + platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), []) + weird_tag = TaggedItem.objects.create(tag='weird', content_object=platypus) + platypus.tags.add(weird_tag) + self.assertSequenceEqual(platypus.tags.all(), [weird_tag]) + + def test_remove_after_prefetch(self): + weird_tag = self.platypus.tags.create(tag='weird') + platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), [weird_tag]) + platypus.tags.remove(weird_tag) + self.assertSequenceEqual(platypus.tags.all(), []) + + def test_clear_after_prefetch(self): + weird_tag = self.platypus.tags.create(tag='weird') + platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), [weird_tag]) + platypus.tags.clear() + self.assertSequenceEqual(platypus.tags.all(), []) + + def test_set_after_prefetch(self): + platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), []) + furry_tag = TaggedItem.objects.create(tag='furry', content_object=platypus) + platypus.tags.set([furry_tag]) + self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + weird_tag = TaggedItem.objects.create(tag='weird', content_object=platypus) + platypus.tags.set([weird_tag]) + self.assertSequenceEqual(platypus.tags.all(), [weird_tag]) + + def test_add_then_remove_after_prefetch(self): + furry_tag = self.platypus.tags.create(tag='furry') + platypus = Animal.objects.prefetch_related('tags').get(pk=self.platypus.pk) + self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + weird_tag = self.platypus.tags.create(tag='weird') + platypus.tags.add(weird_tag) + self.assertSequenceEqual(platypus.tags.all(), [furry_tag, weird_tag]) + platypus.tags.remove(weird_tag) + self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + class ProxyRelatedModelTest(TestCase): def test_default_behavior(self):