From 5ed20b3aa3536539b9bc1cb5b40e84e3147a2228 Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Tue, 23 Jul 2019 05:04:06 -0700 Subject: [PATCH] Fixed #30657 -- Allowed customizing Field's descriptors with a descriptor_class attribute. Allows model fields to override the descriptor class used on the model instance attribute. --- django/db/models/fields/__init__.py | 4 +++- docs/ref/models/fields.txt | 10 ++++++++++ docs/releases/3.0.txt | 4 ++++ tests/field_subclassing/fields.py | 20 ++++++++++++++++++++ tests/field_subclassing/tests.py | 27 +++++++++++++++++++++++++-- 5 files changed, 62 insertions(+), 3 deletions(-) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index a16713a397..1388dffc58 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -123,6 +123,8 @@ class Field(RegisterLookupMixin): one_to_one = None related_model = None + descriptor_class = DeferredAttribute + # Generic field type description, usually overridden by subclasses def _description(self): return _('Field of type: %(field_type)s') % { @@ -738,7 +740,7 @@ class Field(RegisterLookupMixin): # if you have a classmethod and a field with the same name, then # such fields can't be deferred (we don't have a check for this). if not getattr(cls, self.attname, None): - setattr(cls, self.attname, DeferredAttribute(self)) + setattr(cls, self.attname, self.descriptor_class(self)) if self.choices is not None: setattr(cls, 'get_%s_display' % self.name, partialmethod(cls._get_FIELD_display, field=self)) diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 758c21c7bc..ead39f0572 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1793,6 +1793,16 @@ Field API reference where the arguments are interpolated from the field's ``__dict__``. + .. attribute:: descriptor_class + + .. versionadded:: 3.0 + + A class implementing the :py:ref:`descriptor protocol ` + that is instantiated and assigned to the model instance attribute. The + constructor must accept a single argument, the ``Field`` instance. + Overriding this class attribute allows for customizing the get and set + behavior. + To map a ``Field`` to a database-specific type, Django exposes several methods: diff --git a/docs/releases/3.0.txt b/docs/releases/3.0.txt index 34152573f0..07a2396c80 100644 --- a/docs/releases/3.0.txt +++ b/docs/releases/3.0.txt @@ -283,6 +283,10 @@ Models :class:`~django.db.models.Index` now support app label and class interpolation using the ``'%(app_label)s'`` and ``'%(class)s'`` placeholders. +* The new :attr:`.Field.descriptor_class` attribute allows model fields to + customize the get and set behavior by overriding their + :py:ref:`descriptors `. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/field_subclassing/fields.py b/tests/field_subclassing/fields.py index 4eb170116f..1e255ba9cd 100644 --- a/tests/field_subclassing/fields.py +++ b/tests/field_subclassing/fields.py @@ -1,6 +1,26 @@ from django.db import models +from django.db.models.query_utils import DeferredAttribute class CustomTypedField(models.TextField): def db_type(self, connection): return 'custom_field' + + +class CustomDeferredAttribute(DeferredAttribute): + def __get__(self, instance, cls=None): + self._count_call(instance, 'get') + return super().__get__(instance, cls) + + def __set__(self, instance, value): + self._count_call(instance, 'set') + instance.__dict__[self.field.attname] = value + + def _count_call(self, instance, get_or_set): + count_attr = '_%s_%s_count' % (self.field.attname, get_or_set) + count = getattr(instance, count_attr, 0) + setattr(instance, count_attr, count + 1) + + +class CustomDescriptorField(models.CharField): + descriptor_class = CustomDeferredAttribute diff --git a/tests/field_subclassing/tests.py b/tests/field_subclassing/tests.py index a1371cab42..6d62a5da0e 100644 --- a/tests/field_subclassing/tests.py +++ b/tests/field_subclassing/tests.py @@ -1,7 +1,7 @@ -from django.db import connection +from django.db import connection, models from django.test import SimpleTestCase -from .fields import CustomTypedField +from .fields import CustomDescriptorField, CustomTypedField class TestDbType(SimpleTestCase): @@ -9,3 +9,26 @@ class TestDbType(SimpleTestCase): def test_db_parameters_respects_db_type(self): f = CustomTypedField() self.assertEqual(f.db_parameters(connection)['type'], 'custom_field') + + +class DescriptorClassTest(SimpleTestCase): + def test_descriptor_class(self): + class CustomDescriptorModel(models.Model): + name = CustomDescriptorField(max_length=32) + + m = CustomDescriptorModel() + self.assertFalse(hasattr(m, '_name_get_count')) + # The field is set to its default in the model constructor. + self.assertEqual(m._name_set_count, 1) + m.name = 'foo' + self.assertFalse(hasattr(m, '_name_get_count')) + self.assertEqual(m._name_set_count, 2) + self.assertEqual(m.name, 'foo') + self.assertEqual(m._name_get_count, 1) + self.assertEqual(m._name_set_count, 2) + m.name = 'bar' + self.assertEqual(m._name_get_count, 1) + self.assertEqual(m._name_set_count, 3) + self.assertEqual(m.name, 'bar') + self.assertEqual(m._name_get_count, 2) + self.assertEqual(m._name_set_count, 3)