diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 4ced1f2c400..a94a8949edc 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -577,6 +577,12 @@ def create_reverse_many_to_one_manager(superclass, rel): queryset._known_related_objects = {self.field: {self.instance.pk: self.instance}} return queryset + def _remove_prefetched_objects(self): + try: + self.instance._prefetched_objects_cache.pop(self.field.related_query_name()) + except (AttributeError, KeyError): + pass # nothing to clear from cache + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.field.related_query_name()] @@ -606,6 +612,7 @@ def create_reverse_many_to_one_manager(superclass, rel): return queryset, rel_obj_attr, instance_attr, False, cache_name def add(self, *objs, **kwargs): + self._remove_prefetched_objects() bulk = kwargs.pop('bulk', True) objs = list(objs) db = router.db_for_write(self.model, instance=self.instance) @@ -680,6 +687,7 @@ def create_reverse_many_to_one_manager(superclass, rel): clear.alters_data = True def _clear(self, queryset, bulk): + self._remove_prefetched_objects() db = router.db_for_write(self.model, instance=self.instance) queryset = queryset.using(db) if bulk: @@ -856,6 +864,12 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): queryset = queryset.using(self._db) return queryset._next_is_sticky().filter(**self.core_filters) + def _remove_prefetched_objects(self): + try: + self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) + except (AttributeError, KeyError): + pass # nothing to clear from cache + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] @@ -909,7 +923,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): "intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name) ) - + self._remove_prefetched_objects() db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): self._add_items(self.source_field_name, self.target_field_name, *objs) @@ -927,6 +941,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): "an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name) ) + self._remove_prefetched_objects() self._remove_items(self.source_field_name, self.target_field_name, *objs) remove.alters_data = True @@ -938,6 +953,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): instance=self.instance, reverse=self.reverse, model=self.model, pk_set=None, using=db, ) + self._remove_prefetched_objects() filters = self._build_remove_filters(super(ManyRelatedManager, self).get_queryset().using(db)) self.through._default_manager.using(db).filter(filters).delete() diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 36d791db0e5..3f4fce1a26f 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -979,6 +979,18 @@ database. performance, since you have done a database query that you haven't used. So use this feature with caution! + Also, if you call the database-altering methods + :meth:`~django.db.models.fields.related.RelatedManager.add`, + :meth:`~django.db.models.fields.related.RelatedManager.remove`, + :meth:`~django.db.models.fields.related.RelatedManager.clear` or + :meth:`~django.db.models.fields.related.RelatedManager.set`, on + :class:`related managers`, + any prefetched cache for the relation will be cleared. + + .. versionchanged:: 1.11 + + The clearing of the prefetched cache described above was added. + You can also use the normal join syntax to do related fields of related fields. Suppose we have an additional model to the example above:: diff --git a/docs/ref/models/relations.txt b/docs/ref/models/relations.txt index d8100287df3..c34d3104b23 100644 --- a/docs/ref/models/relations.txt +++ b/docs/ref/models/relations.txt @@ -170,6 +170,14 @@ Related objects reference ``add()``, ``create()``, ``remove()``, and ``set()`` methods are disabled. + If you use :meth:`~django.db.models.query.QuerySet.prefetch_related`, + the ``add()``, ``remove()``, ``clear()``, and ``set()`` methods clear + the prefetched cache. + + .. versionchanged:: 1.11 + + The clearing of the prefetched cache described above was added. + Direct Assignment ================= diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index 381d49a42a4..45bfdcffdf4 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -357,6 +357,13 @@ Miscellaneous * The ``checked`` attribute rendered by form widgets now uses HTML5 boolean syntax rather than XHTML's ``checked='checked'``. +* :meth:`RelatedManager.add() + `, + :meth:`~django.db.models.fields.related.RelatedManager.remove`, + :meth:`~django.db.models.fields.related.RelatedManager.clear`, and + :meth:`~django.db.models.fields.related.RelatedManager.set` now + clear the ``prefetch_related()`` cache. + .. _deprecated-features-1.11: Features deprecated in 1.11 diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py index 86ffe816eb2..2d35266f467 100644 --- a/tests/many_to_many/tests.py +++ b/tests/many_to_many/tests.py @@ -518,6 +518,40 @@ class ManyToManyTests(TestCase): self.assertQuerysetEqual(self.a4.publications.all(), []) self.assertQuerysetEqual(self.p2.article_set.all(), ['']) + def test_clear_after_prefetch(self): + a4 = Article.objects.prefetch_related('publications').get(id=self.a4.id) + self.assertQuerysetEqual(a4.publications.all(), ['']) + a4.publications.clear() + self.assertQuerysetEqual(a4.publications.all(), []) + + def test_remove_after_prefetch(self): + a4 = Article.objects.prefetch_related('publications').get(id=self.a4.id) + self.assertQuerysetEqual(a4.publications.all(), ['']) + a4.publications.remove(self.p2) + self.assertQuerysetEqual(a4.publications.all(), []) + + def test_add_after_prefetch(self): + a4 = Article.objects.prefetch_related('publications').get(id=self.a4.id) + self.assertEqual(a4.publications.count(), 1) + a4.publications.add(self.p1) + self.assertEqual(a4.publications.count(), 2) + + def test_set_after_prefetch(self): + a4 = Article.objects.prefetch_related('publications').get(id=self.a4.id) + self.assertEqual(a4.publications.count(), 1) + a4.publications.set([self.p2, self.p1]) + self.assertEqual(a4.publications.count(), 2) + a4.publications.set([self.p1]) + self.assertEqual(a4.publications.count(), 1) + + def test_add_then_remove_after_prefetch(self): + a4 = Article.objects.prefetch_related('publications').get(id=self.a4.id) + self.assertEqual(a4.publications.count(), 1) + a4.publications.add(self.p1) + self.assertEqual(a4.publications.count(), 2) + a4.publications.remove(self.p1) + self.assertQuerysetEqual(a4.publications.all(), ['']) + def test_inherited_models_selects(self): """ #24156 - Objects from child models where the parent's m2m field uses diff --git a/tests/many_to_one/models.py b/tests/many_to_one/models.py index 05491e99f1a..abc9c7d8265 100644 --- a/tests/many_to_one/models.py +++ b/tests/many_to_one/models.py @@ -43,7 +43,7 @@ class City(models.Model): @python_2_unicode_compatible class District(models.Model): - city = models.ForeignKey(City, models.CASCADE) + city = models.ForeignKey(City, models.CASCADE, related_name='districts', null=True) name = models.CharField(max_length=50) def __str__(self): diff --git a/tests/many_to_one/tests.py b/tests/many_to_one/tests.py index 20f2e0d56d4..0de6973bb83 100644 --- a/tests/many_to_one/tests.py +++ b/tests/many_to_one/tests.py @@ -624,3 +624,48 @@ class ManyToOneTests(TestCase): # doesn't exist should be an instance of a subclass of `AttributeError` # refs #21563 self.assertFalse(hasattr(Article(), 'reporter')) + + def test_clear_after_prefetch(self): + c = City.objects.create(name='Musical City') + District.objects.create(name='Ladida', city=c) + city = City.objects.prefetch_related('districts').get(id=c.id) + self.assertQuerysetEqual(city.districts.all(), ['']) + city.districts.clear() + self.assertQuerysetEqual(city.districts.all(), []) + + def test_remove_after_prefetch(self): + c = City.objects.create(name='Musical City') + d = District.objects.create(name='Ladida', city=c) + city = City.objects.prefetch_related('districts').get(id=c.id) + self.assertQuerysetEqual(city.districts.all(), ['']) + city.districts.remove(d) + self.assertQuerysetEqual(city.districts.all(), []) + + def test_add_after_prefetch(self): + c = City.objects.create(name='Musical City') + District.objects.create(name='Ladida', city=c) + d2 = District.objects.create(name='Ladidu') + city = City.objects.prefetch_related('districts').get(id=c.id) + self.assertEqual(city.districts.count(), 1) + city.districts.add(d2) + self.assertEqual(city.districts.count(), 2) + + def test_set_after_prefetch(self): + c = City.objects.create(name='Musical City') + District.objects.create(name='Ladida', city=c) + d2 = District.objects.create(name='Ladidu') + city = City.objects.prefetch_related('districts').get(id=c.id) + self.assertEqual(city.districts.count(), 1) + city.districts.set([d2]) + self.assertQuerysetEqual(city.districts.all(), ['']) + + def test_add_then_remove_after_prefetch(self): + c = City.objects.create(name='Musical City') + District.objects.create(name='Ladida', city=c) + d2 = District.objects.create(name='Ladidu') + city = City.objects.prefetch_related('districts').get(id=c.id) + self.assertEqual(city.districts.count(), 1) + city.districts.add(d2) + self.assertEqual(city.districts.count(), 2) + city.districts.remove(d2) + self.assertEqual(city.districts.count(), 1)