Fixed #34139 -- Fixed acreate(), aget_or_create(), and aupdate_or_create() methods for related managers.

Bug in 58b27e0dbb.
This commit is contained in:
Jon Janzen 2022-11-04 15:22:32 +01:00 committed by Mariusz Felisiak
parent 76e37513e2
commit 7b94847e38
8 changed files with 155 additions and 0 deletions

View File

@ -495,6 +495,7 @@ answer newbie questions, and generally made Django that much better:
John Shaffer <jshaffer2112@gmail.com> John Shaffer <jshaffer2112@gmail.com>
Jökull Sólberg Auðunsson <jokullsolberg@gmail.com> Jökull Sólberg Auðunsson <jokullsolberg@gmail.com>
Jon Dufresne <jon.dufresne@gmail.com> Jon Dufresne <jon.dufresne@gmail.com>
Jon Janzen <jon@jonjanzen.com>
Jonas Haag <jonas@lophus.org> Jonas Haag <jonas@lophus.org>
Jonas Lundberg <jonas.lundberg@gmail.com> Jonas Lundberg <jonas.lundberg@gmail.com>
Jonathan Davis <jonathandavis47780@gmail.com> Jonathan Davis <jonathandavis47780@gmail.com>

View File

@ -2,6 +2,8 @@ import functools
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from asgiref.sync import sync_to_async
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core import checks from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
@ -747,6 +749,11 @@ def create_generic_related_manager(superclass, rel):
create.alters_data = True create.alters_data = True
async def acreate(self, **kwargs):
return await sync_to_async(self.create)(**kwargs)
acreate.alters_data = True
def get_or_create(self, **kwargs): def get_or_create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val kwargs[self.object_id_field_name] = self.pk_val
@ -755,6 +762,11 @@ def create_generic_related_manager(superclass, rel):
get_or_create.alters_data = True get_or_create.alters_data = True
async def aget_or_create(self, **kwargs):
return await sync_to_async(self.get_or_create)(**kwargs)
aget_or_create.alters_data = True
def update_or_create(self, **kwargs): def update_or_create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val kwargs[self.object_id_field_name] = self.pk_val
@ -763,4 +775,9 @@ def create_generic_related_manager(superclass, rel):
update_or_create.alters_data = True update_or_create.alters_data = True
async def aupdate_or_create(self, **kwargs):
return await sync_to_async(self.update_or_create)(**kwargs)
aupdate_or_create.alters_data = True
return GenericRelatedObjectManager return GenericRelatedObjectManager

View File

@ -63,6 +63,8 @@ and two directions (forward and reverse) for a total of six combinations.
``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead. ``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead.
""" """
from asgiref.sync import sync_to_async
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import ( from django.db import (
DEFAULT_DB_ALIAS, DEFAULT_DB_ALIAS,
@ -793,6 +795,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
create.alters_data = True create.alters_data = True
async def acreate(self, **kwargs):
return await sync_to_async(self.create)(**kwargs)
acreate.alters_data = True
def get_or_create(self, **kwargs): def get_or_create(self, **kwargs):
self._check_fk_val() self._check_fk_val()
kwargs[self.field.name] = self.instance kwargs[self.field.name] = self.instance
@ -801,6 +808,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
get_or_create.alters_data = True get_or_create.alters_data = True
async def aget_or_create(self, **kwargs):
return await sync_to_async(self.get_or_create)(**kwargs)
aget_or_create.alters_data = True
def update_or_create(self, **kwargs): def update_or_create(self, **kwargs):
self._check_fk_val() self._check_fk_val()
kwargs[self.field.name] = self.instance kwargs[self.field.name] = self.instance
@ -809,6 +821,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
update_or_create.alters_data = True update_or_create.alters_data = True
async def aupdate_or_create(self, **kwargs):
return await sync_to_async(self.update_or_create)(**kwargs)
aupdate_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a # remove() and clear() are only provided if the ForeignKey can have a
# value of null. # value of null.
if rel.field.null: if rel.field.null:
@ -1191,6 +1208,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
create.alters_data = True create.alters_data = True
async def acreate(self, *, through_defaults=None, **kwargs):
return await sync_to_async(self.create)(
through_defaults=through_defaults, **kwargs
)
acreate.alters_data = True
def get_or_create(self, *, through_defaults=None, **kwargs): def get_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance) db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create( obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
@ -1204,6 +1228,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
get_or_create.alters_data = True get_or_create.alters_data = True
async def aget_or_create(self, *, through_defaults=None, **kwargs):
return await sync_to_async(self.get_or_create)(
through_defaults=through_defaults, **kwargs
)
aget_or_create.alters_data = True
def update_or_create(self, *, through_defaults=None, **kwargs): def update_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance) db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super( obj, created = super(
@ -1217,6 +1248,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
update_or_create.alters_data = True update_or_create.alters_data = True
async def aupdate_or_create(self, *, through_defaults=None, **kwargs):
return await sync_to_async(self.update_or_create)(
through_defaults=through_defaults, **kwargs
)
aupdate_or_create.alters_data = True
def _get_target_ids(self, target_field_name, objs): def _get_target_ids(self, target_field_name, objs):
""" """
Return the set of ids of `objs` that the target field references. Return the set of ids of `objs` that the target field references.

View File

@ -76,6 +76,9 @@ Related objects reference
intermediate instance(s). intermediate instance(s).
.. method:: create(through_defaults=None, **kwargs) .. method:: create(through_defaults=None, **kwargs)
.. method:: acreate(through_defaults=None, **kwargs)
*Asynchronous version*: ``acreate``
Creates a new object, saves it and puts it in the related object set. Creates a new object, saves it and puts it in the related object set.
Returns the newly created object:: Returns the newly created object::
@ -110,6 +113,10 @@ Related objects reference
needed. You can use callables as values in the ``through_defaults`` needed. You can use callables as values in the ``through_defaults``
dictionary. dictionary.
.. versionchanged:: 4.1
``acreate()`` method was added.
.. method:: remove(*objs, bulk=True) .. method:: remove(*objs, bulk=True)
Removes the specified model objects from the related object set:: Removes the specified model objects from the related object set::

View File

@ -16,3 +16,7 @@ Bugfixes
an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and
a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod` a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod`
(:ticket:`34088`). (:ticket:`34088`).
* Fixed a bug in Django 4.1 that caused a crash of ``acreate()``,
``aget_or_create()``, and ``aupdate_or_create()`` asynchronous methods for
related managers (:ticket:`34139`).

View File

@ -9,3 +9,7 @@ class RelatedModel(models.Model):
class SimpleModel(models.Model): class SimpleModel(models.Model):
field = models.IntegerField() field = models.IntegerField()
created = models.DateTimeField(default=timezone.now) created = models.DateTimeField(default=timezone.now)
class ManyToManyModel(models.Model):
simples = models.ManyToManyField("SimpleModel")

View File

@ -0,0 +1,56 @@
from django.test import TestCase
from .models import ManyToManyModel, SimpleModel
class AsyncRelatedManagersOperationTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.mtm1 = ManyToManyModel.objects.create()
cls.s1 = SimpleModel.objects.create(field=0)
async def test_acreate(self):
await self.mtm1.simples.acreate(field=2)
new_simple = await self.mtm1.simples.aget()
self.assertEqual(new_simple.field, 2)
async def test_acreate_reverse(self):
await self.s1.relatedmodel_set.acreate()
new_relatedmodel = await self.s1.relatedmodel_set.aget()
self.assertEqual(new_relatedmodel.simple, self.s1)
async def test_aget_or_create(self):
new_simple, created = await self.mtm1.simples.aget_or_create(field=2)
self.assertIs(created, True)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2)
new_simple, created = await self.mtm1.simples.aget_or_create(
id=new_simple.id, through_defaults={"field": 3}
)
self.assertIs(created, False)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2)
async def test_aget_or_create_reverse(self):
new_relatedmodel, created = await self.s1.relatedmodel_set.aget_or_create()
self.assertIs(created, True)
self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
self.assertEqual(new_relatedmodel.simple, self.s1)
async def test_aupdate_or_create(self):
new_simple, created = await self.mtm1.simples.aupdate_or_create(field=2)
self.assertIs(created, True)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2)
new_simple, created = await self.mtm1.simples.aupdate_or_create(
id=new_simple.id, defaults={"field": 3}
)
self.assertIs(created, False)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 3)
async def test_aupdate_or_create_reverse(self):
new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create()
self.assertIs(created, True)
self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
self.assertEqual(new_relatedmodel.simple, self.s1)

View File

@ -45,6 +45,10 @@ class GenericRelationsTests(TestCase):
# Original list of tags: # Original list of tags:
return obj.tag, obj.content_type.model_class(), obj.object_id return obj.tag, obj.content_type.model_class(), obj.object_id
async def test_generic_async_acreate(self):
await self.bacon.tags.acreate(tag="orange")
self.assertEqual(await self.bacon.tags.acount(), 3)
def test_generic_update_or_create_when_created(self): def test_generic_update_or_create_when_created(self):
""" """
Should be able to use update_or_create from the generic related manager Should be able to use update_or_create from the generic related manager
@ -70,6 +74,18 @@ class GenericRelationsTests(TestCase):
self.assertEqual(count + 1, self.bacon.tags.count()) self.assertEqual(count + 1, self.bacon.tags.count())
self.assertEqual(tag.tag, "juicy") self.assertEqual(tag.tag, "juicy")
async def test_generic_async_aupdate_or_create(self):
tag, created = await self.bacon.tags.aupdate_or_create(
id=self.fatty.id, defaults={"tag": "orange"}
)
self.assertIs(created, False)
self.assertEqual(tag.tag, "orange")
self.assertEqual(await self.bacon.tags.acount(), 2)
tag, created = await self.bacon.tags.aupdate_or_create(tag="pink")
self.assertIs(created, True)
self.assertEqual(await self.bacon.tags.acount(), 3)
self.assertEqual(tag.tag, "pink")
def test_generic_get_or_create_when_created(self): def test_generic_get_or_create_when_created(self):
""" """
Should be able to use get_or_create from the generic related manager Should be able to use get_or_create from the generic related manager
@ -96,6 +112,18 @@ class GenericRelationsTests(TestCase):
# shouldn't had changed the tag # shouldn't had changed the tag
self.assertEqual(tag.tag, "stinky") self.assertEqual(tag.tag, "stinky")
async def test_generic_async_aget_or_create(self):
tag, created = await self.bacon.tags.aget_or_create(
id=self.fatty.id, defaults={"tag": "orange"}
)
self.assertIs(created, False)
self.assertEqual(tag.tag, "fatty")
self.assertEqual(await self.bacon.tags.acount(), 2)
tag, created = await self.bacon.tags.aget_or_create(tag="orange")
self.assertIs(created, True)
self.assertEqual(await self.bacon.tags.acount(), 3)
self.assertEqual(tag.tag, "orange")
def test_generic_relations_m2m_mimic(self): def test_generic_relations_m2m_mimic(self):
""" """
Objects with declared GenericRelations can be tagged directly -- the Objects with declared GenericRelations can be tagged directly -- the