Fixed #901 -- Added Model.refresh_from_db() method

Thanks to github aliases dbrgn, carljm, slurms, dfunckt, and timgraham
for reviews.
This commit is contained in:
Anssi Kääriäinen 2014-07-05 09:03:52 +03:00 committed by Tim Graham
parent 912ad03226
commit c7175fcdfe
9 changed files with 265 additions and 11 deletions

View File

@ -13,6 +13,8 @@ from django.core.exceptions import (ObjectDoesNotExist,
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS) MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.db import (router, connections, transaction, DatabaseError, from django.db import (router, connections, transaction, DatabaseError,
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY) DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY)
from django.db.models import signals
from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.fields import AutoField, FieldDoesNotExist from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel, from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
@ -21,7 +23,6 @@ from django.db.models.manager import ensure_default_manager
from django.db.models.options import Options from django.db.models.options import Options
from django.db.models.query import Q from django.db.models.query import Q
from django.db.models.query_utils import DeferredAttribute, deferred_class_factory from django.db.models.query_utils import DeferredAttribute, deferred_class_factory
from django.db.models import signals
from django.utils import six from django.utils import six
from django.utils.deprecation import RemovedInDjango19Warning from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.encoding import force_str, force_text from django.utils.encoding import force_str, force_text
@ -552,6 +553,71 @@ class Model(six.with_metaclass(ModelBase)):
pk = property(_get_pk_val, _set_pk_val) pk = property(_get_pk_val, _set_pk_val)
def get_deferred_fields(self):
"""
Returns a set containing names of deferred fields for this instance.
"""
return {
f.attname for f in self._meta.concrete_fields
if isinstance(self.__class__.__dict__.get(f.attname), DeferredAttribute)
}
def refresh_from_db(self, using=None, fields=None, **kwargs):
"""
Reloads field values from the database.
By default, the reloading happens from the database this instance was
loaded from, or by the read router if this instance wasn't loaded from
any database. The using parameter will override the default.
Fields can be used to specify which fields to reload. The fields
should be an iterable of field attnames. If fields is None, then
all non-deferred fields are reloaded.
When accessing deferred fields of an instance, the deferred loading
of the field will call this method.
"""
if fields is not None:
if len(fields) == 0:
return
if any(LOOKUP_SEP in f for f in fields):
raise ValueError(
'Found "%s" in fields argument. Relations and transforms '
'are not allowed in fields.' % LOOKUP_SEP)
db = using if using is not None else self._state.db
if self._deferred:
non_deferred_model = self._meta.proxy_for_model
else:
non_deferred_model = self.__class__
db_instance_qs = non_deferred_model._default_manager.using(db).filter(pk=self.pk)
# Use provided fields, if not set then reload all non-deferred fields.
if fields is not None:
fields = list(fields)
db_instance_qs = db_instance_qs.only(*fields)
elif self._deferred:
deferred_fields = self.get_deferred_fields()
fields = [f.attname for f in self._meta.concrete_fields
if f.attname not in deferred_fields]
db_instance_qs = db_instance_qs.only(*fields)
db_instance = db_instance_qs.get()
non_loaded_fields = db_instance.get_deferred_fields()
for field in self._meta.concrete_fields:
if field.attname in non_loaded_fields:
# This field wasn't refreshed - skip ahead.
continue
setattr(self, field.attname, getattr(db_instance, field.attname))
# Throw away stale foreign key references.
if field.rel and field.get_cache_name() in self.__dict__:
rel_instance = getattr(self, field.get_cache_name())
local_val = getattr(db_instance, field.attname)
related_val = getattr(rel_instance, field.related_field.attname)
if local_val != related_val:
del self.__dict__[field.get_cache_name()]
self._state.db = db_instance._state.db
def serializable_value(self, field_name): def serializable_value(self, field_name):
""" """
Returns the value of the field name for this instance. If the field is Returns the value of the field name for this instance. If the field is

View File

@ -109,14 +109,8 @@ class DeferredAttribute(object):
# might be able to reuse the already loaded value. Refs #18343. # might be able to reuse the already loaded value. Refs #18343.
val = self._check_parent_chain(instance, name) val = self._check_parent_chain(instance, name)
if val is None: if val is None:
# We use only() instead of values() here because we want the instance.refresh_from_db(fields=[self.field_name])
# various data coercion methods (to_python(), etc.) to be val = getattr(instance, self.field_name)
# called here.
val = getattr(
non_deferred_model._base_manager.only(name).using(
instance._state.db).get(pk=instance.pk),
self.field_name
)
data[self.field_name] = val data[self.field_name] = val
return data[self.field_name] return data[self.field_name]

View File

@ -116,6 +116,73 @@ The example above shows a full ``from_db()`` implementation to clarify how that
is done. In this case it would of course be possible to just use ``super()`` call is done. In this case it would of course be possible to just use ``super()`` call
in the ``from_db()`` method. in the ``from_db()`` method.
Refreshing objects from database
================================
.. method:: Model.refresh_from_db(using=None, fields=None, **kwargs)
.. versionadded:: 1.8
If you need to reload a model's values from the database, you can use the
``refresh_from_db()`` method. When this method is called without arguments the
following is done:
1. All non-deferred fields of the model are updated to the values currently
present in the database.
2. The previously loaded related instances for which the relation's value is no
longer valid are removed from the reloaded instance. For example, if you have
a foreign key from the reloaded instance to another model with name
``Author``, then if ``obj.author_id != obj.author.id``, ``obj.author`` will
be thrown away, and when next accessed it will be reloaded with the value of
``obj.author_id``.
Note that only fields of the model are reloaded from the database. Other
database dependent values such as annotations are not reloaded.
The reloading happens from the database the instance was loaded from, or from
the default database if the instance wasn't loaded from the database. The
``using`` argument can be used to force the database used for reloading.
It is possible to force the set of fields to be loaded by using the ``fields``
argument.
For example, to test that an ``update()`` call resulted in the expected
update, you could write a test similar to this::
def test_update_result(self):
obj = MyModel.objects.create(val=1)
MyModel.objects.filter(pk=obj.pk).update(val=F('val') + 1)
# At this point obj.val is still 1, but the value in the database
# was updated to 2. The object's updated value needs to be reloaded
# from the database.
obj.refresh_from_db()
self.assertEqual(obj.val, 2)
Note that when deferred fields are accessed, the loading of the deferred
field's value happens through this method. Thus it is possible to customize
the way deferred loading happens. The example below shows how one can reload
all of the instance's fields when a deferred field is reloaded::
class ExampleModel(models.Model):
def refresh_from_db(self, using=None, fields=None, **kwargs):
# fields contains the name of the deferred field to be
# loaded.
if fields is not None:
fields = set(fields)
deferred_fields = self.get_deferred_fields()
# If any deferred field is going to be loaded
if fields.intersection(deferred_fields):
# then load all of them
fields = fields.union(deferred_fields)
super(ExampleModel, self).refresh_from_db(using, fields, **kwargs)
.. method:: Model.get_deferred_fields()
.. versionadded:: 1.8
A helper method that returns a set containing the attribute names of all those
fields that are currently deferred for this model.
.. _validating-objects: .. _validating-objects:
Validating objects Validating objects

View File

@ -391,6 +391,12 @@ Models
by the database, which can lead to somewhat complex queries involving nested by the database, which can lead to somewhat complex queries involving nested
``REPLACE`` function calls. ``REPLACE`` function calls.
* You can now refresh model instances by using :meth:`Model.refresh_from_db()
<django.db.models.Model.refresh_from_db>`.
* You can now get the set of deferred fields for a model using
:meth:`Model.get_deferred_fields() <django.db.models.Model.get_deferred_fields>`.
Signals Signals
^^^^^^^ ^^^^^^^

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from datetime import datetime from datetime import datetime, timedelta
import threading import threading
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
@ -713,3 +713,60 @@ class SelectOnSaveTests(TestCase):
asos.save(update_fields=['pub_date']) asos.save(update_fields=['pub_date'])
finally: finally:
Article._base_manager.__class__ = orig_class Article._base_manager.__class__ = orig_class
class ModelRefreshTests(TestCase):
def _truncate_ms(self, val):
# MySQL < 5.6.4 removes microseconds from the datetimes which can cause
# problems when comparing the original value to that loaded from DB
return val - timedelta(microseconds=val.microsecond)
def test_refresh(self):
a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
Article.objects.filter(pk=a.pk).update(headline='new headline')
with self.assertNumQueries(1):
a.refresh_from_db()
self.assertEqual(a.headline, 'new headline')
orig_pub_date = a.pub_date
new_pub_date = a.pub_date + timedelta(10)
Article.objects.update(headline='new headline 2', pub_date=new_pub_date)
with self.assertNumQueries(1):
a.refresh_from_db(fields=['headline'])
self.assertEqual(a.headline, 'new headline 2')
self.assertEqual(a.pub_date, orig_pub_date)
with self.assertNumQueries(1):
a.refresh_from_db()
self.assertEqual(a.pub_date, new_pub_date)
def test_refresh_fk(self):
s1 = SelfRef.objects.create()
s2 = SelfRef.objects.create()
s3 = SelfRef.objects.create(selfref=s1)
s3_copy = SelfRef.objects.get(pk=s3.pk)
s3_copy.selfref.touched = True
s3.selfref = s2
s3.save()
with self.assertNumQueries(1):
s3_copy.refresh_from_db()
with self.assertNumQueries(1):
# The old related instance was thrown away (the selfref_id has
# changed). It needs to be reloaded on access, so one query
# executed.
self.assertFalse(hasattr(s3_copy.selfref, 'touched'))
self.assertEqual(s3_copy.selfref, s2)
def test_refresh_unsaved(self):
pub_date = self._truncate_ms(datetime.now())
a = Article.objects.create(pub_date=pub_date)
a2 = Article(id=a.pk)
with self.assertNumQueries(1):
a2.refresh_from_db()
self.assertEqual(a2.pub_date, pub_date)
self.assertEqual(a2._state.db, "default")
def test_refresh_no_fields(self):
a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
with self.assertNumQueries(0):
a.refresh_from_db(fields=[])

View File

@ -32,3 +32,17 @@ class BigChild(Primary):
class ChildProxy(Child): class ChildProxy(Child):
class Meta: class Meta:
proxy = True proxy = True
class RefreshPrimaryProxy(Primary):
class Meta:
proxy = True
def refresh_from_db(self, using=None, fields=None, **kwargs):
# Reloads all deferred fields if any of the fields is deferred.
if fields is not None:
fields = set(fields)
deferred_fields = self.get_deferred_fields()
if fields.intersection(deferred_fields):
fields = fields.union(deferred_fields)
super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs)

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from django.db.models.query_utils import DeferredAttribute, InvalidQuery from django.db.models.query_utils import DeferredAttribute, InvalidQuery
from django.test import TestCase from django.test import TestCase
from .models import Secondary, Primary, Child, BigChild, ChildProxy from .models import Secondary, Primary, Child, BigChild, ChildProxy, RefreshPrimaryProxy
class DeferTests(TestCase): class DeferTests(TestCase):
@ -189,3 +189,29 @@ class DeferTests(TestCase):
s1_defer = Secondary.objects.only('pk').get(pk=s1.pk) s1_defer = Secondary.objects.only('pk').get(pk=s1.pk)
self.assertEqual(s1, s1_defer) self.assertEqual(s1, s1_defer)
self.assertEqual(s1_defer, s1) self.assertEqual(s1_defer, s1)
def test_refresh_not_loading_deferred_fields(self):
s = Secondary.objects.create()
rf = Primary.objects.create(name='foo', value='bar', related=s)
rf2 = Primary.objects.only('related', 'value').get()
rf.name = 'new foo'
rf.value = 'new bar'
rf.save()
with self.assertNumQueries(1):
rf2.refresh_from_db()
self.assertEqual(rf2.value, 'new bar')
with self.assertNumQueries(1):
self.assertEqual(rf2.name, 'new foo')
def test_custom_refresh_on_deferred_loading(self):
s = Secondary.objects.create()
rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s)
rf2 = RefreshPrimaryProxy.objects.only('related').get()
rf.name = 'new foo'
rf.value = 'new bar'
rf.save()
with self.assertNumQueries(1):
# Customized refresh_from_db() reloads all deferred fields on
# access of any of them.
self.assertEqual(rf2.name, 'new foo')
self.assertEqual(rf2.value, 'new bar')

View File

@ -11,6 +11,12 @@ from .models import ChoicesModel, DataModel, MyModel, OtherModel
class CustomField(TestCase): class CustomField(TestCase):
def test_refresh(self):
d = DataModel.objects.create(data=[1, 2, 3])
d.refresh_from_db(fields=['data'])
self.assertIsInstance(d.data, list)
self.assertEqual(d.data, [1, 2, 3])
def test_defer(self): def test_defer(self):
d = DataModel.objects.create(data=[1, 2, 3]) d = DataModel.objects.create(data=[1, 2, 3])

View File

@ -112,6 +112,24 @@ class QueryTestCase(TestCase):
title="Dive into Python" title="Dive into Python"
) )
def test_refresh(self):
dive = Book()
dive.title = "Dive into Python"
dive = Book()
dive.title = "Dive into Python"
dive.published = datetime.date(2009, 5, 4)
dive.save(using='other')
dive.published = datetime.date(2009, 5, 4)
dive.save(using='other')
dive2 = Book.objects.using('other').get()
dive2.title = "Dive into Python (on default)"
dive2.save(using='default')
dive.refresh_from_db()
self.assertEqual(dive.title, "Dive into Python")
dive.refresh_from_db(using='default')
self.assertEqual(dive.title, "Dive into Python (on default)")
self.assertEqual(dive._state.db, "default")
def test_basic_queries(self): def test_basic_queries(self):
"Queries are constrained to a single database" "Queries are constrained to a single database"
dive = Book.objects.using('other').create(title="Dive into Python", dive = Book.objects.using('other').create(title="Dive into Python",