Fixed #901 -- Added Model.refresh_from_db() method
Thanks to github aliases dbrgn, carljm, slurms, dfunckt, and timgraham for reviews.
This commit is contained in:
parent
912ad03226
commit
c7175fcdfe
|
@ -13,6 +13,8 @@ from django.core.exceptions import (ObjectDoesNotExist,
|
|||
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
|
||||
from django.db import (router, connections, transaction, DatabaseError,
|
||||
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY)
|
||||
from django.db.models import signals
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.deletion import Collector
|
||||
from django.db.models.fields import AutoField, FieldDoesNotExist
|
||||
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
|
||||
|
@ -21,7 +23,6 @@ from django.db.models.manager import ensure_default_manager
|
|||
from django.db.models.options import Options
|
||||
from django.db.models.query import Q
|
||||
from django.db.models.query_utils import DeferredAttribute, deferred_class_factory
|
||||
from django.db.models import signals
|
||||
from django.utils import six
|
||||
from django.utils.deprecation import RemovedInDjango19Warning
|
||||
from django.utils.encoding import force_str, force_text
|
||||
|
@ -552,6 +553,71 @@ class Model(six.with_metaclass(ModelBase)):
|
|||
|
||||
pk = property(_get_pk_val, _set_pk_val)
|
||||
|
||||
def get_deferred_fields(self):
|
||||
"""
|
||||
Returns a set containing names of deferred fields for this instance.
|
||||
"""
|
||||
return {
|
||||
f.attname for f in self._meta.concrete_fields
|
||||
if isinstance(self.__class__.__dict__.get(f.attname), DeferredAttribute)
|
||||
}
|
||||
|
||||
def refresh_from_db(self, using=None, fields=None, **kwargs):
|
||||
"""
|
||||
Reloads field values from the database.
|
||||
|
||||
By default, the reloading happens from the database this instance was
|
||||
loaded from, or by the read router if this instance wasn't loaded from
|
||||
any database. The using parameter will override the default.
|
||||
|
||||
Fields can be used to specify which fields to reload. The fields
|
||||
should be an iterable of field attnames. If fields is None, then
|
||||
all non-deferred fields are reloaded.
|
||||
|
||||
When accessing deferred fields of an instance, the deferred loading
|
||||
of the field will call this method.
|
||||
"""
|
||||
if fields is not None:
|
||||
if len(fields) == 0:
|
||||
return
|
||||
if any(LOOKUP_SEP in f for f in fields):
|
||||
raise ValueError(
|
||||
'Found "%s" in fields argument. Relations and transforms '
|
||||
'are not allowed in fields.' % LOOKUP_SEP)
|
||||
|
||||
db = using if using is not None else self._state.db
|
||||
if self._deferred:
|
||||
non_deferred_model = self._meta.proxy_for_model
|
||||
else:
|
||||
non_deferred_model = self.__class__
|
||||
db_instance_qs = non_deferred_model._default_manager.using(db).filter(pk=self.pk)
|
||||
|
||||
# Use provided fields, if not set then reload all non-deferred fields.
|
||||
if fields is not None:
|
||||
fields = list(fields)
|
||||
db_instance_qs = db_instance_qs.only(*fields)
|
||||
elif self._deferred:
|
||||
deferred_fields = self.get_deferred_fields()
|
||||
fields = [f.attname for f in self._meta.concrete_fields
|
||||
if f.attname not in deferred_fields]
|
||||
db_instance_qs = db_instance_qs.only(*fields)
|
||||
|
||||
db_instance = db_instance_qs.get()
|
||||
non_loaded_fields = db_instance.get_deferred_fields()
|
||||
for field in self._meta.concrete_fields:
|
||||
if field.attname in non_loaded_fields:
|
||||
# This field wasn't refreshed - skip ahead.
|
||||
continue
|
||||
setattr(self, field.attname, getattr(db_instance, field.attname))
|
||||
# Throw away stale foreign key references.
|
||||
if field.rel and field.get_cache_name() in self.__dict__:
|
||||
rel_instance = getattr(self, field.get_cache_name())
|
||||
local_val = getattr(db_instance, field.attname)
|
||||
related_val = getattr(rel_instance, field.related_field.attname)
|
||||
if local_val != related_val:
|
||||
del self.__dict__[field.get_cache_name()]
|
||||
self._state.db = db_instance._state.db
|
||||
|
||||
def serializable_value(self, field_name):
|
||||
"""
|
||||
Returns the value of the field name for this instance. If the field is
|
||||
|
|
|
@ -109,14 +109,8 @@ class DeferredAttribute(object):
|
|||
# might be able to reuse the already loaded value. Refs #18343.
|
||||
val = self._check_parent_chain(instance, name)
|
||||
if val is None:
|
||||
# We use only() instead of values() here because we want the
|
||||
# various data coercion methods (to_python(), etc.) to be
|
||||
# called here.
|
||||
val = getattr(
|
||||
non_deferred_model._base_manager.only(name).using(
|
||||
instance._state.db).get(pk=instance.pk),
|
||||
self.field_name
|
||||
)
|
||||
instance.refresh_from_db(fields=[self.field_name])
|
||||
val = getattr(instance, self.field_name)
|
||||
data[self.field_name] = val
|
||||
return data[self.field_name]
|
||||
|
||||
|
|
|
@ -116,6 +116,73 @@ The example above shows a full ``from_db()`` implementation to clarify how that
|
|||
is done. In this case it would of course be possible to just use ``super()`` call
|
||||
in the ``from_db()`` method.
|
||||
|
||||
Refreshing objects from database
|
||||
================================
|
||||
|
||||
.. method:: Model.refresh_from_db(using=None, fields=None, **kwargs)
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
If you need to reload a model's values from the database, you can use the
|
||||
``refresh_from_db()`` method. When this method is called without arguments the
|
||||
following is done:
|
||||
|
||||
1. All non-deferred fields of the model are updated to the values currently
|
||||
present in the database.
|
||||
2. The previously loaded related instances for which the relation's value is no
|
||||
longer valid are removed from the reloaded instance. For example, if you have
|
||||
a foreign key from the reloaded instance to another model with name
|
||||
``Author``, then if ``obj.author_id != obj.author.id``, ``obj.author`` will
|
||||
be thrown away, and when next accessed it will be reloaded with the value of
|
||||
``obj.author_id``.
|
||||
|
||||
Note that only fields of the model are reloaded from the database. Other
|
||||
database dependent values such as annotations are not reloaded.
|
||||
|
||||
The reloading happens from the database the instance was loaded from, or from
|
||||
the default database if the instance wasn't loaded from the database. The
|
||||
``using`` argument can be used to force the database used for reloading.
|
||||
|
||||
It is possible to force the set of fields to be loaded by using the ``fields``
|
||||
argument.
|
||||
|
||||
For example, to test that an ``update()`` call resulted in the expected
|
||||
update, you could write a test similar to this::
|
||||
|
||||
def test_update_result(self):
|
||||
obj = MyModel.objects.create(val=1)
|
||||
MyModel.objects.filter(pk=obj.pk).update(val=F('val') + 1)
|
||||
# At this point obj.val is still 1, but the value in the database
|
||||
# was updated to 2. The object's updated value needs to be reloaded
|
||||
# from the database.
|
||||
obj.refresh_from_db()
|
||||
self.assertEqual(obj.val, 2)
|
||||
|
||||
Note that when deferred fields are accessed, the loading of the deferred
|
||||
field's value happens through this method. Thus it is possible to customize
|
||||
the way deferred loading happens. The example below shows how one can reload
|
||||
all of the instance's fields when a deferred field is reloaded::
|
||||
|
||||
class ExampleModel(models.Model):
|
||||
def refresh_from_db(self, using=None, fields=None, **kwargs):
|
||||
# fields contains the name of the deferred field to be
|
||||
# loaded.
|
||||
if fields is not None:
|
||||
fields = set(fields)
|
||||
deferred_fields = self.get_deferred_fields()
|
||||
# If any deferred field is going to be loaded
|
||||
if fields.intersection(deferred_fields):
|
||||
# then load all of them
|
||||
fields = fields.union(deferred_fields)
|
||||
super(ExampleModel, self).refresh_from_db(using, fields, **kwargs)
|
||||
|
||||
.. method:: Model.get_deferred_fields()
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
A helper method that returns a set containing the attribute names of all those
|
||||
fields that are currently deferred for this model.
|
||||
|
||||
.. _validating-objects:
|
||||
|
||||
Validating objects
|
||||
|
|
|
@ -391,6 +391,12 @@ Models
|
|||
by the database, which can lead to somewhat complex queries involving nested
|
||||
``REPLACE`` function calls.
|
||||
|
||||
* You can now refresh model instances by using :meth:`Model.refresh_from_db()
|
||||
<django.db.models.Model.refresh_from_db>`.
|
||||
|
||||
* You can now get the set of deferred fields for a model using
|
||||
:meth:`Model.get_deferred_fields() <django.db.models.Model.get_deferred_fields>`.
|
||||
|
||||
Signals
|
||||
^^^^^^^
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
|
||||
|
@ -713,3 +713,60 @@ class SelectOnSaveTests(TestCase):
|
|||
asos.save(update_fields=['pub_date'])
|
||||
finally:
|
||||
Article._base_manager.__class__ = orig_class
|
||||
|
||||
|
||||
class ModelRefreshTests(TestCase):
|
||||
def _truncate_ms(self, val):
|
||||
# MySQL < 5.6.4 removes microseconds from the datetimes which can cause
|
||||
# problems when comparing the original value to that loaded from DB
|
||||
return val - timedelta(microseconds=val.microsecond)
|
||||
|
||||
def test_refresh(self):
|
||||
a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
|
||||
Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
|
||||
Article.objects.filter(pk=a.pk).update(headline='new headline')
|
||||
with self.assertNumQueries(1):
|
||||
a.refresh_from_db()
|
||||
self.assertEqual(a.headline, 'new headline')
|
||||
|
||||
orig_pub_date = a.pub_date
|
||||
new_pub_date = a.pub_date + timedelta(10)
|
||||
Article.objects.update(headline='new headline 2', pub_date=new_pub_date)
|
||||
with self.assertNumQueries(1):
|
||||
a.refresh_from_db(fields=['headline'])
|
||||
self.assertEqual(a.headline, 'new headline 2')
|
||||
self.assertEqual(a.pub_date, orig_pub_date)
|
||||
with self.assertNumQueries(1):
|
||||
a.refresh_from_db()
|
||||
self.assertEqual(a.pub_date, new_pub_date)
|
||||
|
||||
def test_refresh_fk(self):
|
||||
s1 = SelfRef.objects.create()
|
||||
s2 = SelfRef.objects.create()
|
||||
s3 = SelfRef.objects.create(selfref=s1)
|
||||
s3_copy = SelfRef.objects.get(pk=s3.pk)
|
||||
s3_copy.selfref.touched = True
|
||||
s3.selfref = s2
|
||||
s3.save()
|
||||
with self.assertNumQueries(1):
|
||||
s3_copy.refresh_from_db()
|
||||
with self.assertNumQueries(1):
|
||||
# The old related instance was thrown away (the selfref_id has
|
||||
# changed). It needs to be reloaded on access, so one query
|
||||
# executed.
|
||||
self.assertFalse(hasattr(s3_copy.selfref, 'touched'))
|
||||
self.assertEqual(s3_copy.selfref, s2)
|
||||
|
||||
def test_refresh_unsaved(self):
|
||||
pub_date = self._truncate_ms(datetime.now())
|
||||
a = Article.objects.create(pub_date=pub_date)
|
||||
a2 = Article(id=a.pk)
|
||||
with self.assertNumQueries(1):
|
||||
a2.refresh_from_db()
|
||||
self.assertEqual(a2.pub_date, pub_date)
|
||||
self.assertEqual(a2._state.db, "default")
|
||||
|
||||
def test_refresh_no_fields(self):
|
||||
a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
|
||||
with self.assertNumQueries(0):
|
||||
a.refresh_from_db(fields=[])
|
||||
|
|
|
@ -32,3 +32,17 @@ class BigChild(Primary):
|
|||
class ChildProxy(Child):
|
||||
class Meta:
|
||||
proxy = True
|
||||
|
||||
|
||||
class RefreshPrimaryProxy(Primary):
|
||||
class Meta:
|
||||
proxy = True
|
||||
|
||||
def refresh_from_db(self, using=None, fields=None, **kwargs):
|
||||
# Reloads all deferred fields if any of the fields is deferred.
|
||||
if fields is not None:
|
||||
fields = set(fields)
|
||||
deferred_fields = self.get_deferred_fields()
|
||||
if fields.intersection(deferred_fields):
|
||||
fields = fields.union(deferred_fields)
|
||||
super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs)
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
from django.db.models.query_utils import DeferredAttribute, InvalidQuery
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import Secondary, Primary, Child, BigChild, ChildProxy
|
||||
from .models import Secondary, Primary, Child, BigChild, ChildProxy, RefreshPrimaryProxy
|
||||
|
||||
|
||||
class DeferTests(TestCase):
|
||||
|
@ -189,3 +189,29 @@ class DeferTests(TestCase):
|
|||
s1_defer = Secondary.objects.only('pk').get(pk=s1.pk)
|
||||
self.assertEqual(s1, s1_defer)
|
||||
self.assertEqual(s1_defer, s1)
|
||||
|
||||
def test_refresh_not_loading_deferred_fields(self):
|
||||
s = Secondary.objects.create()
|
||||
rf = Primary.objects.create(name='foo', value='bar', related=s)
|
||||
rf2 = Primary.objects.only('related', 'value').get()
|
||||
rf.name = 'new foo'
|
||||
rf.value = 'new bar'
|
||||
rf.save()
|
||||
with self.assertNumQueries(1):
|
||||
rf2.refresh_from_db()
|
||||
self.assertEqual(rf2.value, 'new bar')
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEqual(rf2.name, 'new foo')
|
||||
|
||||
def test_custom_refresh_on_deferred_loading(self):
|
||||
s = Secondary.objects.create()
|
||||
rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s)
|
||||
rf2 = RefreshPrimaryProxy.objects.only('related').get()
|
||||
rf.name = 'new foo'
|
||||
rf.value = 'new bar'
|
||||
rf.save()
|
||||
with self.assertNumQueries(1):
|
||||
# Customized refresh_from_db() reloads all deferred fields on
|
||||
# access of any of them.
|
||||
self.assertEqual(rf2.name, 'new foo')
|
||||
self.assertEqual(rf2.value, 'new bar')
|
||||
|
|
|
@ -11,6 +11,12 @@ from .models import ChoicesModel, DataModel, MyModel, OtherModel
|
|||
|
||||
|
||||
class CustomField(TestCase):
|
||||
def test_refresh(self):
|
||||
d = DataModel.objects.create(data=[1, 2, 3])
|
||||
d.refresh_from_db(fields=['data'])
|
||||
self.assertIsInstance(d.data, list)
|
||||
self.assertEqual(d.data, [1, 2, 3])
|
||||
|
||||
def test_defer(self):
|
||||
d = DataModel.objects.create(data=[1, 2, 3])
|
||||
|
||||
|
|
|
@ -112,6 +112,24 @@ class QueryTestCase(TestCase):
|
|||
title="Dive into Python"
|
||||
)
|
||||
|
||||
def test_refresh(self):
|
||||
dive = Book()
|
||||
dive.title = "Dive into Python"
|
||||
dive = Book()
|
||||
dive.title = "Dive into Python"
|
||||
dive.published = datetime.date(2009, 5, 4)
|
||||
dive.save(using='other')
|
||||
dive.published = datetime.date(2009, 5, 4)
|
||||
dive.save(using='other')
|
||||
dive2 = Book.objects.using('other').get()
|
||||
dive2.title = "Dive into Python (on default)"
|
||||
dive2.save(using='default')
|
||||
dive.refresh_from_db()
|
||||
self.assertEqual(dive.title, "Dive into Python")
|
||||
dive.refresh_from_db(using='default')
|
||||
self.assertEqual(dive.title, "Dive into Python (on default)")
|
||||
self.assertEqual(dive._state.db, "default")
|
||||
|
||||
def test_basic_queries(self):
|
||||
"Queries are constrained to a single database"
|
||||
dive = Book.objects.using('other').create(title="Dive into Python",
|
||||
|
|
Loading…
Reference in New Issue