From 321ecb40f4da842926e1bc07e11df4aabe53ca4b Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Thu, 3 Nov 2022 19:57:33 +0100 Subject: [PATCH] Fixed #34135 -- Added async-compatible interface to related managers. --- django/contrib/contenttypes/fields.py | 20 +++++++ .../db/models/fields/related_descriptors.py | 44 ++++++++++++++++ docs/ref/models/relations.txt | 42 ++++++++++++--- docs/releases/4.2.txt | 5 ++ docs/topics/async.txt | 6 ++- tests/async/test_async_related_managers.py | 52 ++++++++++++++++++- tests/generic_relations/tests.py | 27 ++++++++++ 7 files changed, 188 insertions(+), 8 deletions(-) diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index ce2a096cc2c..72b9e7631a8 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -689,6 +689,11 @@ def create_generic_related_manager(superclass, rel): add.alters_data = True + async def aadd(self, *objs, bulk=True): + return await sync_to_async(self.add)(*objs, bulk=bulk) + + aadd.alters_data = True + def remove(self, *objs, bulk=True): if not objs: return @@ -696,11 +701,21 @@ def create_generic_related_manager(superclass, rel): remove.alters_data = True + async def aremove(self, *objs, bulk=True): + return await sync_to_async(self.remove)(*objs, bulk=bulk) + + aremove.alters_data = True + def clear(self, *, bulk=True): self._clear(self, bulk) clear.alters_data = True + async def aclear(self, *, bulk=True): + return await sync_to_async(self.clear)(bulk=bulk) + + aclear.alters_data = True + def _clear(self, queryset, bulk): self._remove_prefetched_objects() db = router.db_for_write(self.model, instance=self.instance) @@ -740,6 +755,11 @@ def create_generic_related_manager(superclass, rel): set.alters_data = True + async def aset(self, objs, *, bulk=True, clear=False): + return await sync_to_async(self.set)(objs, bulk=bulk, clear=clear) + + aset.alters_data = True + def create(self, **kwargs): self._remove_prefetched_objects() kwargs[self.content_type_field_name] = self.content_type diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index f1c8a73f494..422b08e6ca9 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -787,6 +787,11 @@ def create_reverse_many_to_one_manager(superclass, rel): add.alters_data = True + async def aadd(self, *objs, bulk=True): + return await sync_to_async(self.add)(*objs, bulk=bulk) + + aadd.alters_data = True + def create(self, **kwargs): self._check_fk_val() kwargs[self.field.name] = self.instance @@ -856,12 +861,22 @@ def create_reverse_many_to_one_manager(superclass, rel): remove.alters_data = True + async def aremove(self, *objs, bulk=True): + return await sync_to_async(self.remove)(*objs, bulk=bulk) + + aremove.alters_data = True + def clear(self, *, bulk=True): self._check_fk_val() self._clear(self, bulk) clear.alters_data = True + async def aclear(self, *, bulk=True): + return await sync_to_async(self.clear)(bulk=bulk) + + aclear.alters_data = True + def _clear(self, queryset, bulk): self._remove_prefetched_objects() db = router.db_for_write(self.model, instance=self.instance) @@ -905,6 +920,11 @@ def create_reverse_many_to_one_manager(superclass, rel): set.alters_data = True + async def aset(self, objs, *, bulk=True, clear=False): + return await sync_to_async(self.set)(objs=objs, bulk=bulk, clear=clear) + + aset.alters_data = True + return RelatedManager @@ -1132,12 +1152,24 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): add.alters_data = True + async def aadd(self, *objs, through_defaults=None): + return await sync_to_async(self.add)( + *objs, through_defaults=through_defaults + ) + + aadd.alters_data = True + def remove(self, *objs): self._remove_prefetched_objects() self._remove_items(self.source_field_name, self.target_field_name, *objs) remove.alters_data = True + async def aremove(self, *objs): + return await sync_to_async(self.remove)(*objs) + + aremove.alters_data = True + def clear(self): db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): @@ -1166,6 +1198,11 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): clear.alters_data = True + async def aclear(self): + return await sync_to_async(self.clear)() + + aclear.alters_data = True + def set(self, objs, *, clear=False, through_defaults=None): # Force evaluation of `objs` in case it's a queryset whose value # could be affected by `manager.clear()`. Refs #19816. @@ -1200,6 +1237,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): set.alters_data = True + async def aset(self, objs, *, clear=False, through_defaults=None): + return await sync_to_async(self.set)( + objs=objs, clear=clear, through_defaults=through_defaults + ) + + aset.alters_data = True + def create(self, *, through_defaults=None, **kwargs): db = router.db_for_write(self.instance.__class__, instance=self.instance) new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs) diff --git a/docs/ref/models/relations.txt b/docs/ref/models/relations.txt index c091612c6ec..1b1aad7425c 100644 --- a/docs/ref/models/relations.txt +++ b/docs/ref/models/relations.txt @@ -37,6 +37,9 @@ Related objects reference ``topping.pizza_set`` and on ``pizza.toppings``. .. method:: add(*objs, bulk=True, through_defaults=None) + .. method:: aadd(*objs, bulk=True, through_defaults=None) + + *Asynchronous version*: ``aadd`` Adds the specified model objects to the related object set. @@ -75,6 +78,10 @@ Related objects reference dictionary and they will be evaluated once before creating any intermediate instance(s). + .. versionchanged:: 4.2 + + ``aadd()`` method was added. + .. method:: create(through_defaults=None, **kwargs) .. method:: acreate(through_defaults=None, **kwargs) @@ -118,6 +125,9 @@ Related objects reference ``acreate()`` method was added. .. method:: remove(*objs, bulk=True) + .. method:: aremove(*objs, bulk=True) + + *Asynchronous version*: ``aremove`` Removes the specified model objects from the related object set:: @@ -157,7 +167,14 @@ Related objects reference For many-to-many relationships, the ``bulk`` keyword argument doesn't exist. + .. versionchanged:: 4.2 + + ``aremove()`` method was added. + .. method:: clear(bulk=True) + .. method:: aclear(bulk=True) + + *Asynchronous version*: ``aclear`` Removes all objects from the related object set:: @@ -174,7 +191,14 @@ Related objects reference For many-to-many relationships, the ``bulk`` keyword argument doesn't exist. + .. versionchanged:: 4.2 + + ``aclear()`` method was added. + .. method:: set(objs, bulk=True, clear=False, through_defaults=None) + .. method:: aset(objs, bulk=True, clear=False, through_defaults=None) + + *Asynchronous version*: ``aset`` Replace the set of related objects:: @@ -207,13 +231,19 @@ Related objects reference dictionary and they will be evaluated once before creating any intermediate instance(s). + .. versionchanged:: 4.2 + + ``aset()`` method was added. + .. note:: - Note that ``add()``, ``create()``, ``remove()``, ``clear()``, and - ``set()`` all apply database changes immediately for all types of - related fields. In other words, there is no need to call ``save()`` - on either end of the relationship. + Note that ``add()``, ``aadd()``, ``create()``, ``acreate()``, + ``remove()``, ``aremove()``, ``clear()``, ``aclear()``, ``set()``, and + ``aset()`` all apply database changes immediately for all types of + related fields. In other words, there is no need to call + ``save()``/``asave()`` on either end of the relationship. If you use :meth:`~django.db.models.query.QuerySet.prefetch_related`, - the ``add()``, ``remove()``, ``clear()``, and ``set()`` methods clear - the prefetched cache. + the ``add()``, ``aadd()``, ``remove()``, ``aremove()``, ``clear()``, + ``aclear()``, ``set()``, and ``aset()`` methods clear the prefetched + cache. diff --git a/docs/releases/4.2.txt b/docs/releases/4.2.txt index bf0a36fce61..39cf05f23bd 100644 --- a/docs/releases/4.2.txt +++ b/docs/releases/4.2.txt @@ -243,6 +243,11 @@ Models database, using an ``a`` prefix: :meth:`~.Model.adelete`, :meth:`~.Model.arefresh_from_db`, and :meth:`~.Model.asave`. +* Related managers now provide asynchronous versions of methods that change a + set of related objects, using an ``a`` prefix: :meth:`~.RelatedManager.aadd`, + :meth:`~.RelatedManager.aclear`, :meth:`~.RelatedManager.aremove`, and + :meth:`~.RelatedManager.aset`. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/async.txt b/docs/topics/async.txt index 2b9b1a85d9e..0570d59db23 100644 --- a/docs/topics/async.txt +++ b/docs/topics/async.txt @@ -97,13 +97,17 @@ Django also supports some asynchronous model methods that use the database:: book = Book(...) await book.asave(using="secondary") + async def make_book_with_tags(tags, ...): + book = await Book.objects.acreate(...) + await book.tags.aset(tags) + Transactions do not yet work in async mode. If you have a piece of code that needs transactions behavior, we recommend you write that piece as a single synchronous function and call it using :func:`sync_to_async`. .. versionchanged:: 4.2 - Asynchronous model interface was added. + Asynchronous model and related manager interfaces were added. Performance ----------- diff --git a/tests/async/test_async_related_managers.py b/tests/async/test_async_related_managers.py index dd573f59897..c475b54899b 100644 --- a/tests/async/test_async_related_managers.py +++ b/tests/async/test_async_related_managers.py @@ -1,6 +1,6 @@ from django.test import TestCase -from .models import ManyToManyModel, SimpleModel +from .models import ManyToManyModel, RelatedModel, SimpleModel class AsyncRelatedManagersOperationTest(TestCase): @@ -8,6 +8,8 @@ class AsyncRelatedManagersOperationTest(TestCase): def setUpTestData(cls): cls.mtm1 = ManyToManyModel.objects.create() cls.s1 = SimpleModel.objects.create(field=0) + cls.mtm2 = ManyToManyModel.objects.create() + cls.mtm2.simples.set([cls.s1]) async def test_acreate(self): await self.mtm1.simples.acreate(field=2) @@ -54,3 +56,51 @@ class AsyncRelatedManagersOperationTest(TestCase): self.assertIs(created, True) self.assertEqual(await self.s1.relatedmodel_set.acount(), 1) self.assertEqual(new_relatedmodel.simple, self.s1) + + async def test_aadd(self): + await self.mtm1.simples.aadd(self.s1) + self.assertEqual(await self.mtm1.simples.aget(), self.s1) + + async def test_aadd_reverse(self): + r1 = await RelatedModel.objects.acreate() + await self.s1.relatedmodel_set.aadd(r1, bulk=False) + self.assertEqual(await self.s1.relatedmodel_set.aget(), r1) + + async def test_aremove(self): + self.assertEqual(await self.mtm2.simples.acount(), 1) + await self.mtm2.simples.aremove(self.s1) + self.assertEqual(await self.mtm2.simples.acount(), 0) + + async def test_aremove_reverse(self): + r1 = await RelatedModel.objects.acreate(simple=self.s1) + self.assertEqual(await self.s1.relatedmodel_set.acount(), 1) + await self.s1.relatedmodel_set.aremove(r1) + self.assertEqual(await self.s1.relatedmodel_set.acount(), 0) + + async def test_aset(self): + await self.mtm1.simples.aset([self.s1]) + self.assertEqual(await self.mtm1.simples.aget(), self.s1) + await self.mtm1.simples.aset([]) + self.assertEqual(await self.mtm1.simples.acount(), 0) + await self.mtm1.simples.aset([self.s1], clear=True) + self.assertEqual(await self.mtm1.simples.aget(), self.s1) + + async def test_aset_reverse(self): + r1 = await RelatedModel.objects.acreate() + await self.s1.relatedmodel_set.aset([r1]) + self.assertEqual(await self.s1.relatedmodel_set.aget(), r1) + await self.s1.relatedmodel_set.aset([]) + self.assertEqual(await self.s1.relatedmodel_set.acount(), 0) + await self.s1.relatedmodel_set.aset([r1], bulk=False, clear=True) + self.assertEqual(await self.s1.relatedmodel_set.aget(), r1) + + async def test_aclear(self): + self.assertEqual(await self.mtm2.simples.acount(), 1) + await self.mtm2.simples.aclear() + self.assertEqual(await self.mtm2.simples.acount(), 0) + + async def test_aclear_reverse(self): + await RelatedModel.objects.acreate(simple=self.s1) + self.assertEqual(await self.s1.relatedmodel_set.acount(), 1) + await self.s1.relatedmodel_set.aclear(bulk=False) + self.assertEqual(await self.s1.relatedmodel_set.acount(), 0) diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index e6bee11cdf2..18e3578f60c 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -324,6 +324,13 @@ class GenericRelationsTests(TestCase): with self.assertRaisesMessage(TypeError, msg): self.bacon.tags.add(self.lion) + async def test_aadd(self): + bacon = await Vegetable.objects.acreate(name="Bacon", is_yucky=False) + t1 = await TaggedItem.objects.acreate(content_object=self.quartz, tag="shiny") + t2 = await TaggedItem.objects.acreate(content_object=self.quartz, tag="fatty") + await bacon.tags.aadd(t1, t2, bulk=False) + self.assertEqual(await bacon.tags.acount(), 2) + def test_set(self): bacon = Vegetable.objects.create(name="Bacon", is_yucky=False) fatty = bacon.tags.create(tag="fatty") @@ -347,6 +354,16 @@ class GenericRelationsTests(TestCase): bacon.tags.set([], clear=True) self.assertSequenceEqual(bacon.tags.all(), []) + async def test_aset(self): + bacon = await Vegetable.objects.acreate(name="Bacon", is_yucky=False) + fatty = await bacon.tags.acreate(tag="fatty") + await bacon.tags.aset([fatty]) + self.assertEqual(await bacon.tags.acount(), 1) + await bacon.tags.aset([]) + self.assertEqual(await bacon.tags.acount(), 0) + await bacon.tags.aset([fatty], bulk=False, clear=True) + self.assertEqual(await bacon.tags.acount(), 1) + def test_assign(self): bacon = Vegetable.objects.create(name="Bacon", is_yucky=False) fatty = bacon.tags.create(tag="fatty") @@ -388,6 +405,10 @@ class GenericRelationsTests(TestCase): [self.hairy, self.yellow], ) + async def test_aclear(self): + await self.bacon.tags.aclear() + self.assertEqual(await self.bacon.tags.acount(), 0) + def test_remove(self): self.assertSequenceEqual( TaggedItem.objects.order_by("tag"), @@ -400,6 +421,12 @@ class GenericRelationsTests(TestCase): [self.hairy, self.salty, self.yellow], ) + async def test_aremove(self): + await self.bacon.tags.aremove(self.fatty) + self.assertEqual(await self.bacon.tags.acount(), 1) + await self.bacon.tags.aremove(self.salty) + self.assertEqual(await self.bacon.tags.acount(), 0) + def test_generic_relation_related_name_default(self): # GenericRelation isn't usable from the reverse side by default. msg = (