mirror of https://github.com/django/django.git
Fixed #34139 -- Fixed acreate(), aget_or_create(), and aupdate_or_create() methods for related managers.
Bug in 58b27e0dbb
.
This commit is contained in:
parent
76e37513e2
commit
7b94847e38
1
AUTHORS
1
AUTHORS
|
@ -495,6 +495,7 @@ answer newbie questions, and generally made Django that much better:
|
||||||
John Shaffer <jshaffer2112@gmail.com>
|
John Shaffer <jshaffer2112@gmail.com>
|
||||||
Jökull Sólberg Auðunsson <jokullsolberg@gmail.com>
|
Jökull Sólberg Auðunsson <jokullsolberg@gmail.com>
|
||||||
Jon Dufresne <jon.dufresne@gmail.com>
|
Jon Dufresne <jon.dufresne@gmail.com>
|
||||||
|
Jon Janzen <jon@jonjanzen.com>
|
||||||
Jonas Haag <jonas@lophus.org>
|
Jonas Haag <jonas@lophus.org>
|
||||||
Jonas Lundberg <jonas.lundberg@gmail.com>
|
Jonas Lundberg <jonas.lundberg@gmail.com>
|
||||||
Jonathan Davis <jonathandavis47780@gmail.com>
|
Jonathan Davis <jonathandavis47780@gmail.com>
|
||||||
|
|
|
@ -2,6 +2,8 @@ import functools
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.core import checks
|
from django.core import checks
|
||||||
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
|
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
|
||||||
|
@ -747,6 +749,11 @@ def create_generic_related_manager(superclass, rel):
|
||||||
|
|
||||||
create.alters_data = True
|
create.alters_data = True
|
||||||
|
|
||||||
|
async def acreate(self, **kwargs):
|
||||||
|
return await sync_to_async(self.create)(**kwargs)
|
||||||
|
|
||||||
|
acreate.alters_data = True
|
||||||
|
|
||||||
def get_or_create(self, **kwargs):
|
def get_or_create(self, **kwargs):
|
||||||
kwargs[self.content_type_field_name] = self.content_type
|
kwargs[self.content_type_field_name] = self.content_type
|
||||||
kwargs[self.object_id_field_name] = self.pk_val
|
kwargs[self.object_id_field_name] = self.pk_val
|
||||||
|
@ -755,6 +762,11 @@ def create_generic_related_manager(superclass, rel):
|
||||||
|
|
||||||
get_or_create.alters_data = True
|
get_or_create.alters_data = True
|
||||||
|
|
||||||
|
async def aget_or_create(self, **kwargs):
|
||||||
|
return await sync_to_async(self.get_or_create)(**kwargs)
|
||||||
|
|
||||||
|
aget_or_create.alters_data = True
|
||||||
|
|
||||||
def update_or_create(self, **kwargs):
|
def update_or_create(self, **kwargs):
|
||||||
kwargs[self.content_type_field_name] = self.content_type
|
kwargs[self.content_type_field_name] = self.content_type
|
||||||
kwargs[self.object_id_field_name] = self.pk_val
|
kwargs[self.object_id_field_name] = self.pk_val
|
||||||
|
@ -763,4 +775,9 @@ def create_generic_related_manager(superclass, rel):
|
||||||
|
|
||||||
update_or_create.alters_data = True
|
update_or_create.alters_data = True
|
||||||
|
|
||||||
|
async def aupdate_or_create(self, **kwargs):
|
||||||
|
return await sync_to_async(self.update_or_create)(**kwargs)
|
||||||
|
|
||||||
|
aupdate_or_create.alters_data = True
|
||||||
|
|
||||||
return GenericRelatedObjectManager
|
return GenericRelatedObjectManager
|
||||||
|
|
|
@ -63,6 +63,8 @@ and two directions (forward and reverse) for a total of six combinations.
|
||||||
``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead.
|
``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import (
|
from django.db import (
|
||||||
DEFAULT_DB_ALIAS,
|
DEFAULT_DB_ALIAS,
|
||||||
|
@ -793,6 +795,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||||
|
|
||||||
create.alters_data = True
|
create.alters_data = True
|
||||||
|
|
||||||
|
async def acreate(self, **kwargs):
|
||||||
|
return await sync_to_async(self.create)(**kwargs)
|
||||||
|
|
||||||
|
acreate.alters_data = True
|
||||||
|
|
||||||
def get_or_create(self, **kwargs):
|
def get_or_create(self, **kwargs):
|
||||||
self._check_fk_val()
|
self._check_fk_val()
|
||||||
kwargs[self.field.name] = self.instance
|
kwargs[self.field.name] = self.instance
|
||||||
|
@ -801,6 +808,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||||
|
|
||||||
get_or_create.alters_data = True
|
get_or_create.alters_data = True
|
||||||
|
|
||||||
|
async def aget_or_create(self, **kwargs):
|
||||||
|
return await sync_to_async(self.get_or_create)(**kwargs)
|
||||||
|
|
||||||
|
aget_or_create.alters_data = True
|
||||||
|
|
||||||
def update_or_create(self, **kwargs):
|
def update_or_create(self, **kwargs):
|
||||||
self._check_fk_val()
|
self._check_fk_val()
|
||||||
kwargs[self.field.name] = self.instance
|
kwargs[self.field.name] = self.instance
|
||||||
|
@ -809,6 +821,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||||
|
|
||||||
update_or_create.alters_data = True
|
update_or_create.alters_data = True
|
||||||
|
|
||||||
|
async def aupdate_or_create(self, **kwargs):
|
||||||
|
return await sync_to_async(self.update_or_create)(**kwargs)
|
||||||
|
|
||||||
|
aupdate_or_create.alters_data = True
|
||||||
|
|
||||||
# remove() and clear() are only provided if the ForeignKey can have a
|
# remove() and clear() are only provided if the ForeignKey can have a
|
||||||
# value of null.
|
# value of null.
|
||||||
if rel.field.null:
|
if rel.field.null:
|
||||||
|
@ -1191,6 +1208,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
|
|
||||||
create.alters_data = True
|
create.alters_data = True
|
||||||
|
|
||||||
|
async def acreate(self, *, through_defaults=None, **kwargs):
|
||||||
|
return await sync_to_async(self.create)(
|
||||||
|
through_defaults=through_defaults, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
acreate.alters_data = True
|
||||||
|
|
||||||
def get_or_create(self, *, through_defaults=None, **kwargs):
|
def get_or_create(self, *, through_defaults=None, **kwargs):
|
||||||
db = router.db_for_write(self.instance.__class__, instance=self.instance)
|
db = router.db_for_write(self.instance.__class__, instance=self.instance)
|
||||||
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
|
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
|
||||||
|
@ -1204,6 +1228,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
|
|
||||||
get_or_create.alters_data = True
|
get_or_create.alters_data = True
|
||||||
|
|
||||||
|
async def aget_or_create(self, *, through_defaults=None, **kwargs):
|
||||||
|
return await sync_to_async(self.get_or_create)(
|
||||||
|
through_defaults=through_defaults, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
aget_or_create.alters_data = True
|
||||||
|
|
||||||
def update_or_create(self, *, through_defaults=None, **kwargs):
|
def update_or_create(self, *, through_defaults=None, **kwargs):
|
||||||
db = router.db_for_write(self.instance.__class__, instance=self.instance)
|
db = router.db_for_write(self.instance.__class__, instance=self.instance)
|
||||||
obj, created = super(
|
obj, created = super(
|
||||||
|
@ -1217,6 +1248,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
|
|
||||||
update_or_create.alters_data = True
|
update_or_create.alters_data = True
|
||||||
|
|
||||||
|
async def aupdate_or_create(self, *, through_defaults=None, **kwargs):
|
||||||
|
return await sync_to_async(self.update_or_create)(
|
||||||
|
through_defaults=through_defaults, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
aupdate_or_create.alters_data = True
|
||||||
|
|
||||||
def _get_target_ids(self, target_field_name, objs):
|
def _get_target_ids(self, target_field_name, objs):
|
||||||
"""
|
"""
|
||||||
Return the set of ids of `objs` that the target field references.
|
Return the set of ids of `objs` that the target field references.
|
||||||
|
|
|
@ -76,6 +76,9 @@ Related objects reference
|
||||||
intermediate instance(s).
|
intermediate instance(s).
|
||||||
|
|
||||||
.. method:: create(through_defaults=None, **kwargs)
|
.. method:: create(through_defaults=None, **kwargs)
|
||||||
|
.. method:: acreate(through_defaults=None, **kwargs)
|
||||||
|
|
||||||
|
*Asynchronous version*: ``acreate``
|
||||||
|
|
||||||
Creates a new object, saves it and puts it in the related object set.
|
Creates a new object, saves it and puts it in the related object set.
|
||||||
Returns the newly created object::
|
Returns the newly created object::
|
||||||
|
@ -110,6 +113,10 @@ Related objects reference
|
||||||
needed. You can use callables as values in the ``through_defaults``
|
needed. You can use callables as values in the ``through_defaults``
|
||||||
dictionary.
|
dictionary.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.1
|
||||||
|
|
||||||
|
``acreate()`` method was added.
|
||||||
|
|
||||||
.. method:: remove(*objs, bulk=True)
|
.. method:: remove(*objs, bulk=True)
|
||||||
|
|
||||||
Removes the specified model objects from the related object set::
|
Removes the specified model objects from the related object set::
|
||||||
|
|
|
@ -16,3 +16,7 @@ Bugfixes
|
||||||
an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and
|
an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and
|
||||||
a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod`
|
a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod`
|
||||||
(:ticket:`34088`).
|
(:ticket:`34088`).
|
||||||
|
|
||||||
|
* Fixed a bug in Django 4.1 that caused a crash of ``acreate()``,
|
||||||
|
``aget_or_create()``, and ``aupdate_or_create()`` asynchronous methods for
|
||||||
|
related managers (:ticket:`34139`).
|
||||||
|
|
|
@ -9,3 +9,7 @@ class RelatedModel(models.Model):
|
||||||
class SimpleModel(models.Model):
|
class SimpleModel(models.Model):
|
||||||
field = models.IntegerField()
|
field = models.IntegerField()
|
||||||
created = models.DateTimeField(default=timezone.now)
|
created = models.DateTimeField(default=timezone.now)
|
||||||
|
|
||||||
|
|
||||||
|
class ManyToManyModel(models.Model):
|
||||||
|
simples = models.ManyToManyField("SimpleModel")
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from .models import ManyToManyModel, SimpleModel
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRelatedManagersOperationTest(TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
cls.mtm1 = ManyToManyModel.objects.create()
|
||||||
|
cls.s1 = SimpleModel.objects.create(field=0)
|
||||||
|
|
||||||
|
async def test_acreate(self):
|
||||||
|
await self.mtm1.simples.acreate(field=2)
|
||||||
|
new_simple = await self.mtm1.simples.aget()
|
||||||
|
self.assertEqual(new_simple.field, 2)
|
||||||
|
|
||||||
|
async def test_acreate_reverse(self):
|
||||||
|
await self.s1.relatedmodel_set.acreate()
|
||||||
|
new_relatedmodel = await self.s1.relatedmodel_set.aget()
|
||||||
|
self.assertEqual(new_relatedmodel.simple, self.s1)
|
||||||
|
|
||||||
|
async def test_aget_or_create(self):
|
||||||
|
new_simple, created = await self.mtm1.simples.aget_or_create(field=2)
|
||||||
|
self.assertIs(created, True)
|
||||||
|
self.assertEqual(await self.mtm1.simples.acount(), 1)
|
||||||
|
self.assertEqual(new_simple.field, 2)
|
||||||
|
new_simple, created = await self.mtm1.simples.aget_or_create(
|
||||||
|
id=new_simple.id, through_defaults={"field": 3}
|
||||||
|
)
|
||||||
|
self.assertIs(created, False)
|
||||||
|
self.assertEqual(await self.mtm1.simples.acount(), 1)
|
||||||
|
self.assertEqual(new_simple.field, 2)
|
||||||
|
|
||||||
|
async def test_aget_or_create_reverse(self):
|
||||||
|
new_relatedmodel, created = await self.s1.relatedmodel_set.aget_or_create()
|
||||||
|
self.assertIs(created, True)
|
||||||
|
self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
|
||||||
|
self.assertEqual(new_relatedmodel.simple, self.s1)
|
||||||
|
|
||||||
|
async def test_aupdate_or_create(self):
|
||||||
|
new_simple, created = await self.mtm1.simples.aupdate_or_create(field=2)
|
||||||
|
self.assertIs(created, True)
|
||||||
|
self.assertEqual(await self.mtm1.simples.acount(), 1)
|
||||||
|
self.assertEqual(new_simple.field, 2)
|
||||||
|
new_simple, created = await self.mtm1.simples.aupdate_or_create(
|
||||||
|
id=new_simple.id, defaults={"field": 3}
|
||||||
|
)
|
||||||
|
self.assertIs(created, False)
|
||||||
|
self.assertEqual(await self.mtm1.simples.acount(), 1)
|
||||||
|
self.assertEqual(new_simple.field, 3)
|
||||||
|
|
||||||
|
async def test_aupdate_or_create_reverse(self):
|
||||||
|
new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create()
|
||||||
|
self.assertIs(created, True)
|
||||||
|
self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
|
||||||
|
self.assertEqual(new_relatedmodel.simple, self.s1)
|
|
@ -45,6 +45,10 @@ class GenericRelationsTests(TestCase):
|
||||||
# Original list of tags:
|
# Original list of tags:
|
||||||
return obj.tag, obj.content_type.model_class(), obj.object_id
|
return obj.tag, obj.content_type.model_class(), obj.object_id
|
||||||
|
|
||||||
|
async def test_generic_async_acreate(self):
|
||||||
|
await self.bacon.tags.acreate(tag="orange")
|
||||||
|
self.assertEqual(await self.bacon.tags.acount(), 3)
|
||||||
|
|
||||||
def test_generic_update_or_create_when_created(self):
|
def test_generic_update_or_create_when_created(self):
|
||||||
"""
|
"""
|
||||||
Should be able to use update_or_create from the generic related manager
|
Should be able to use update_or_create from the generic related manager
|
||||||
|
@ -70,6 +74,18 @@ class GenericRelationsTests(TestCase):
|
||||||
self.assertEqual(count + 1, self.bacon.tags.count())
|
self.assertEqual(count + 1, self.bacon.tags.count())
|
||||||
self.assertEqual(tag.tag, "juicy")
|
self.assertEqual(tag.tag, "juicy")
|
||||||
|
|
||||||
|
async def test_generic_async_aupdate_or_create(self):
|
||||||
|
tag, created = await self.bacon.tags.aupdate_or_create(
|
||||||
|
id=self.fatty.id, defaults={"tag": "orange"}
|
||||||
|
)
|
||||||
|
self.assertIs(created, False)
|
||||||
|
self.assertEqual(tag.tag, "orange")
|
||||||
|
self.assertEqual(await self.bacon.tags.acount(), 2)
|
||||||
|
tag, created = await self.bacon.tags.aupdate_or_create(tag="pink")
|
||||||
|
self.assertIs(created, True)
|
||||||
|
self.assertEqual(await self.bacon.tags.acount(), 3)
|
||||||
|
self.assertEqual(tag.tag, "pink")
|
||||||
|
|
||||||
def test_generic_get_or_create_when_created(self):
|
def test_generic_get_or_create_when_created(self):
|
||||||
"""
|
"""
|
||||||
Should be able to use get_or_create from the generic related manager
|
Should be able to use get_or_create from the generic related manager
|
||||||
|
@ -96,6 +112,18 @@ class GenericRelationsTests(TestCase):
|
||||||
# shouldn't had changed the tag
|
# shouldn't had changed the tag
|
||||||
self.assertEqual(tag.tag, "stinky")
|
self.assertEqual(tag.tag, "stinky")
|
||||||
|
|
||||||
|
async def test_generic_async_aget_or_create(self):
|
||||||
|
tag, created = await self.bacon.tags.aget_or_create(
|
||||||
|
id=self.fatty.id, defaults={"tag": "orange"}
|
||||||
|
)
|
||||||
|
self.assertIs(created, False)
|
||||||
|
self.assertEqual(tag.tag, "fatty")
|
||||||
|
self.assertEqual(await self.bacon.tags.acount(), 2)
|
||||||
|
tag, created = await self.bacon.tags.aget_or_create(tag="orange")
|
||||||
|
self.assertIs(created, True)
|
||||||
|
self.assertEqual(await self.bacon.tags.acount(), 3)
|
||||||
|
self.assertEqual(tag.tag, "orange")
|
||||||
|
|
||||||
def test_generic_relations_m2m_mimic(self):
|
def test_generic_relations_m2m_mimic(self):
|
||||||
"""
|
"""
|
||||||
Objects with declared GenericRelations can be tagged directly -- the
|
Objects with declared GenericRelations can be tagged directly -- the
|
||||||
|
|
Loading…
Reference in New Issue