mirror of https://github.com/django/django.git
Fixed #28344 -- Allowed customizing queryset in Model.refresh_from_db()/arefresh_from_db().
The from_queryset parameter can be used to: - use a custom Manager - lock the row until the end of transaction - select additional related objects
This commit is contained in:
parent
f3d10546a8
commit
f92641a636
|
@ -673,7 +673,7 @@ class Model(AltersData, metaclass=ModelBase):
|
||||||
if f.attname not in self.__dict__
|
if f.attname not in self.__dict__
|
||||||
}
|
}
|
||||||
|
|
||||||
def refresh_from_db(self, using=None, fields=None):
|
def refresh_from_db(self, using=None, fields=None, from_queryset=None):
|
||||||
"""
|
"""
|
||||||
Reload field values from the database.
|
Reload field values from the database.
|
||||||
|
|
||||||
|
@ -705,10 +705,13 @@ class Model(AltersData, metaclass=ModelBase):
|
||||||
"are not allowed in fields." % LOOKUP_SEP
|
"are not allowed in fields." % LOOKUP_SEP
|
||||||
)
|
)
|
||||||
|
|
||||||
hints = {"instance": self}
|
if from_queryset is None:
|
||||||
db_instance_qs = self.__class__._base_manager.db_manager(
|
hints = {"instance": self}
|
||||||
using, hints=hints
|
from_queryset = self.__class__._base_manager.db_manager(using, hints=hints)
|
||||||
).filter(pk=self.pk)
|
elif using is not None:
|
||||||
|
from_queryset = from_queryset.using(using)
|
||||||
|
|
||||||
|
db_instance_qs = from_queryset.filter(pk=self.pk)
|
||||||
|
|
||||||
# Use provided fields, if not set then reload all non-deferred fields.
|
# Use provided fields, if not set then reload all non-deferred fields.
|
||||||
deferred_fields = self.get_deferred_fields()
|
deferred_fields = self.get_deferred_fields()
|
||||||
|
@ -729,9 +732,12 @@ class Model(AltersData, metaclass=ModelBase):
|
||||||
# This field wasn't refreshed - skip ahead.
|
# This field wasn't refreshed - skip ahead.
|
||||||
continue
|
continue
|
||||||
setattr(self, field.attname, getattr(db_instance, field.attname))
|
setattr(self, field.attname, getattr(db_instance, field.attname))
|
||||||
# Clear cached foreign keys.
|
# Clear or copy cached foreign keys.
|
||||||
if field.is_relation and field.is_cached(self):
|
if field.is_relation:
|
||||||
field.delete_cached_value(self)
|
if field.is_cached(db_instance):
|
||||||
|
field.set_cached_value(self, field.get_cached_value(db_instance))
|
||||||
|
elif field.is_cached(self):
|
||||||
|
field.delete_cached_value(self)
|
||||||
|
|
||||||
# Clear cached relations.
|
# Clear cached relations.
|
||||||
for field in self._meta.related_objects:
|
for field in self._meta.related_objects:
|
||||||
|
@ -745,8 +751,10 @@ class Model(AltersData, metaclass=ModelBase):
|
||||||
|
|
||||||
self._state.db = db_instance._state.db
|
self._state.db = db_instance._state.db
|
||||||
|
|
||||||
async def arefresh_from_db(self, using=None, fields=None):
|
async def arefresh_from_db(self, using=None, fields=None, from_queryset=None):
|
||||||
return await sync_to_async(self.refresh_from_db)(using=using, fields=fields)
|
return await sync_to_async(self.refresh_from_db)(
|
||||||
|
using=using, fields=fields, from_queryset=from_queryset
|
||||||
|
)
|
||||||
|
|
||||||
def serializable_value(self, field_name):
|
def serializable_value(self, field_name):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -142,8 +142,8 @@ value from the database:
|
||||||
>>> del obj.field
|
>>> del obj.field
|
||||||
>>> obj.field # Loads the field from the database
|
>>> obj.field # Loads the field from the database
|
||||||
|
|
||||||
.. method:: Model.refresh_from_db(using=None, fields=None)
|
.. method:: Model.refresh_from_db(using=None, fields=None, from_queryset=None)
|
||||||
.. method:: Model.arefresh_from_db(using=None, fields=None)
|
.. method:: Model.arefresh_from_db(using=None, fields=None, from_queryset=None)
|
||||||
|
|
||||||
*Asynchronous version*: ``arefresh_from_db()``
|
*Asynchronous version*: ``arefresh_from_db()``
|
||||||
|
|
||||||
|
@ -197,6 +197,27 @@ all of the instance's fields when a deferred field is reloaded::
|
||||||
fields = fields.union(deferred_fields)
|
fields = fields.union(deferred_fields)
|
||||||
super().refresh_from_db(using, fields, **kwargs)
|
super().refresh_from_db(using, fields, **kwargs)
|
||||||
|
|
||||||
|
The ``from_queryset`` argument allows using a different queryset than the one
|
||||||
|
created from :attr:`~django.db.models.Model._base_manager`. It gives you more
|
||||||
|
control over how the model is reloaded. For example, when your model uses soft
|
||||||
|
deletion you can make ``refresh_from_db()`` to take this into account::
|
||||||
|
|
||||||
|
obj.refresh_from_db(from_queryset=MyModel.active_objects.all())
|
||||||
|
|
||||||
|
You can cache related objects that otherwise would be cleared from the reloaded
|
||||||
|
instance::
|
||||||
|
|
||||||
|
obj.refresh_from_db(from_queryset=MyModel.objects.select_related("related_field"))
|
||||||
|
|
||||||
|
You can lock the row until the end of transaction before reloading a model's
|
||||||
|
values::
|
||||||
|
|
||||||
|
obj.refresh_from_db(from_queryset=MyModel.objects.select_for_update())
|
||||||
|
|
||||||
|
.. versionchanged:: 5.1
|
||||||
|
|
||||||
|
The ``from_queryset`` argument was added.
|
||||||
|
|
||||||
.. method:: Model.get_deferred_fields()
|
.. method:: Model.get_deferred_fields()
|
||||||
|
|
||||||
A helper method that returns a set containing the attribute names of all those
|
A helper method that returns a set containing the attribute names of all those
|
||||||
|
|
|
@ -208,6 +208,11 @@ Models
|
||||||
:class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced
|
:class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced
|
||||||
<slicing-using-f>`.
|
<slicing-using-f>`.
|
||||||
|
|
||||||
|
* The new ``from_queryset`` argument of :meth:`.Model.refresh_from_db` and
|
||||||
|
:meth:`.Model.arefresh_from_db` allows customizing the queryset used to
|
||||||
|
reload a model's value. This can be used to lock the row before reloading or
|
||||||
|
to select related objects.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -23,3 +23,14 @@ class AsyncModelOperationTest(TestCase):
|
||||||
await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
|
await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
|
||||||
await self.s1.arefresh_from_db()
|
await self.s1.arefresh_from_db()
|
||||||
self.assertEqual(self.s1.field, 20)
|
self.assertEqual(self.s1.field, 20)
|
||||||
|
|
||||||
|
async def test_arefresh_from_db_from_queryset(self):
|
||||||
|
await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
|
||||||
|
with self.assertRaises(SimpleModel.DoesNotExist):
|
||||||
|
await self.s1.arefresh_from_db(
|
||||||
|
from_queryset=SimpleModel.objects.filter(field=0)
|
||||||
|
)
|
||||||
|
await self.s1.arefresh_from_db(
|
||||||
|
from_queryset=SimpleModel.objects.filter(field__gt=0)
|
||||||
|
)
|
||||||
|
self.assertEqual(self.s1.field, 20)
|
||||||
|
|
|
@ -4,7 +4,14 @@ from datetime import datetime, timedelta
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
|
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
|
||||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
|
from django.db import (
|
||||||
|
DEFAULT_DB_ALIAS,
|
||||||
|
DatabaseError,
|
||||||
|
connection,
|
||||||
|
connections,
|
||||||
|
models,
|
||||||
|
transaction,
|
||||||
|
)
|
||||||
from django.db.models.manager import BaseManager
|
from django.db.models.manager import BaseManager
|
||||||
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
|
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
|
||||||
from django.test import (
|
from django.test import (
|
||||||
|
@ -13,7 +20,8 @@ from django.test import (
|
||||||
TransactionTestCase,
|
TransactionTestCase,
|
||||||
skipUnlessDBFeature,
|
skipUnlessDBFeature,
|
||||||
)
|
)
|
||||||
from django.test.utils import ignore_warnings
|
from django.test.utils import CaptureQueriesContext, ignore_warnings
|
||||||
|
from django.utils.connection import ConnectionDoesNotExist
|
||||||
from django.utils.deprecation import RemovedInDjango60Warning
|
from django.utils.deprecation import RemovedInDjango60Warning
|
||||||
from django.utils.translation import gettext_lazy
|
from django.utils.translation import gettext_lazy
|
||||||
|
|
||||||
|
@ -1003,3 +1011,47 @@ class ModelRefreshTests(TestCase):
|
||||||
# Cache was cleared and new results are available.
|
# Cache was cleared and new results are available.
|
||||||
self.assertCountEqual(a2_prefetched.selfref_set.all(), [s])
|
self.assertCountEqual(a2_prefetched.selfref_set.all(), [s])
|
||||||
self.assertCountEqual(a2_prefetched.cited.all(), [s])
|
self.assertCountEqual(a2_prefetched.cited.all(), [s])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("has_select_for_update")
|
||||||
|
def test_refresh_for_update(self):
|
||||||
|
a = Article.objects.create(pub_date=datetime.now())
|
||||||
|
for_update_sql = connection.ops.for_update_sql()
|
||||||
|
|
||||||
|
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
|
||||||
|
a.refresh_from_db(from_queryset=Article.objects.select_for_update())
|
||||||
|
self.assertTrue(
|
||||||
|
any(for_update_sql in query["sql"] for query in ctx.captured_queries)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_refresh_with_related(self):
|
||||||
|
a = Article.objects.create(pub_date=datetime.now())
|
||||||
|
fa = FeaturedArticle.objects.create(article=a)
|
||||||
|
|
||||||
|
from_queryset = FeaturedArticle.objects.select_related("article")
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
fa.refresh_from_db(from_queryset=from_queryset)
|
||||||
|
self.assertEqual(fa.article.pub_date, a.pub_date)
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
fa.refresh_from_db()
|
||||||
|
self.assertEqual(fa.article.pub_date, a.pub_date)
|
||||||
|
|
||||||
|
def test_refresh_overwrites_queryset_using(self):
|
||||||
|
a = Article.objects.create(pub_date=datetime.now())
|
||||||
|
|
||||||
|
from_queryset = Article.objects.using("nonexistent")
|
||||||
|
with self.assertRaises(ConnectionDoesNotExist):
|
||||||
|
a.refresh_from_db(from_queryset=from_queryset)
|
||||||
|
a.refresh_from_db(using="default", from_queryset=from_queryset)
|
||||||
|
|
||||||
|
def test_refresh_overwrites_queryset_fields(self):
|
||||||
|
a = Article.objects.create(pub_date=datetime.now())
|
||||||
|
headline = "headline"
|
||||||
|
Article.objects.filter(pk=a.pk).update(headline=headline)
|
||||||
|
|
||||||
|
from_queryset = Article.objects.only("pub_date")
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
a.refresh_from_db(from_queryset=from_queryset)
|
||||||
|
self.assertNotEqual(a.headline, headline)
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
a.refresh_from_db(fields=["headline"], from_queryset=from_queryset)
|
||||||
|
self.assertEqual(a.headline, headline)
|
||||||
|
|
Loading…
Reference in New Issue