From c7175fcdfe94be60c04f3b1ceb6d0b2def2b6f09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Sat, 5 Jul 2014 09:03:52 +0300 Subject: [PATCH] Fixed #901 -- Added Model.refresh_from_db() method Thanks to github aliases dbrgn, carljm, slurms, dfunckt, and timgraham for reviews. --- django/db/models/base.py | 68 +++++++++++++++++++++++++++++++- django/db/models/query_utils.py | 10 +---- docs/ref/models/instances.txt | 67 +++++++++++++++++++++++++++++++ docs/releases/1.8.txt | 6 +++ tests/basic/tests.py | 59 ++++++++++++++++++++++++++- tests/defer/models.py | 14 +++++++ tests/defer/tests.py | 28 ++++++++++++- tests/field_subclassing/tests.py | 6 +++ tests/multiple_database/tests.py | 18 +++++++++ 9 files changed, 265 insertions(+), 11 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index e4f970af05..c2f711ea08 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -13,6 +13,8 @@ from django.core.exceptions import (ObjectDoesNotExist, MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS) from django.db import (router, connections, transaction, DatabaseError, DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY) +from django.db.models import signals +from django.db.models.constants import LOOKUP_SEP from django.db.models.deletion import Collector from django.db.models.fields import AutoField, FieldDoesNotExist from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel, @@ -21,7 +23,6 @@ from django.db.models.manager import ensure_default_manager from django.db.models.options import Options from django.db.models.query import Q from django.db.models.query_utils import DeferredAttribute, deferred_class_factory -from django.db.models import signals from django.utils import six from django.utils.deprecation import RemovedInDjango19Warning from django.utils.encoding import force_str, force_text @@ -552,6 +553,71 @@ class Model(six.with_metaclass(ModelBase)): pk = property(_get_pk_val, _set_pk_val) + def get_deferred_fields(self): + """ + Returns a set containing names of deferred fields for this instance. + """ + return { + f.attname for f in self._meta.concrete_fields + if isinstance(self.__class__.__dict__.get(f.attname), DeferredAttribute) + } + + def refresh_from_db(self, using=None, fields=None, **kwargs): + """ + Reloads field values from the database. + + By default, the reloading happens from the database this instance was + loaded from, or by the read router if this instance wasn't loaded from + any database. The using parameter will override the default. + + Fields can be used to specify which fields to reload. The fields + should be an iterable of field attnames. If fields is None, then + all non-deferred fields are reloaded. + + When accessing deferred fields of an instance, the deferred loading + of the field will call this method. + """ + if fields is not None: + if len(fields) == 0: + return + if any(LOOKUP_SEP in f for f in fields): + raise ValueError( + 'Found "%s" in fields argument. Relations and transforms ' + 'are not allowed in fields.' % LOOKUP_SEP) + + db = using if using is not None else self._state.db + if self._deferred: + non_deferred_model = self._meta.proxy_for_model + else: + non_deferred_model = self.__class__ + db_instance_qs = non_deferred_model._default_manager.using(db).filter(pk=self.pk) + + # Use provided fields, if not set then reload all non-deferred fields. + if fields is not None: + fields = list(fields) + db_instance_qs = db_instance_qs.only(*fields) + elif self._deferred: + deferred_fields = self.get_deferred_fields() + fields = [f.attname for f in self._meta.concrete_fields + if f.attname not in deferred_fields] + db_instance_qs = db_instance_qs.only(*fields) + + db_instance = db_instance_qs.get() + non_loaded_fields = db_instance.get_deferred_fields() + for field in self._meta.concrete_fields: + if field.attname in non_loaded_fields: + # This field wasn't refreshed - skip ahead. + continue + setattr(self, field.attname, getattr(db_instance, field.attname)) + # Throw away stale foreign key references. + if field.rel and field.get_cache_name() in self.__dict__: + rel_instance = getattr(self, field.get_cache_name()) + local_val = getattr(db_instance, field.attname) + related_val = getattr(rel_instance, field.related_field.attname) + if local_val != related_val: + del self.__dict__[field.get_cache_name()] + self._state.db = db_instance._state.db + def serializable_value(self, field_name): """ Returns the value of the field name for this instance. If the field is diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 6dbeb855d1..0fbe25cca7 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -109,14 +109,8 @@ class DeferredAttribute(object): # might be able to reuse the already loaded value. Refs #18343. val = self._check_parent_chain(instance, name) if val is None: - # We use only() instead of values() here because we want the - # various data coercion methods (to_python(), etc.) to be - # called here. - val = getattr( - non_deferred_model._base_manager.only(name).using( - instance._state.db).get(pk=instance.pk), - self.field_name - ) + instance.refresh_from_db(fields=[self.field_name]) + val = getattr(instance, self.field_name) data[self.field_name] = val return data[self.field_name] diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt index 4e8a1d1ee8..21dbb5af1a 100644 --- a/docs/ref/models/instances.txt +++ b/docs/ref/models/instances.txt @@ -116,6 +116,73 @@ The example above shows a full ``from_db()`` implementation to clarify how that is done. In this case it would of course be possible to just use ``super()`` call in the ``from_db()`` method. +Refreshing objects from database +================================ + +.. method:: Model.refresh_from_db(using=None, fields=None, **kwargs) + +.. versionadded:: 1.8 + +If you need to reload a model's values from the database, you can use the +``refresh_from_db()`` method. When this method is called without arguments the +following is done: + +1. All non-deferred fields of the model are updated to the values currently + present in the database. +2. The previously loaded related instances for which the relation's value is no + longer valid are removed from the reloaded instance. For example, if you have + a foreign key from the reloaded instance to another model with name + ``Author``, then if ``obj.author_id != obj.author.id``, ``obj.author`` will + be thrown away, and when next accessed it will be reloaded with the value of + ``obj.author_id``. + +Note that only fields of the model are reloaded from the database. Other +database dependent values such as annotations are not reloaded. + +The reloading happens from the database the instance was loaded from, or from +the default database if the instance wasn't loaded from the database. The +``using`` argument can be used to force the database used for reloading. + +It is possible to force the set of fields to be loaded by using the ``fields`` +argument. + +For example, to test that an ``update()`` call resulted in the expected +update, you could write a test similar to this:: + + def test_update_result(self): + obj = MyModel.objects.create(val=1) + MyModel.objects.filter(pk=obj.pk).update(val=F('val') + 1) + # At this point obj.val is still 1, but the value in the database + # was updated to 2. The object's updated value needs to be reloaded + # from the database. + obj.refresh_from_db() + self.assertEqual(obj.val, 2) + +Note that when deferred fields are accessed, the loading of the deferred +field's value happens through this method. Thus it is possible to customize +the way deferred loading happens. The example below shows how one can reload +all of the instance's fields when a deferred field is reloaded:: + + class ExampleModel(models.Model): + def refresh_from_db(self, using=None, fields=None, **kwargs): + # fields contains the name of the deferred field to be + # loaded. + if fields is not None: + fields = set(fields) + deferred_fields = self.get_deferred_fields() + # If any deferred field is going to be loaded + if fields.intersection(deferred_fields): + # then load all of them + fields = fields.union(deferred_fields) + super(ExampleModel, self).refresh_from_db(using, fields, **kwargs) + +.. method:: Model.get_deferred_fields() + +.. versionadded:: 1.8 + +A helper method that returns a set containing the attribute names of all those +fields that are currently deferred for this model. + .. _validating-objects: Validating objects diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index fb4e44576c..5df0cc1bd1 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -391,6 +391,12 @@ Models by the database, which can lead to somewhat complex queries involving nested ``REPLACE`` function calls. +* You can now refresh model instances by using :meth:`Model.refresh_from_db() + `. + +* You can now get the set of deferred fields for a model using + :meth:`Model.get_deferred_fields() `. + Signals ^^^^^^^ diff --git a/tests/basic/tests.py b/tests/basic/tests.py index a10e38f2a4..31e8b724bc 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals -from datetime import datetime +from datetime import datetime, timedelta import threading from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned @@ -713,3 +713,60 @@ class SelectOnSaveTests(TestCase): asos.save(update_fields=['pub_date']) finally: Article._base_manager.__class__ = orig_class + + +class ModelRefreshTests(TestCase): + def _truncate_ms(self, val): + # MySQL < 5.6.4 removes microseconds from the datetimes which can cause + # problems when comparing the original value to that loaded from DB + return val - timedelta(microseconds=val.microsecond) + + def test_refresh(self): + a = Article.objects.create(pub_date=self._truncate_ms(datetime.now())) + Article.objects.create(pub_date=self._truncate_ms(datetime.now())) + Article.objects.filter(pk=a.pk).update(headline='new headline') + with self.assertNumQueries(1): + a.refresh_from_db() + self.assertEqual(a.headline, 'new headline') + + orig_pub_date = a.pub_date + new_pub_date = a.pub_date + timedelta(10) + Article.objects.update(headline='new headline 2', pub_date=new_pub_date) + with self.assertNumQueries(1): + a.refresh_from_db(fields=['headline']) + self.assertEqual(a.headline, 'new headline 2') + self.assertEqual(a.pub_date, orig_pub_date) + with self.assertNumQueries(1): + a.refresh_from_db() + self.assertEqual(a.pub_date, new_pub_date) + + def test_refresh_fk(self): + s1 = SelfRef.objects.create() + s2 = SelfRef.objects.create() + s3 = SelfRef.objects.create(selfref=s1) + s3_copy = SelfRef.objects.get(pk=s3.pk) + s3_copy.selfref.touched = True + s3.selfref = s2 + s3.save() + with self.assertNumQueries(1): + s3_copy.refresh_from_db() + with self.assertNumQueries(1): + # The old related instance was thrown away (the selfref_id has + # changed). It needs to be reloaded on access, so one query + # executed. + self.assertFalse(hasattr(s3_copy.selfref, 'touched')) + self.assertEqual(s3_copy.selfref, s2) + + def test_refresh_unsaved(self): + pub_date = self._truncate_ms(datetime.now()) + a = Article.objects.create(pub_date=pub_date) + a2 = Article(id=a.pk) + with self.assertNumQueries(1): + a2.refresh_from_db() + self.assertEqual(a2.pub_date, pub_date) + self.assertEqual(a2._state.db, "default") + + def test_refresh_no_fields(self): + a = Article.objects.create(pub_date=self._truncate_ms(datetime.now())) + with self.assertNumQueries(0): + a.refresh_from_db(fields=[]) diff --git a/tests/defer/models.py b/tests/defer/models.py index ffc8a0c2c7..ecf69c0d7f 100644 --- a/tests/defer/models.py +++ b/tests/defer/models.py @@ -32,3 +32,17 @@ class BigChild(Primary): class ChildProxy(Child): class Meta: proxy = True + + +class RefreshPrimaryProxy(Primary): + class Meta: + proxy = True + + def refresh_from_db(self, using=None, fields=None, **kwargs): + # Reloads all deferred fields if any of the fields is deferred. + if fields is not None: + fields = set(fields) + deferred_fields = self.get_deferred_fields() + if fields.intersection(deferred_fields): + fields = fields.union(deferred_fields) + super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs) diff --git a/tests/defer/tests.py b/tests/defer/tests.py index 43a088f3e2..597f871cc8 100644 --- a/tests/defer/tests.py +++ b/tests/defer/tests.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals from django.db.models.query_utils import DeferredAttribute, InvalidQuery from django.test import TestCase -from .models import Secondary, Primary, Child, BigChild, ChildProxy +from .models import Secondary, Primary, Child, BigChild, ChildProxy, RefreshPrimaryProxy class DeferTests(TestCase): @@ -189,3 +189,29 @@ class DeferTests(TestCase): s1_defer = Secondary.objects.only('pk').get(pk=s1.pk) self.assertEqual(s1, s1_defer) self.assertEqual(s1_defer, s1) + + def test_refresh_not_loading_deferred_fields(self): + s = Secondary.objects.create() + rf = Primary.objects.create(name='foo', value='bar', related=s) + rf2 = Primary.objects.only('related', 'value').get() + rf.name = 'new foo' + rf.value = 'new bar' + rf.save() + with self.assertNumQueries(1): + rf2.refresh_from_db() + self.assertEqual(rf2.value, 'new bar') + with self.assertNumQueries(1): + self.assertEqual(rf2.name, 'new foo') + + def test_custom_refresh_on_deferred_loading(self): + s = Secondary.objects.create() + rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s) + rf2 = RefreshPrimaryProxy.objects.only('related').get() + rf.name = 'new foo' + rf.value = 'new bar' + rf.save() + with self.assertNumQueries(1): + # Customized refresh_from_db() reloads all deferred fields on + # access of any of them. + self.assertEqual(rf2.name, 'new foo') + self.assertEqual(rf2.value, 'new bar') diff --git a/tests/field_subclassing/tests.py b/tests/field_subclassing/tests.py index 5c695a455c..9e40d92496 100644 --- a/tests/field_subclassing/tests.py +++ b/tests/field_subclassing/tests.py @@ -11,6 +11,12 @@ from .models import ChoicesModel, DataModel, MyModel, OtherModel class CustomField(TestCase): + def test_refresh(self): + d = DataModel.objects.create(data=[1, 2, 3]) + d.refresh_from_db(fields=['data']) + self.assertIsInstance(d.data, list) + self.assertEqual(d.data, [1, 2, 3]) + def test_defer(self): d = DataModel.objects.create(data=[1, 2, 3]) diff --git a/tests/multiple_database/tests.py b/tests/multiple_database/tests.py index f230311eed..17a1ef90b1 100644 --- a/tests/multiple_database/tests.py +++ b/tests/multiple_database/tests.py @@ -112,6 +112,24 @@ class QueryTestCase(TestCase): title="Dive into Python" ) + def test_refresh(self): + dive = Book() + dive.title = "Dive into Python" + dive = Book() + dive.title = "Dive into Python" + dive.published = datetime.date(2009, 5, 4) + dive.save(using='other') + dive.published = datetime.date(2009, 5, 4) + dive.save(using='other') + dive2 = Book.objects.using('other').get() + dive2.title = "Dive into Python (on default)" + dive2.save(using='default') + dive.refresh_from_db() + self.assertEqual(dive.title, "Dive into Python") + dive.refresh_from_db(using='default') + self.assertEqual(dive.title, "Dive into Python (on default)") + self.assertEqual(dive._state.db, "default") + def test_basic_queries(self): "Queries are constrained to a single database" dive = Book.objects.using('other').create(title="Dive into Python",