Fixed #901 -- Added Model.refresh_from_db() method
Thanks to github aliases dbrgn, carljm, slurms, dfunckt, and timgraham for reviews.
This commit is contained in:
parent
912ad03226
commit
c7175fcdfe
|
@ -13,6 +13,8 @@ from django.core.exceptions import (ObjectDoesNotExist,
|
||||||
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
|
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
|
||||||
from django.db import (router, connections, transaction, DatabaseError,
|
from django.db import (router, connections, transaction, DatabaseError,
|
||||||
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY)
|
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY)
|
||||||
|
from django.db.models import signals
|
||||||
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.deletion import Collector
|
from django.db.models.deletion import Collector
|
||||||
from django.db.models.fields import AutoField, FieldDoesNotExist
|
from django.db.models.fields import AutoField, FieldDoesNotExist
|
||||||
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
|
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
|
||||||
|
@ -21,7 +23,6 @@ from django.db.models.manager import ensure_default_manager
|
||||||
from django.db.models.options import Options
|
from django.db.models.options import Options
|
||||||
from django.db.models.query import Q
|
from django.db.models.query import Q
|
||||||
from django.db.models.query_utils import DeferredAttribute, deferred_class_factory
|
from django.db.models.query_utils import DeferredAttribute, deferred_class_factory
|
||||||
from django.db.models import signals
|
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.deprecation import RemovedInDjango19Warning
|
from django.utils.deprecation import RemovedInDjango19Warning
|
||||||
from django.utils.encoding import force_str, force_text
|
from django.utils.encoding import force_str, force_text
|
||||||
|
@ -552,6 +553,71 @@ class Model(six.with_metaclass(ModelBase)):
|
||||||
|
|
||||||
pk = property(_get_pk_val, _set_pk_val)
|
pk = property(_get_pk_val, _set_pk_val)
|
||||||
|
|
||||||
|
def get_deferred_fields(self):
|
||||||
|
"""
|
||||||
|
Returns a set containing names of deferred fields for this instance.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
f.attname for f in self._meta.concrete_fields
|
||||||
|
if isinstance(self.__class__.__dict__.get(f.attname), DeferredAttribute)
|
||||||
|
}
|
||||||
|
|
||||||
|
def refresh_from_db(self, using=None, fields=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Reloads field values from the database.
|
||||||
|
|
||||||
|
By default, the reloading happens from the database this instance was
|
||||||
|
loaded from, or by the read router if this instance wasn't loaded from
|
||||||
|
any database. The using parameter will override the default.
|
||||||
|
|
||||||
|
Fields can be used to specify which fields to reload. The fields
|
||||||
|
should be an iterable of field attnames. If fields is None, then
|
||||||
|
all non-deferred fields are reloaded.
|
||||||
|
|
||||||
|
When accessing deferred fields of an instance, the deferred loading
|
||||||
|
of the field will call this method.
|
||||||
|
"""
|
||||||
|
if fields is not None:
|
||||||
|
if len(fields) == 0:
|
||||||
|
return
|
||||||
|
if any(LOOKUP_SEP in f for f in fields):
|
||||||
|
raise ValueError(
|
||||||
|
'Found "%s" in fields argument. Relations and transforms '
|
||||||
|
'are not allowed in fields.' % LOOKUP_SEP)
|
||||||
|
|
||||||
|
db = using if using is not None else self._state.db
|
||||||
|
if self._deferred:
|
||||||
|
non_deferred_model = self._meta.proxy_for_model
|
||||||
|
else:
|
||||||
|
non_deferred_model = self.__class__
|
||||||
|
db_instance_qs = non_deferred_model._default_manager.using(db).filter(pk=self.pk)
|
||||||
|
|
||||||
|
# Use provided fields, if not set then reload all non-deferred fields.
|
||||||
|
if fields is not None:
|
||||||
|
fields = list(fields)
|
||||||
|
db_instance_qs = db_instance_qs.only(*fields)
|
||||||
|
elif self._deferred:
|
||||||
|
deferred_fields = self.get_deferred_fields()
|
||||||
|
fields = [f.attname for f in self._meta.concrete_fields
|
||||||
|
if f.attname not in deferred_fields]
|
||||||
|
db_instance_qs = db_instance_qs.only(*fields)
|
||||||
|
|
||||||
|
db_instance = db_instance_qs.get()
|
||||||
|
non_loaded_fields = db_instance.get_deferred_fields()
|
||||||
|
for field in self._meta.concrete_fields:
|
||||||
|
if field.attname in non_loaded_fields:
|
||||||
|
# This field wasn't refreshed - skip ahead.
|
||||||
|
continue
|
||||||
|
setattr(self, field.attname, getattr(db_instance, field.attname))
|
||||||
|
# Throw away stale foreign key references.
|
||||||
|
if field.rel and field.get_cache_name() in self.__dict__:
|
||||||
|
rel_instance = getattr(self, field.get_cache_name())
|
||||||
|
local_val = getattr(db_instance, field.attname)
|
||||||
|
related_val = getattr(rel_instance, field.related_field.attname)
|
||||||
|
if local_val != related_val:
|
||||||
|
del self.__dict__[field.get_cache_name()]
|
||||||
|
self._state.db = db_instance._state.db
|
||||||
|
|
||||||
def serializable_value(self, field_name):
|
def serializable_value(self, field_name):
|
||||||
"""
|
"""
|
||||||
Returns the value of the field name for this instance. If the field is
|
Returns the value of the field name for this instance. If the field is
|
||||||
|
|
|
@ -109,14 +109,8 @@ class DeferredAttribute(object):
|
||||||
# might be able to reuse the already loaded value. Refs #18343.
|
# might be able to reuse the already loaded value. Refs #18343.
|
||||||
val = self._check_parent_chain(instance, name)
|
val = self._check_parent_chain(instance, name)
|
||||||
if val is None:
|
if val is None:
|
||||||
# We use only() instead of values() here because we want the
|
instance.refresh_from_db(fields=[self.field_name])
|
||||||
# various data coercion methods (to_python(), etc.) to be
|
val = getattr(instance, self.field_name)
|
||||||
# called here.
|
|
||||||
val = getattr(
|
|
||||||
non_deferred_model._base_manager.only(name).using(
|
|
||||||
instance._state.db).get(pk=instance.pk),
|
|
||||||
self.field_name
|
|
||||||
)
|
|
||||||
data[self.field_name] = val
|
data[self.field_name] = val
|
||||||
return data[self.field_name]
|
return data[self.field_name]
|
||||||
|
|
||||||
|
|
|
@ -116,6 +116,73 @@ The example above shows a full ``from_db()`` implementation to clarify how that
|
||||||
is done. In this case it would of course be possible to just use ``super()`` call
|
is done. In this case it would of course be possible to just use ``super()`` call
|
||||||
in the ``from_db()`` method.
|
in the ``from_db()`` method.
|
||||||
|
|
||||||
|
Refreshing objects from database
|
||||||
|
================================
|
||||||
|
|
||||||
|
.. method:: Model.refresh_from_db(using=None, fields=None, **kwargs)
|
||||||
|
|
||||||
|
.. versionadded:: 1.8
|
||||||
|
|
||||||
|
If you need to reload a model's values from the database, you can use the
|
||||||
|
``refresh_from_db()`` method. When this method is called without arguments the
|
||||||
|
following is done:
|
||||||
|
|
||||||
|
1. All non-deferred fields of the model are updated to the values currently
|
||||||
|
present in the database.
|
||||||
|
2. The previously loaded related instances for which the relation's value is no
|
||||||
|
longer valid are removed from the reloaded instance. For example, if you have
|
||||||
|
a foreign key from the reloaded instance to another model with name
|
||||||
|
``Author``, then if ``obj.author_id != obj.author.id``, ``obj.author`` will
|
||||||
|
be thrown away, and when next accessed it will be reloaded with the value of
|
||||||
|
``obj.author_id``.
|
||||||
|
|
||||||
|
Note that only fields of the model are reloaded from the database. Other
|
||||||
|
database dependent values such as annotations are not reloaded.
|
||||||
|
|
||||||
|
The reloading happens from the database the instance was loaded from, or from
|
||||||
|
the default database if the instance wasn't loaded from the database. The
|
||||||
|
``using`` argument can be used to force the database used for reloading.
|
||||||
|
|
||||||
|
It is possible to force the set of fields to be loaded by using the ``fields``
|
||||||
|
argument.
|
||||||
|
|
||||||
|
For example, to test that an ``update()`` call resulted in the expected
|
||||||
|
update, you could write a test similar to this::
|
||||||
|
|
||||||
|
def test_update_result(self):
|
||||||
|
obj = MyModel.objects.create(val=1)
|
||||||
|
MyModel.objects.filter(pk=obj.pk).update(val=F('val') + 1)
|
||||||
|
# At this point obj.val is still 1, but the value in the database
|
||||||
|
# was updated to 2. The object's updated value needs to be reloaded
|
||||||
|
# from the database.
|
||||||
|
obj.refresh_from_db()
|
||||||
|
self.assertEqual(obj.val, 2)
|
||||||
|
|
||||||
|
Note that when deferred fields are accessed, the loading of the deferred
|
||||||
|
field's value happens through this method. Thus it is possible to customize
|
||||||
|
the way deferred loading happens. The example below shows how one can reload
|
||||||
|
all of the instance's fields when a deferred field is reloaded::
|
||||||
|
|
||||||
|
class ExampleModel(models.Model):
|
||||||
|
def refresh_from_db(self, using=None, fields=None, **kwargs):
|
||||||
|
# fields contains the name of the deferred field to be
|
||||||
|
# loaded.
|
||||||
|
if fields is not None:
|
||||||
|
fields = set(fields)
|
||||||
|
deferred_fields = self.get_deferred_fields()
|
||||||
|
# If any deferred field is going to be loaded
|
||||||
|
if fields.intersection(deferred_fields):
|
||||||
|
# then load all of them
|
||||||
|
fields = fields.union(deferred_fields)
|
||||||
|
super(ExampleModel, self).refresh_from_db(using, fields, **kwargs)
|
||||||
|
|
||||||
|
.. method:: Model.get_deferred_fields()
|
||||||
|
|
||||||
|
.. versionadded:: 1.8
|
||||||
|
|
||||||
|
A helper method that returns a set containing the attribute names of all those
|
||||||
|
fields that are currently deferred for this model.
|
||||||
|
|
||||||
.. _validating-objects:
|
.. _validating-objects:
|
||||||
|
|
||||||
Validating objects
|
Validating objects
|
||||||
|
|
|
@ -391,6 +391,12 @@ Models
|
||||||
by the database, which can lead to somewhat complex queries involving nested
|
by the database, which can lead to somewhat complex queries involving nested
|
||||||
``REPLACE`` function calls.
|
``REPLACE`` function calls.
|
||||||
|
|
||||||
|
* You can now refresh model instances by using :meth:`Model.refresh_from_db()
|
||||||
|
<django.db.models.Model.refresh_from_db>`.
|
||||||
|
|
||||||
|
* You can now get the set of deferred fields for a model using
|
||||||
|
:meth:`Model.get_deferred_fields() <django.db.models.Model.get_deferred_fields>`.
|
||||||
|
|
||||||
Signals
|
Signals
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
|
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
|
||||||
|
@ -713,3 +713,60 @@ class SelectOnSaveTests(TestCase):
|
||||||
asos.save(update_fields=['pub_date'])
|
asos.save(update_fields=['pub_date'])
|
||||||
finally:
|
finally:
|
||||||
Article._base_manager.__class__ = orig_class
|
Article._base_manager.__class__ = orig_class
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRefreshTests(TestCase):
|
||||||
|
def _truncate_ms(self, val):
|
||||||
|
# MySQL < 5.6.4 removes microseconds from the datetimes which can cause
|
||||||
|
# problems when comparing the original value to that loaded from DB
|
||||||
|
return val - timedelta(microseconds=val.microsecond)
|
||||||
|
|
||||||
|
def test_refresh(self):
|
||||||
|
a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
|
||||||
|
Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
|
||||||
|
Article.objects.filter(pk=a.pk).update(headline='new headline')
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
a.refresh_from_db()
|
||||||
|
self.assertEqual(a.headline, 'new headline')
|
||||||
|
|
||||||
|
orig_pub_date = a.pub_date
|
||||||
|
new_pub_date = a.pub_date + timedelta(10)
|
||||||
|
Article.objects.update(headline='new headline 2', pub_date=new_pub_date)
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
a.refresh_from_db(fields=['headline'])
|
||||||
|
self.assertEqual(a.headline, 'new headline 2')
|
||||||
|
self.assertEqual(a.pub_date, orig_pub_date)
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
a.refresh_from_db()
|
||||||
|
self.assertEqual(a.pub_date, new_pub_date)
|
||||||
|
|
||||||
|
def test_refresh_fk(self):
|
||||||
|
s1 = SelfRef.objects.create()
|
||||||
|
s2 = SelfRef.objects.create()
|
||||||
|
s3 = SelfRef.objects.create(selfref=s1)
|
||||||
|
s3_copy = SelfRef.objects.get(pk=s3.pk)
|
||||||
|
s3_copy.selfref.touched = True
|
||||||
|
s3.selfref = s2
|
||||||
|
s3.save()
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
s3_copy.refresh_from_db()
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
# The old related instance was thrown away (the selfref_id has
|
||||||
|
# changed). It needs to be reloaded on access, so one query
|
||||||
|
# executed.
|
||||||
|
self.assertFalse(hasattr(s3_copy.selfref, 'touched'))
|
||||||
|
self.assertEqual(s3_copy.selfref, s2)
|
||||||
|
|
||||||
|
def test_refresh_unsaved(self):
|
||||||
|
pub_date = self._truncate_ms(datetime.now())
|
||||||
|
a = Article.objects.create(pub_date=pub_date)
|
||||||
|
a2 = Article(id=a.pk)
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
a2.refresh_from_db()
|
||||||
|
self.assertEqual(a2.pub_date, pub_date)
|
||||||
|
self.assertEqual(a2._state.db, "default")
|
||||||
|
|
||||||
|
def test_refresh_no_fields(self):
|
||||||
|
a = Article.objects.create(pub_date=self._truncate_ms(datetime.now()))
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
a.refresh_from_db(fields=[])
|
||||||
|
|
|
@ -32,3 +32,17 @@ class BigChild(Primary):
|
||||||
class ChildProxy(Child):
|
class ChildProxy(Child):
|
||||||
class Meta:
|
class Meta:
|
||||||
proxy = True
|
proxy = True
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshPrimaryProxy(Primary):
|
||||||
|
class Meta:
|
||||||
|
proxy = True
|
||||||
|
|
||||||
|
def refresh_from_db(self, using=None, fields=None, **kwargs):
|
||||||
|
# Reloads all deferred fields if any of the fields is deferred.
|
||||||
|
if fields is not None:
|
||||||
|
fields = set(fields)
|
||||||
|
deferred_fields = self.get_deferred_fields()
|
||||||
|
if fields.intersection(deferred_fields):
|
||||||
|
fields = fields.union(deferred_fields)
|
||||||
|
super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
||||||
from django.db.models.query_utils import DeferredAttribute, InvalidQuery
|
from django.db.models.query_utils import DeferredAttribute, InvalidQuery
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from .models import Secondary, Primary, Child, BigChild, ChildProxy
|
from .models import Secondary, Primary, Child, BigChild, ChildProxy, RefreshPrimaryProxy
|
||||||
|
|
||||||
|
|
||||||
class DeferTests(TestCase):
|
class DeferTests(TestCase):
|
||||||
|
@ -189,3 +189,29 @@ class DeferTests(TestCase):
|
||||||
s1_defer = Secondary.objects.only('pk').get(pk=s1.pk)
|
s1_defer = Secondary.objects.only('pk').get(pk=s1.pk)
|
||||||
self.assertEqual(s1, s1_defer)
|
self.assertEqual(s1, s1_defer)
|
||||||
self.assertEqual(s1_defer, s1)
|
self.assertEqual(s1_defer, s1)
|
||||||
|
|
||||||
|
def test_refresh_not_loading_deferred_fields(self):
|
||||||
|
s = Secondary.objects.create()
|
||||||
|
rf = Primary.objects.create(name='foo', value='bar', related=s)
|
||||||
|
rf2 = Primary.objects.only('related', 'value').get()
|
||||||
|
rf.name = 'new foo'
|
||||||
|
rf.value = 'new bar'
|
||||||
|
rf.save()
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
rf2.refresh_from_db()
|
||||||
|
self.assertEqual(rf2.value, 'new bar')
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
self.assertEqual(rf2.name, 'new foo')
|
||||||
|
|
||||||
|
def test_custom_refresh_on_deferred_loading(self):
|
||||||
|
s = Secondary.objects.create()
|
||||||
|
rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s)
|
||||||
|
rf2 = RefreshPrimaryProxy.objects.only('related').get()
|
||||||
|
rf.name = 'new foo'
|
||||||
|
rf.value = 'new bar'
|
||||||
|
rf.save()
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
# Customized refresh_from_db() reloads all deferred fields on
|
||||||
|
# access of any of them.
|
||||||
|
self.assertEqual(rf2.name, 'new foo')
|
||||||
|
self.assertEqual(rf2.value, 'new bar')
|
||||||
|
|
|
@ -11,6 +11,12 @@ from .models import ChoicesModel, DataModel, MyModel, OtherModel
|
||||||
|
|
||||||
|
|
||||||
class CustomField(TestCase):
|
class CustomField(TestCase):
|
||||||
|
def test_refresh(self):
|
||||||
|
d = DataModel.objects.create(data=[1, 2, 3])
|
||||||
|
d.refresh_from_db(fields=['data'])
|
||||||
|
self.assertIsInstance(d.data, list)
|
||||||
|
self.assertEqual(d.data, [1, 2, 3])
|
||||||
|
|
||||||
def test_defer(self):
|
def test_defer(self):
|
||||||
d = DataModel.objects.create(data=[1, 2, 3])
|
d = DataModel.objects.create(data=[1, 2, 3])
|
||||||
|
|
||||||
|
|
|
@ -112,6 +112,24 @@ class QueryTestCase(TestCase):
|
||||||
title="Dive into Python"
|
title="Dive into Python"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_refresh(self):
|
||||||
|
dive = Book()
|
||||||
|
dive.title = "Dive into Python"
|
||||||
|
dive = Book()
|
||||||
|
dive.title = "Dive into Python"
|
||||||
|
dive.published = datetime.date(2009, 5, 4)
|
||||||
|
dive.save(using='other')
|
||||||
|
dive.published = datetime.date(2009, 5, 4)
|
||||||
|
dive.save(using='other')
|
||||||
|
dive2 = Book.objects.using('other').get()
|
||||||
|
dive2.title = "Dive into Python (on default)"
|
||||||
|
dive2.save(using='default')
|
||||||
|
dive.refresh_from_db()
|
||||||
|
self.assertEqual(dive.title, "Dive into Python")
|
||||||
|
dive.refresh_from_db(using='default')
|
||||||
|
self.assertEqual(dive.title, "Dive into Python (on default)")
|
||||||
|
self.assertEqual(dive._state.db, "default")
|
||||||
|
|
||||||
def test_basic_queries(self):
|
def test_basic_queries(self):
|
||||||
"Queries are constrained to a single database"
|
"Queries are constrained to a single database"
|
||||||
dive = Book.objects.using('other').create(title="Dive into Python",
|
dive = Book.objects.using('other').create(title="Dive into Python",
|
||||||
|
|
Loading…
Reference in New Issue