Fixed #24858 -- Added support for get_FOO_display() to ArrayField and RangeFields.

_get_FIELD_display() crashed when Field.choices was unhashable.
This commit is contained in:
Hasan Ramezani 2019-11-07 15:35:33 +01:00 committed by Mariusz Felisiak
parent 8058d9d7ad
commit 153c7956f8
5 changed files with 84 additions and 1 deletions

View File

@ -33,6 +33,7 @@ from django.db.models.signals import (
) )
from django.db.models.utils import make_model_tuple from django.db.models.utils import make_model_tuple
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.hashable import make_hashable
from django.utils.text import capfirst, get_text_list from django.utils.text import capfirst, get_text_list
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils.version import get_version from django.utils.version import get_version
@ -940,8 +941,9 @@ class Model(metaclass=ModelBase):
def _get_FIELD_display(self, field): def _get_FIELD_display(self, field):
value = getattr(self, field.attname) value = getattr(self, field.attname)
choices_dict = dict(make_hashable(field.flatchoices))
# force_str() to coerce lazy strings. # force_str() to coerce lazy strings.
return force_str(dict(field.flatchoices).get(value, value), strings_only=True) return force_str(choices_dict.get(make_hashable(value), value), strings_only=True)
def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs): def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
if not self.pk: if not self.pk:

View File

@ -797,6 +797,11 @@ For example::
>>> p.get_shirt_size_display() >>> p.get_shirt_size_display()
'Large' 'Large'
.. versionchanged:: 3.1
Support for :class:`~django.contrib.postgres.fields.ArrayField` and
:class:`~django.contrib.postgres.fields.RangeField` was added.
.. method:: Model.get_next_by_FOO(**kwargs) .. method:: Model.get_next_by_FOO(**kwargs)
.. method:: Model.get_previous_by_FOO(**kwargs) .. method:: Model.get_previous_by_FOO(**kwargs)

View File

@ -76,6 +76,10 @@ Minor features
:class:`~django.contrib.postgres.operations.BloomExtension` migration :class:`~django.contrib.postgres.operations.BloomExtension` migration
operation installs the ``bloom`` extension to add support for this index. operation installs the ``bloom`` extension to add support for this index.
* :meth:`~django.db.models.Model.get_FOO_display` now supports
:class:`~django.contrib.postgres.fields.ArrayField` and
:class:`~django.contrib.postgres.fields.RangeField`.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -37,6 +37,53 @@ except ImportError:
pass pass
@isolate_apps('postgres_tests')
class BasicTests(PostgreSQLSimpleTestCase):
def test_get_field_display(self):
class MyModel(PostgreSQLModel):
field = ArrayField(
models.CharField(max_length=16),
choices=[
['Media', [(['vinyl', 'cd'], 'Audio')]],
(('mp3', 'mp4'), 'Digital'),
],
)
tests = (
(['vinyl', 'cd'], 'Audio'),
(('mp3', 'mp4'), 'Digital'),
(('a', 'b'), "('a', 'b')"),
(['c', 'd'], "['c', 'd']"),
)
for value, display in tests:
with self.subTest(value=value, display=display):
instance = MyModel(field=value)
self.assertEqual(instance.get_field_display(), display)
def test_get_field_display_nested_array(self):
class MyModel(PostgreSQLModel):
field = ArrayField(
ArrayField(models.CharField(max_length=16)),
choices=[
[
'Media',
[([['vinyl', 'cd'], ('x',)], 'Audio')],
],
((['mp3'], ('mp4',)), 'Digital'),
],
)
tests = (
([['vinyl', 'cd'], ('x',)], 'Audio'),
((['mp3'], ('mp4',)), 'Digital'),
((('a', 'b'), ('c',)), "(('a', 'b'), ('c',))"),
([['a', 'b'], ['c']], "[['a', 'b'], ['c']]"),
)
for value, display in tests:
with self.subTest(value=value, display=display):
instance = MyModel(field=value)
self.assertEqual(instance.get_field_display(), display)
class TestSaveLoad(PostgreSQLTestCase): class TestSaveLoad(PostgreSQLTestCase):
def test_integer(self): def test_integer(self):

View File

@ -7,6 +7,7 @@ from django.core import exceptions, serializers
from django.db.models import DateField, DateTimeField, F, Func, Value from django.db.models import DateField, DateTimeField, F, Func, Value
from django.http import QueryDict from django.http import QueryDict
from django.test import override_settings from django.test import override_settings
from django.test.utils import isolate_apps
from django.utils import timezone from django.utils import timezone
from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase
@ -22,6 +23,30 @@ except ImportError:
pass pass
@isolate_apps('postgres_tests')
class BasicTests(PostgreSQLSimpleTestCase):
def test_get_field_display(self):
class Model(PostgreSQLModel):
field = pg_fields.IntegerRangeField(
choices=[
['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]],
((51, 100), '51-100'),
],
)
tests = (
((1, 25), '1-25'),
([26, 50], '26-50'),
((51, 100), '51-100'),
((1, 2), '(1, 2)'),
([1, 2], '[1, 2]'),
)
for value, display in tests:
with self.subTest(value=value, display=display):
instance = Model(field=value)
self.assertEqual(instance.get_field_display(), display)
class TestSaveLoad(PostgreSQLTestCase): class TestSaveLoad(PostgreSQLTestCase):
def test_all_fields(self): def test_all_fields(self):