From 72ebe85a269aab4bdb3829de4846b41f90973c5d Mon Sep 17 00:00:00 2001 From: Shai Berger Date: Mon, 31 Dec 2018 19:57:35 +0200 Subject: [PATCH] Fixed #27910 -- Added enumeration helpers for use in Field.choices. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These classes can serve as a base class for user enums, supporting translatable human-readable names, or names automatically inferred from the enum member name. Additional properties make it easy to access the list of names, values and display labels. Thanks to the following for ideas and reviews: Carlton Gibson, Fran Hrženjak, Ian Foote, Mariusz Felisiak, Shai Berger. Co-authored-by: Shai Berger Co-authored-by: Nick Pope Co-authored-by: Mariusz Felisiak --- django/db/migrations/serializer.py | 6 + django/db/models/__init__.py | 4 +- django/db/models/enums.py | 75 +++++++ docs/ref/models/fields.txt | 111 ++++++++++- docs/releases/3.0.txt | 12 ++ docs/spelling_wordlist | 1 + docs/topics/db/models.txt | 13 ++ tests/migrations/test_writer.py | 42 ++++ tests/model_enums/tests.py | 253 ++++++++++++++++++++++++ tests/model_fields/test_charfield.py | 27 +++ tests/model_fields/test_integerfield.py | 14 ++ 11 files changed, 554 insertions(+), 4 deletions(-) create mode 100644 django/db/models/enums.py create mode 100644 tests/model_enums/tests.py diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py index 1f1b3f4f20a..27b5cbd379d 100644 --- a/django/db/migrations/serializer.py +++ b/django/db/migrations/serializer.py @@ -46,6 +46,11 @@ class BaseSimpleSerializer(BaseSerializer): return repr(self.value), set() +class ChoicesSerializer(BaseSerializer): + def serialize(self): + return serializer_factory(self.value.value).serialize() + + class DateTimeSerializer(BaseSerializer): """For datetime.*, except datetime.datetime.""" def serialize(self): @@ -279,6 +284,7 @@ class Serializer: set: SetSerializer, tuple: TupleSerializer, dict: DictionarySerializer, + models.Choices: ChoicesSerializer, enum.Enum: EnumSerializer, datetime.datetime: DatetimeDatetimeSerializer, (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer, diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 79b175c1d52..cca87ed0e73 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -7,6 +7,8 @@ from django.db.models.constraints import __all__ as constraints_all from django.db.models.deletion import ( CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, ) +from django.db.models.enums import * # NOQA +from django.db.models.enums import __all__ as enums_all from django.db.models.expressions import ( Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame, @@ -32,7 +34,7 @@ from django.db.models.fields.related import ( # isort:skip ) -__all__ = aggregates_all + constraints_all + fields_all + indexes_all +__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all __all__ += [ 'ObjectDoesNotExist', 'signals', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', diff --git a/django/db/models/enums.py b/django/db/models/enums.py new file mode 100644 index 00000000000..bbe362a6abf --- /dev/null +++ b/django/db/models/enums.py @@ -0,0 +1,75 @@ +import enum + +from django.utils.functional import Promise + +__all__ = ['Choices', 'IntegerChoices', 'TextChoices'] + + +class ChoicesMeta(enum.EnumMeta): + """A metaclass for creating a enum choices.""" + + def __new__(metacls, classname, bases, classdict): + labels = [] + for key in classdict._member_names: + value = classdict[key] + if ( + isinstance(value, (list, tuple)) and + len(value) > 1 and + isinstance(value[-1], (Promise, str)) + ): + *value, label = value + value = tuple(value) + else: + label = key.replace('_', ' ').title() + labels.append(label) + # Use dict.__setitem__() to suppress defenses against double + # assignment in enum's classdict. + dict.__setitem__(classdict, key, value) + cls = super().__new__(metacls, classname, bases, classdict) + cls._value2label_map_ = dict(zip(cls._value2member_map_, labels)) + # Add a label property to instances of enum which uses the enum member + # that is passed in as "self" as the value to use when looking up the + # label in the choices. + cls.label = property(lambda self: cls._value2label_map_.get(self.value)) + return enum.unique(cls) + + def __contains__(cls, member): + if not isinstance(member, enum.Enum): + # Allow non-enums to match against member values. + return member in {x.value for x in cls} + return super().__contains__(member) + + @property + def names(cls): + empty = ['__empty__'] if hasattr(cls, '__empty__') else [] + return empty + [member.name for member in cls] + + @property + def choices(cls): + empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else [] + return empty + [(member.value, member.label) for member in cls] + + @property + def labels(cls): + return [label for _, label in cls.choices] + + @property + def values(cls): + return [value for value, _ in cls.choices] + + +class Choices(enum.Enum, metaclass=ChoicesMeta): + """Class for creating enumerated choices.""" + pass + + +class IntegerChoices(int, Choices): + """Class for creating enumerated integer choices.""" + pass + + +class TextChoices(str, Choices): + """Class for creating enumerated string choices.""" + + def _generate_next_value_(name, start, count, last_values): + return name diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 01a56c13128..49494186cea 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -94,6 +94,7 @@ and the second element is the human-readable name. For example:: ('SO', 'Sophomore'), ('JR', 'Junior'), ('SR', 'Senior'), + ('GR', 'Graduate'), ] Generally, it's best to define choices inside a model class, and to @@ -106,11 +107,13 @@ define a suitably-named constant for each value:: SOPHOMORE = 'SO' JUNIOR = 'JR' SENIOR = 'SR' + GRADUATE = 'GR' YEAR_IN_SCHOOL_CHOICES = [ (FRESHMAN, 'Freshman'), (SOPHOMORE, 'Sophomore'), (JUNIOR, 'Junior'), (SENIOR, 'Senior'), + (GRADUATE, 'Graduate'), ] year_in_school = models.CharField( max_length=2, @@ -119,7 +122,7 @@ define a suitably-named constant for each value:: ) def is_upperclass(self): - return self.year_in_school in (self.JUNIOR, self.SENIOR) + return self.year_in_school in {self.JUNIOR, self.SENIOR} Though you can define a choices list outside of a model class and then refer to it, defining the choices and names for each choice inside the @@ -127,6 +130,95 @@ model class keeps all of that information with the class that uses it, and makes the choices easy to reference (e.g, ``Student.SOPHOMORE`` will work anywhere that the ``Student`` model has been imported). +In addition, Django provides enumeration types that you can subclass to define +choices in a concise way:: + + from django.utils.translation import gettext_lazy as _ + + class Student(models.Model): + + class YearInSchool(models.TextChoices): + FRESHMAN = 'FR', _('Freshman') + SOPHOMORE = 'SO', _('Sophomore') + JUNIOR = 'JR', _('Junior') + SENIOR = 'SR', _('Senior') + GRADUATE = 'GR', _('Graduate') + + year_in_school = models.CharField( + max_length=2, + choices=YearInSchool.choices, + default=YearInSchool.FRESHMAN, + ) + + def is_upperclass(self): + return self.year_in_school in {YearInSchool.JUNIOR, YearInSchool.SENIOR} + +These work similar to :mod:`enum` from Python's standard library, but with some +modifications: + +* Instead of values in the ``enum``, Django uses ``(value, label)`` tuples. The + ``label`` can be a lazy translatable string. If a tuple is not provided, the + label is automatically generated from the member name. +* ``.label`` property is added on values, to return the label specified. +* Number of custom properties are added to the enumeration classes -- + ``.choices``, ``.labels``, ``.values``, and ``.names`` -- to make it easier + to access lists of those separate parts of the enumeration. Use ``.choices`` + as a suitable value to pass to :attr:`~Field.choices` in a field definition. +* The use of :func:`enum.unique()` is enforced to ensure that values cannot be + defined multiple times. This is unlikely to be expected in choices for a + field. + +Note that ``YearInSchool.SENIOR``, ``YearInSchool['SENIOR']``, +``YearInSchool('SR')`` work as expected, while ``YearInSchool.SENIOR.label`` is +a translatable string. + +If you don't need to have the human-readable names translated, you can have +them inferred from the member name (replacing underscores to spaces and using +title-case):: + + class YearInSchool(models.TextChoices): + FRESHMAN = 'FR' + SOPHOMORE = 'SO' + JUNIOR = 'JR' + SENIOR = 'SR' + GRADUATE = 'GR' + +Since the case where the enum values need to be integers is extremely common, +Django provides a ``IntegerChoices`` class. For example:: + + class Card(models.Model): + + class Suit(models.IntegerChoices): + DIAMOND = 1 + SPADE = 2 + HEART = 3 + CLUB = 4 + + suit = models.IntegerField(choices=Suit.choices) + +It is also possible to make use of the `Enum Functional API +`_ with the caveat +that labels are automatically generated as highlighted above:: + + >>> MedalType = models.TextChoices('MedalType', 'GOLD SILVER BRONZE') + >>> MedalType.choices + [('GOLD', 'Gold'), ('SILVER', 'Silver'), ('BRONZE', 'Bronze')] + >>> Place = models.IntegerChoices('Place', 'FIRST SECOND THIRD') + >>> Place.choices + [(1, 'First'), (2, 'Second'), (3, 'Third')] + +If you require support for a concrete data type other than ``int`` or ``str``, +you can subclass ``Choices`` and the required concrete data type, e.g. +:class:``datetime.date`` for use with :class:`~django.db.models.DateField`:: + + class MoonLandings(datetime.date, models.Choices): + APOLLO_11 = 1969, 7, 20, 'Apollo 11 (Eagle)' + APOLLO_12 = 1969, 11, 19, 'Apollo 12 (Intrepid)' + APOLLO_14 = 1971, 2, 5, 'Apollo 14 (Antares)' + APOLLO_15 = 1971, 7, 30, 'Apollo 15 (Falcon)' + APOLLO_16 = 1972, 4, 21, 'Apollo 16 (Orion)' + APOLLO_17 = 1972, 12, 11, 'Apollo 17 (Challenger)' + You can also collect your available choices into named groups that can be used for organizational purposes:: @@ -148,7 +240,8 @@ The first element in each tuple is the name to apply to the group. The second element is an iterable of 2-tuples, with each 2-tuple containing a value and a human-readable name for an option. Grouped options may be combined with ungrouped options within a single list (such as the -`unknown` option in this example). +`unknown` option in this example). Grouping is not supported by the custom +enumeration types for managing choices. For each model field that has :attr:`~Field.choices` set, Django will add a method to retrieve the human-readable name for the field's current value. See @@ -169,7 +262,19 @@ Unless :attr:`blank=False` is set on the field along with a with the select box. To override this behavior, add a tuple to ``choices`` containing ``None``; e.g. ``(None, 'Your String For Display')``. Alternatively, you can use an empty string instead of ``None`` where this makes -sense - such as on a :class:`~django.db.models.CharField`. +sense - such as on a :class:`~django.db.models.CharField`. To change the label +when using one of the custom enumeration types, set the ``__empty__`` attribute +on the class:: + + class Answer(models.IntegerChoices): + NO = 0, _('No') + YES = 1, _('Yes') + + __empty__ = _('(Unknown)') + +.. versionadded:: 3.0 + + The ``TextChoices``, ``IntegerChoices``, and ``Choices`` classes were added. ``db_column`` ------------- diff --git a/docs/releases/3.0.txt b/docs/releases/3.0.txt index 9745a346600..2868e37e161 100644 --- a/docs/releases/3.0.txt +++ b/docs/releases/3.0.txt @@ -81,6 +81,18 @@ Expressions that outputs :class:`~django.db.models.BooleanField` may now be used directly in ``QuerySet`` filters, without having to first annotate and then filter against the annotation. +Enumerations for model field choices +------------------------------------ + +Custom enumeration types ``TextChoices``, ``IntegerChoices``, and ``Choices`` +are now available as a way to define :attr:`.Field.choices`. ``TextChoices`` +and ``IntegerChoices`` types are provided for text and integer fields. The +``Choices`` class allows defining a compatible enumeration for other concrete +data types. These custom enumeration types support human-readable labels that +can be translated and accessed via a property on the enumeration or its +members. See :ref:`Field.choices documentation ` for more +details and examples. + Minor features -------------- diff --git a/docs/spelling_wordlist b/docs/spelling_wordlist index 445a64adfc9..43e8a683858 100644 --- a/docs/spelling_wordlist +++ b/docs/spelling_wordlist @@ -193,6 +193,7 @@ elidable encodings Endian Enero +enum environ esque Ess diff --git a/docs/topics/db/models.txt b/docs/topics/db/models.txt index f47490877f3..0aab0bb35fe 100644 --- a/docs/topics/db/models.txt +++ b/docs/topics/db/models.txt @@ -198,6 +198,19 @@ ones: >>> p.get_shirt_size_display() 'Large' + You can also use enumeration classes to define ``choices`` in a concise + way:: + + from django.db import models + + class Runner(models.Model): + MedalType = models.TextChoices('MedalType', 'GOLD SILVER BRONZE') + name = models.CharField(max_length=60) + medal = models.CharField(blank=True, choices=MedalType.choices, max_length=10) + + Further examples are available in the :ref:`model field reference + `. + :attr:`~Field.default` The default value for the field. This can be a value or a callable object. If callable it will be called every time a new object is diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 790d6a3efac..f97d76e9adc 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -306,6 +306,48 @@ class WriterTests(SimpleTestCase): "default=migrations.test_writer.IntEnum(1))" ) + def test_serialize_choices(self): + class TextChoices(models.TextChoices): + A = 'A', 'A value' + B = 'B', 'B value' + + class IntegerChoices(models.IntegerChoices): + A = 1, 'One' + B = 2, 'Two' + + class DateChoices(datetime.date, models.Choices): + DATE_1 = 1969, 7, 20, 'First date' + DATE_2 = 1969, 11, 19, 'Second date' + + self.assertSerializedResultEqual(TextChoices.A, ("'A'", set())) + self.assertSerializedResultEqual(IntegerChoices.A, ('1', set())) + self.assertSerializedResultEqual( + DateChoices.DATE_1, + ('datetime.date(1969, 7, 20)', {'import datetime'}), + ) + field = models.CharField(default=TextChoices.B, choices=TextChoices.choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.CharField(choices=[('A', 'A value'), ('B', 'B value')], " + "default='B')", + ) + field = models.IntegerField(default=IntegerChoices.B, choices=IntegerChoices.choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.IntegerField(choices=[(1, 'One'), (2, 'Two')], default=2)", + ) + field = models.DateField(default=DateChoices.DATE_2, choices=DateChoices.choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.DateField(choices=[" + "(datetime.date(1969, 7, 20), 'First date'), " + "(datetime.date(1969, 11, 19), 'Second date')], " + "default=datetime.date(1969, 11, 19))" + ) + def test_serialize_uuid(self): self.assertSerializedEqual(uuid.uuid1()) self.assertSerializedEqual(uuid.uuid4()) diff --git a/tests/model_enums/tests.py b/tests/model_enums/tests.py new file mode 100644 index 00000000000..6b4bd6e7fdd --- /dev/null +++ b/tests/model_enums/tests.py @@ -0,0 +1,253 @@ +import datetime +import decimal +import ipaddress +import uuid + +from django.db import models +from django.test import SimpleTestCase +from django.utils.functional import Promise +from django.utils.translation import gettext_lazy as _ + + +class Suit(models.IntegerChoices): + DIAMOND = 1, _('Diamond') + SPADE = 2, _('Spade') + HEART = 3, _('Heart') + CLUB = 4, _('Club') + + +class YearInSchool(models.TextChoices): + FRESHMAN = 'FR', _('Freshman') + SOPHOMORE = 'SO', _('Sophomore') + JUNIOR = 'JR', _('Junior') + SENIOR = 'SR', _('Senior') + GRADUATE = 'GR', _('Graduate') + + +class Vehicle(models.IntegerChoices): + CAR = 1, 'Carriage' + TRUCK = 2 + JET_SKI = 3 + + __empty__ = _('(Unknown)') + + +class Gender(models.TextChoices): + MALE = 'M' + FEMALE = 'F' + NOT_SPECIFIED = 'X' + + __empty__ = '(Undeclared)' + + +class ChoicesTests(SimpleTestCase): + def test_integerchoices(self): + self.assertEqual(Suit.choices, [(1, 'Diamond'), (2, 'Spade'), (3, 'Heart'), (4, 'Club')]) + self.assertEqual(Suit.labels, ['Diamond', 'Spade', 'Heart', 'Club']) + self.assertEqual(Suit.values, [1, 2, 3, 4]) + self.assertEqual(Suit.names, ['DIAMOND', 'SPADE', 'HEART', 'CLUB']) + + self.assertEqual(repr(Suit.DIAMOND), '') + self.assertEqual(Suit.DIAMOND.label, 'Diamond') + self.assertEqual(Suit.DIAMOND.value, 1) + self.assertEqual(Suit['DIAMOND'], Suit.DIAMOND) + self.assertEqual(Suit(1), Suit.DIAMOND) + + self.assertIsInstance(Suit, type(models.Choices)) + self.assertIsInstance(Suit.DIAMOND, Suit) + self.assertIsInstance(Suit.DIAMOND.label, Promise) + self.assertIsInstance(Suit.DIAMOND.value, int) + + def test_integerchoices_auto_label(self): + self.assertEqual(Vehicle.CAR.label, 'Carriage') + self.assertEqual(Vehicle.TRUCK.label, 'Truck') + self.assertEqual(Vehicle.JET_SKI.label, 'Jet Ski') + + def test_integerchoices_empty_label(self): + self.assertEqual(Vehicle.choices[0], (None, '(Unknown)')) + self.assertEqual(Vehicle.labels[0], '(Unknown)') + self.assertEqual(Vehicle.values[0], None) + self.assertEqual(Vehicle.names[0], '__empty__') + + def test_integerchoices_functional_api(self): + Place = models.IntegerChoices('Place', 'FIRST SECOND THIRD') + self.assertEqual(Place.labels, ['First', 'Second', 'Third']) + self.assertEqual(Place.values, [1, 2, 3]) + self.assertEqual(Place.names, ['FIRST', 'SECOND', 'THIRD']) + + def test_integerchoices_containment(self): + self.assertIn(Suit.DIAMOND, Suit) + self.assertIn(1, Suit) + self.assertNotIn(0, Suit) + + def test_textchoices(self): + self.assertEqual(YearInSchool.choices, [ + ('FR', 'Freshman'), ('SO', 'Sophomore'), ('JR', 'Junior'), ('SR', 'Senior'), ('GR', 'Graduate'), + ]) + self.assertEqual(YearInSchool.labels, ['Freshman', 'Sophomore', 'Junior', 'Senior', 'Graduate']) + self.assertEqual(YearInSchool.values, ['FR', 'SO', 'JR', 'SR', 'GR']) + self.assertEqual(YearInSchool.names, ['FRESHMAN', 'SOPHOMORE', 'JUNIOR', 'SENIOR', 'GRADUATE']) + + self.assertEqual(repr(YearInSchool.FRESHMAN), "") + self.assertEqual(YearInSchool.FRESHMAN.label, 'Freshman') + self.assertEqual(YearInSchool.FRESHMAN.value, 'FR') + self.assertEqual(YearInSchool['FRESHMAN'], YearInSchool.FRESHMAN) + self.assertEqual(YearInSchool('FR'), YearInSchool.FRESHMAN) + + self.assertIsInstance(YearInSchool, type(models.Choices)) + self.assertIsInstance(YearInSchool.FRESHMAN, YearInSchool) + self.assertIsInstance(YearInSchool.FRESHMAN.label, Promise) + self.assertIsInstance(YearInSchool.FRESHMAN.value, str) + + def test_textchoices_auto_label(self): + self.assertEqual(Gender.MALE.label, 'Male') + self.assertEqual(Gender.FEMALE.label, 'Female') + self.assertEqual(Gender.NOT_SPECIFIED.label, 'Not Specified') + + def test_textchoices_empty_label(self): + self.assertEqual(Gender.choices[0], (None, '(Undeclared)')) + self.assertEqual(Gender.labels[0], '(Undeclared)') + self.assertEqual(Gender.values[0], None) + self.assertEqual(Gender.names[0], '__empty__') + + def test_textchoices_functional_api(self): + Medal = models.TextChoices('Medal', 'GOLD SILVER BRONZE') + self.assertEqual(Medal.labels, ['Gold', 'Silver', 'Bronze']) + self.assertEqual(Medal.values, ['GOLD', 'SILVER', 'BRONZE']) + self.assertEqual(Medal.names, ['GOLD', 'SILVER', 'BRONZE']) + + def test_textchoices_containment(self): + self.assertIn(YearInSchool.FRESHMAN, YearInSchool) + self.assertIn('FR', YearInSchool) + self.assertNotIn('XX', YearInSchool) + + def test_textchoices_blank_value(self): + class BlankStr(models.TextChoices): + EMPTY = '', '(Empty)' + ONE = 'ONE', 'One' + + self.assertEqual(BlankStr.labels, ['(Empty)', 'One']) + self.assertEqual(BlankStr.values, ['', 'ONE']) + self.assertEqual(BlankStr.names, ['EMPTY', 'ONE']) + + def test_invalid_definition(self): + msg = "'str' object cannot be interpreted as an integer" + with self.assertRaisesMessage(TypeError, msg): + class InvalidArgumentEnum(models.IntegerChoices): + # A string is not permitted as the second argument to int(). + ONE = 1, 'X', 'Invalid' + + msg = "duplicate values found in : PINEAPPLE -> APPLE" + with self.assertRaisesMessage(ValueError, msg): + class Fruit(models.IntegerChoices): + APPLE = 1, 'Apple' + PINEAPPLE = 1, 'Pineapple' + + +class Separator(bytes, models.Choices): + FS = b'\x1c', 'File Separator' + GS = b'\x1d', 'Group Separator' + RS = b'\x1e', 'Record Separator' + US = b'\x1f', 'Unit Separator' + + +class Constants(float, models.Choices): + PI = 3.141592653589793, 'π' + TAU = 6.283185307179586, 'τ' + + +class Set(frozenset, models.Choices): + A = {1, 2} + B = {2, 3} + UNION = A | B + DIFFERENCE = A - B + INTERSECTION = A & B + + +class MoonLandings(datetime.date, models.Choices): + APOLLO_11 = 1969, 7, 20, 'Apollo 11 (Eagle)' + APOLLO_12 = 1969, 11, 19, 'Apollo 12 (Intrepid)' + APOLLO_14 = 1971, 2, 5, 'Apollo 14 (Antares)' + APOLLO_15 = 1971, 7, 30, 'Apollo 15 (Falcon)' + APOLLO_16 = 1972, 4, 21, 'Apollo 16 (Orion)' + APOLLO_17 = 1972, 12, 11, 'Apollo 17 (Challenger)' + + +class DateAndTime(datetime.datetime, models.Choices): + A = 2010, 10, 10, 10, 10, 10 + B = 2011, 11, 11, 11, 11, 11 + C = 2012, 12, 12, 12, 12, 12 + + +class MealTimes(datetime.time, models.Choices): + BREAKFAST = 7, 0 + LUNCH = 13, 0 + DINNER = 18, 30 + + +class Frequency(datetime.timedelta, models.Choices): + WEEK = 0, 0, 0, 0, 0, 0, 1, 'Week' + DAY = 1, 'Day' + HOUR = 0, 0, 0, 0, 0, 1, 'Hour' + MINUTE = 0, 0, 0, 0, 1, 'Hour' + SECOND = 0, 1, 'Second' + + +class Number(decimal.Decimal, models.Choices): + E = 2.718281828459045, 'e' + PI = '3.141592653589793', 'π' + TAU = decimal.Decimal('6.283185307179586'), 'τ' + + +class IPv4Address(ipaddress.IPv4Address, models.Choices): + LOCALHOST = '127.0.0.1', 'Localhost' + GATEWAY = '192.168.0.1', 'Gateway' + BROADCAST = '192.168.0.255', 'Broadcast' + + +class IPv6Address(ipaddress.IPv6Address, models.Choices): + LOCALHOST = '::1', 'Localhost' + UNSPECIFIED = '::', 'Unspecified' + + +class IPv4Network(ipaddress.IPv4Network, models.Choices): + LOOPBACK = '127.0.0.0/8', 'Loopback' + LINK_LOCAL = '169.254.0.0/16', 'Link-Local' + PRIVATE_USE_A = '10.0.0.0/8', 'Private-Use (Class A)' + + +class IPv6Network(ipaddress.IPv6Network, models.Choices): + LOOPBACK = '::1/128', 'Loopback' + UNSPECIFIED = '::/128', 'Unspecified' + UNIQUE_LOCAL = 'fc00::/7', 'Unique-Local' + LINK_LOCAL_UNICAST = 'fe80::/10', 'Link-Local Unicast' + + +class CustomChoicesTests(SimpleTestCase): + def test_labels_valid(self): + enums = ( + Separator, Constants, Set, MoonLandings, DateAndTime, MealTimes, + Frequency, Number, IPv4Address, IPv6Address, IPv4Network, + IPv6Network, + ) + for choice_enum in enums: + with self.subTest(choice_enum.__name__): + self.assertNotIn(None, choice_enum.labels) + + def test_bool_unsupported(self): + msg = "type 'bool' is not an acceptable base type" + with self.assertRaisesMessage(TypeError, msg): + class Boolean(bool, models.Choices): + pass + + def test_timezone_unsupported(self): + msg = "type 'datetime.timezone' is not an acceptable base type" + with self.assertRaisesMessage(TypeError, msg): + class Timezone(datetime.timezone, models.Choices): + pass + + def test_uuid_unsupported(self): + msg = 'UUID objects are immutable' + with self.assertRaisesMessage(TypeError, msg): + class Identifier(uuid.UUID, models.Choices): + A = '972ce4eb-a95f-4a56-9339-68c208a76f18' diff --git a/tests/model_fields/test_charfield.py b/tests/model_fields/test_charfield.py index 1b50f01f3af..051be2eeec9 100644 --- a/tests/model_fields/test_charfield.py +++ b/tests/model_fields/test_charfield.py @@ -28,9 +28,27 @@ class TestCharField(TestCase): p.refresh_from_db() self.assertEqual(p.title, 'Smile 😀') + def test_assignment_from_choice_enum(self): + class Event(models.TextChoices): + C = 'Carnival!' + F = 'Festival!' + + p1 = Post.objects.create(title=Event.C, body=Event.F) + p1.refresh_from_db() + self.assertEqual(p1.title, 'Carnival!') + self.assertEqual(p1.body, 'Festival!') + self.assertEqual(p1.title, Event.C) + self.assertEqual(p1.body, Event.F) + p2 = Post.objects.get(title='Carnival!') + self.assertEquals(p1, p2) + self.assertEquals(p2.title, Event.C) + class ValidationTests(SimpleTestCase): + class Choices(models.TextChoices): + C = 'c', 'C' + def test_charfield_raises_error_on_empty_string(self): f = models.CharField() with self.assertRaises(ValidationError): @@ -49,6 +67,15 @@ class ValidationTests(SimpleTestCase): with self.assertRaises(ValidationError): f.clean('not a', None) + def test_enum_choices_cleans_valid_string(self): + f = models.CharField(choices=self.Choices.choices, max_length=1) + self.assertEqual(f.clean('c', None), 'c') + + def test_enum_choices_invalid_input(self): + f = models.CharField(choices=self.Choices.choices, max_length=1) + with self.assertRaises(ValidationError): + f.clean('a', None) + def test_charfield_raises_error_on_empty_input(self): f = models.CharField(null=False) with self.assertRaises(ValidationError): diff --git a/tests/model_fields/test_integerfield.py b/tests/model_fields/test_integerfield.py index c0bb7627cfe..606d71057aa 100644 --- a/tests/model_fields/test_integerfield.py +++ b/tests/model_fields/test_integerfield.py @@ -184,6 +184,9 @@ class PositiveIntegerFieldTests(IntegerFieldTests): class ValidationTests(SimpleTestCase): + class Choices(models.IntegerChoices): + A = 1 + def test_integerfield_cleans_valid_string(self): f = models.IntegerField() self.assertEqual(f.clean('2', None), 2) @@ -217,3 +220,14 @@ class ValidationTests(SimpleTestCase): f = models.IntegerField(choices=((1, 1),)) with self.assertRaises(ValidationError): f.clean('0', None) + + def test_enum_choices_cleans_valid_string(self): + f = models.IntegerField(choices=self.Choices.choices) + self.assertEqual(f.clean('1', None), 1) + + def test_enum_choices_invalid_input(self): + f = models.IntegerField(choices=self.Choices.choices) + with self.assertRaises(ValidationError): + f.clean('A', None) + with self.assertRaises(ValidationError): + f.clean('3', None)