diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py index 1f1b3f4f20..27b5cbd379 100644 --- a/django/db/migrations/serializer.py +++ b/django/db/migrations/serializer.py @@ -46,6 +46,11 @@ class BaseSimpleSerializer(BaseSerializer): return repr(self.value), set() +class ChoicesSerializer(BaseSerializer): + def serialize(self): + return serializer_factory(self.value.value).serialize() + + class DateTimeSerializer(BaseSerializer): """For datetime.*, except datetime.datetime.""" def serialize(self): @@ -279,6 +284,7 @@ class Serializer: set: SetSerializer, tuple: TupleSerializer, dict: DictionarySerializer, + models.Choices: ChoicesSerializer, enum.Enum: EnumSerializer, datetime.datetime: DatetimeDatetimeSerializer, (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer, diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 79b175c1d5..cca87ed0e7 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -7,6 +7,8 @@ from django.db.models.constraints import __all__ as constraints_all from django.db.models.deletion import ( CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, ) +from django.db.models.enums import * # NOQA +from django.db.models.enums import __all__ as enums_all from django.db.models.expressions import ( Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame, @@ -32,7 +34,7 @@ from django.db.models.fields.related import ( # isort:skip ) -__all__ = aggregates_all + constraints_all + fields_all + indexes_all +__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all __all__ += [ 'ObjectDoesNotExist', 'signals', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', diff --git a/django/db/models/enums.py b/django/db/models/enums.py new file mode 100644 index 0000000000..bbe362a6ab --- /dev/null +++ b/django/db/models/enums.py @@ -0,0 +1,75 @@ +import enum + +from django.utils.functional import Promise + +__all__ = ['Choices', 'IntegerChoices', 'TextChoices'] + + +class ChoicesMeta(enum.EnumMeta): + """A metaclass for creating a enum choices.""" + + def __new__(metacls, classname, bases, classdict): + labels = [] + for key in classdict._member_names: + value = classdict[key] + if ( + isinstance(value, (list, tuple)) and + len(value) > 1 and + isinstance(value[-1], (Promise, str)) + ): + *value, label = value + value = tuple(value) + else: + label = key.replace('_', ' ').title() + labels.append(label) + # Use dict.__setitem__() to suppress defenses against double + # assignment in enum's classdict. + dict.__setitem__(classdict, key, value) + cls = super().__new__(metacls, classname, bases, classdict) + cls._value2label_map_ = dict(zip(cls._value2member_map_, labels)) + # Add a label property to instances of enum which uses the enum member + # that is passed in as "self" as the value to use when looking up the + # label in the choices. + cls.label = property(lambda self: cls._value2label_map_.get(self.value)) + return enum.unique(cls) + + def __contains__(cls, member): + if not isinstance(member, enum.Enum): + # Allow non-enums to match against member values. + return member in {x.value for x in cls} + return super().__contains__(member) + + @property + def names(cls): + empty = ['__empty__'] if hasattr(cls, '__empty__') else [] + return empty + [member.name for member in cls] + + @property + def choices(cls): + empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else [] + return empty + [(member.value, member.label) for member in cls] + + @property + def labels(cls): + return [label for _, label in cls.choices] + + @property + def values(cls): + return [value for value, _ in cls.choices] + + +class Choices(enum.Enum, metaclass=ChoicesMeta): + """Class for creating enumerated choices.""" + pass + + +class IntegerChoices(int, Choices): + """Class for creating enumerated integer choices.""" + pass + + +class TextChoices(str, Choices): + """Class for creating enumerated string choices.""" + + def _generate_next_value_(name, start, count, last_values): + return name diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 01a56c1312..49494186ce 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -94,6 +94,7 @@ and the second element is the human-readable name. For example:: ('SO', 'Sophomore'), ('JR', 'Junior'), ('SR', 'Senior'), + ('GR', 'Graduate'), ] Generally, it's best to define choices inside a model class, and to @@ -106,11 +107,13 @@ define a suitably-named constant for each value:: SOPHOMORE = 'SO' JUNIOR = 'JR' SENIOR = 'SR' + GRADUATE = 'GR' YEAR_IN_SCHOOL_CHOICES = [ (FRESHMAN, 'Freshman'), (SOPHOMORE, 'Sophomore'), (JUNIOR, 'Junior'), (SENIOR, 'Senior'), + (GRADUATE, 'Graduate'), ] year_in_school = models.CharField( max_length=2, @@ -119,7 +122,7 @@ define a suitably-named constant for each value:: ) def is_upperclass(self): - return self.year_in_school in (self.JUNIOR, self.SENIOR) + return self.year_in_school in {self.JUNIOR, self.SENIOR} Though you can define a choices list outside of a model class and then refer to it, defining the choices and names for each choice inside the @@ -127,6 +130,95 @@ model class keeps all of that information with the class that uses it, and makes the choices easy to reference (e.g, ``Student.SOPHOMORE`` will work anywhere that the ``Student`` model has been imported). +In addition, Django provides enumeration types that you can subclass to define +choices in a concise way:: + + from django.utils.translation import gettext_lazy as _ + + class Student(models.Model): + + class YearInSchool(models.TextChoices): + FRESHMAN = 'FR', _('Freshman') + SOPHOMORE = 'SO', _('Sophomore') + JUNIOR = 'JR', _('Junior') + SENIOR = 'SR', _('Senior') + GRADUATE = 'GR', _('Graduate') + + year_in_school = models.CharField( + max_length=2, + choices=YearInSchool.choices, + default=YearInSchool.FRESHMAN, + ) + + def is_upperclass(self): + return self.year_in_school in {YearInSchool.JUNIOR, YearInSchool.SENIOR} + +These work similar to :mod:`enum` from Python's standard library, but with some +modifications: + +* Instead of values in the ``enum``, Django uses ``(value, label)`` tuples. The + ``label`` can be a lazy translatable string. If a tuple is not provided, the + label is automatically generated from the member name. +* ``.label`` property is added on values, to return the label specified. +* Number of custom properties are added to the enumeration classes -- + ``.choices``, ``.labels``, ``.values``, and ``.names`` -- to make it easier + to access lists of those separate parts of the enumeration. Use ``.choices`` + as a suitable value to pass to :attr:`~Field.choices` in a field definition. +* The use of :func:`enum.unique()` is enforced to ensure that values cannot be + defined multiple times. This is unlikely to be expected in choices for a + field. + +Note that ``YearInSchool.SENIOR``, ``YearInSchool['SENIOR']``, +``YearInSchool('SR')`` work as expected, while ``YearInSchool.SENIOR.label`` is +a translatable string. + +If you don't need to have the human-readable names translated, you can have +them inferred from the member name (replacing underscores to spaces and using +title-case):: + + class YearInSchool(models.TextChoices): + FRESHMAN = 'FR' + SOPHOMORE = 'SO' + JUNIOR = 'JR' + SENIOR = 'SR' + GRADUATE = 'GR' + +Since the case where the enum values need to be integers is extremely common, +Django provides a ``IntegerChoices`` class. For example:: + + class Card(models.Model): + + class Suit(models.IntegerChoices): + DIAMOND = 1 + SPADE = 2 + HEART = 3 + CLUB = 4 + + suit = models.IntegerField(choices=Suit.choices) + +It is also possible to make use of the `Enum Functional API +`_ with the caveat +that labels are automatically generated as highlighted above:: + + >>> MedalType = models.TextChoices('MedalType', 'GOLD SILVER BRONZE') + >>> MedalType.choices + [('GOLD', 'Gold'), ('SILVER', 'Silver'), ('BRONZE', 'Bronze')] + >>> Place = models.IntegerChoices('Place', 'FIRST SECOND THIRD') + >>> Place.choices + [(1, 'First'), (2, 'Second'), (3, 'Third')] + +If you require support for a concrete data type other than ``int`` or ``str``, +you can subclass ``Choices`` and the required concrete data type, e.g. +:class:``datetime.date`` for use with :class:`~django.db.models.DateField`:: + + class MoonLandings(datetime.date, models.Choices): + APOLLO_11 = 1969, 7, 20, 'Apollo 11 (Eagle)' + APOLLO_12 = 1969, 11, 19, 'Apollo 12 (Intrepid)' + APOLLO_14 = 1971, 2, 5, 'Apollo 14 (Antares)' + APOLLO_15 = 1971, 7, 30, 'Apollo 15 (Falcon)' + APOLLO_16 = 1972, 4, 21, 'Apollo 16 (Orion)' + APOLLO_17 = 1972, 12, 11, 'Apollo 17 (Challenger)' + You can also collect your available choices into named groups that can be used for organizational purposes:: @@ -148,7 +240,8 @@ The first element in each tuple is the name to apply to the group. The second element is an iterable of 2-tuples, with each 2-tuple containing a value and a human-readable name for an option. Grouped options may be combined with ungrouped options within a single list (such as the -`unknown` option in this example). +`unknown` option in this example). Grouping is not supported by the custom +enumeration types for managing choices. For each model field that has :attr:`~Field.choices` set, Django will add a method to retrieve the human-readable name for the field's current value. See @@ -169,7 +262,19 @@ Unless :attr:`blank=False` is set on the field along with a with the select box. To override this behavior, add a tuple to ``choices`` containing ``None``; e.g. ``(None, 'Your String For Display')``. Alternatively, you can use an empty string instead of ``None`` where this makes -sense - such as on a :class:`~django.db.models.CharField`. +sense - such as on a :class:`~django.db.models.CharField`. To change the label +when using one of the custom enumeration types, set the ``__empty__`` attribute +on the class:: + + class Answer(models.IntegerChoices): + NO = 0, _('No') + YES = 1, _('Yes') + + __empty__ = _('(Unknown)') + +.. versionadded:: 3.0 + + The ``TextChoices``, ``IntegerChoices``, and ``Choices`` classes were added. ``db_column`` ------------- diff --git a/docs/releases/3.0.txt b/docs/releases/3.0.txt index 9745a34660..2868e37e16 100644 --- a/docs/releases/3.0.txt +++ b/docs/releases/3.0.txt @@ -81,6 +81,18 @@ Expressions that outputs :class:`~django.db.models.BooleanField` may now be used directly in ``QuerySet`` filters, without having to first annotate and then filter against the annotation. +Enumerations for model field choices +------------------------------------ + +Custom enumeration types ``TextChoices``, ``IntegerChoices``, and ``Choices`` +are now available as a way to define :attr:`.Field.choices`. ``TextChoices`` +and ``IntegerChoices`` types are provided for text and integer fields. The +``Choices`` class allows defining a compatible enumeration for other concrete +data types. These custom enumeration types support human-readable labels that +can be translated and accessed via a property on the enumeration or its +members. See :ref:`Field.choices documentation ` for more +details and examples. + Minor features -------------- diff --git a/docs/spelling_wordlist b/docs/spelling_wordlist index 445a64adfc..43e8a68385 100644 --- a/docs/spelling_wordlist +++ b/docs/spelling_wordlist @@ -193,6 +193,7 @@ elidable encodings Endian Enero +enum environ esque Ess diff --git a/docs/topics/db/models.txt b/docs/topics/db/models.txt index f47490877f..0aab0bb35f 100644 --- a/docs/topics/db/models.txt +++ b/docs/topics/db/models.txt @@ -198,6 +198,19 @@ ones: >>> p.get_shirt_size_display() 'Large' + You can also use enumeration classes to define ``choices`` in a concise + way:: + + from django.db import models + + class Runner(models.Model): + MedalType = models.TextChoices('MedalType', 'GOLD SILVER BRONZE') + name = models.CharField(max_length=60) + medal = models.CharField(blank=True, choices=MedalType.choices, max_length=10) + + Further examples are available in the :ref:`model field reference + `. + :attr:`~Field.default` The default value for the field. This can be a value or a callable object. If callable it will be called every time a new object is diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 790d6a3efa..f97d76e9ad 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -306,6 +306,48 @@ class WriterTests(SimpleTestCase): "default=migrations.test_writer.IntEnum(1))" ) + def test_serialize_choices(self): + class TextChoices(models.TextChoices): + A = 'A', 'A value' + B = 'B', 'B value' + + class IntegerChoices(models.IntegerChoices): + A = 1, 'One' + B = 2, 'Two' + + class DateChoices(datetime.date, models.Choices): + DATE_1 = 1969, 7, 20, 'First date' + DATE_2 = 1969, 11, 19, 'Second date' + + self.assertSerializedResultEqual(TextChoices.A, ("'A'", set())) + self.assertSerializedResultEqual(IntegerChoices.A, ('1', set())) + self.assertSerializedResultEqual( + DateChoices.DATE_1, + ('datetime.date(1969, 7, 20)', {'import datetime'}), + ) + field = models.CharField(default=TextChoices.B, choices=TextChoices.choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.CharField(choices=[('A', 'A value'), ('B', 'B value')], " + "default='B')", + ) + field = models.IntegerField(default=IntegerChoices.B, choices=IntegerChoices.choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.IntegerField(choices=[(1, 'One'), (2, 'Two')], default=2)", + ) + field = models.DateField(default=DateChoices.DATE_2, choices=DateChoices.choices) + string = MigrationWriter.serialize(field)[0] + self.assertEqual( + string, + "models.DateField(choices=[" + "(datetime.date(1969, 7, 20), 'First date'), " + "(datetime.date(1969, 11, 19), 'Second date')], " + "default=datetime.date(1969, 11, 19))" + ) + def test_serialize_uuid(self): self.assertSerializedEqual(uuid.uuid1()) self.assertSerializedEqual(uuid.uuid4()) diff --git a/tests/model_enums/tests.py b/tests/model_enums/tests.py new file mode 100644 index 0000000000..6b4bd6e7fd --- /dev/null +++ b/tests/model_enums/tests.py @@ -0,0 +1,253 @@ +import datetime +import decimal +import ipaddress +import uuid + +from django.db import models +from django.test import SimpleTestCase +from django.utils.functional import Promise +from django.utils.translation import gettext_lazy as _ + + +class Suit(models.IntegerChoices): + DIAMOND = 1, _('Diamond') + SPADE = 2, _('Spade') + HEART = 3, _('Heart') + CLUB = 4, _('Club') + + +class YearInSchool(models.TextChoices): + FRESHMAN = 'FR', _('Freshman') + SOPHOMORE = 'SO', _('Sophomore') + JUNIOR = 'JR', _('Junior') + SENIOR = 'SR', _('Senior') + GRADUATE = 'GR', _('Graduate') + + +class Vehicle(models.IntegerChoices): + CAR = 1, 'Carriage' + TRUCK = 2 + JET_SKI = 3 + + __empty__ = _('(Unknown)') + + +class Gender(models.TextChoices): + MALE = 'M' + FEMALE = 'F' + NOT_SPECIFIED = 'X' + + __empty__ = '(Undeclared)' + + +class ChoicesTests(SimpleTestCase): + def test_integerchoices(self): + self.assertEqual(Suit.choices, [(1, 'Diamond'), (2, 'Spade'), (3, 'Heart'), (4, 'Club')]) + self.assertEqual(Suit.labels, ['Diamond', 'Spade', 'Heart', 'Club']) + self.assertEqual(Suit.values, [1, 2, 3, 4]) + self.assertEqual(Suit.names, ['DIAMOND', 'SPADE', 'HEART', 'CLUB']) + + self.assertEqual(repr(Suit.DIAMOND), '') + self.assertEqual(Suit.DIAMOND.label, 'Diamond') + self.assertEqual(Suit.DIAMOND.value, 1) + self.assertEqual(Suit['DIAMOND'], Suit.DIAMOND) + self.assertEqual(Suit(1), Suit.DIAMOND) + + self.assertIsInstance(Suit, type(models.Choices)) + self.assertIsInstance(Suit.DIAMOND, Suit) + self.assertIsInstance(Suit.DIAMOND.label, Promise) + self.assertIsInstance(Suit.DIAMOND.value, int) + + def test_integerchoices_auto_label(self): + self.assertEqual(Vehicle.CAR.label, 'Carriage') + self.assertEqual(Vehicle.TRUCK.label, 'Truck') + self.assertEqual(Vehicle.JET_SKI.label, 'Jet Ski') + + def test_integerchoices_empty_label(self): + self.assertEqual(Vehicle.choices[0], (None, '(Unknown)')) + self.assertEqual(Vehicle.labels[0], '(Unknown)') + self.assertEqual(Vehicle.values[0], None) + self.assertEqual(Vehicle.names[0], '__empty__') + + def test_integerchoices_functional_api(self): + Place = models.IntegerChoices('Place', 'FIRST SECOND THIRD') + self.assertEqual(Place.labels, ['First', 'Second', 'Third']) + self.assertEqual(Place.values, [1, 2, 3]) + self.assertEqual(Place.names, ['FIRST', 'SECOND', 'THIRD']) + + def test_integerchoices_containment(self): + self.assertIn(Suit.DIAMOND, Suit) + self.assertIn(1, Suit) + self.assertNotIn(0, Suit) + + def test_textchoices(self): + self.assertEqual(YearInSchool.choices, [ + ('FR', 'Freshman'), ('SO', 'Sophomore'), ('JR', 'Junior'), ('SR', 'Senior'), ('GR', 'Graduate'), + ]) + self.assertEqual(YearInSchool.labels, ['Freshman', 'Sophomore', 'Junior', 'Senior', 'Graduate']) + self.assertEqual(YearInSchool.values, ['FR', 'SO', 'JR', 'SR', 'GR']) + self.assertEqual(YearInSchool.names, ['FRESHMAN', 'SOPHOMORE', 'JUNIOR', 'SENIOR', 'GRADUATE']) + + self.assertEqual(repr(YearInSchool.FRESHMAN), "") + self.assertEqual(YearInSchool.FRESHMAN.label, 'Freshman') + self.assertEqual(YearInSchool.FRESHMAN.value, 'FR') + self.assertEqual(YearInSchool['FRESHMAN'], YearInSchool.FRESHMAN) + self.assertEqual(YearInSchool('FR'), YearInSchool.FRESHMAN) + + self.assertIsInstance(YearInSchool, type(models.Choices)) + self.assertIsInstance(YearInSchool.FRESHMAN, YearInSchool) + self.assertIsInstance(YearInSchool.FRESHMAN.label, Promise) + self.assertIsInstance(YearInSchool.FRESHMAN.value, str) + + def test_textchoices_auto_label(self): + self.assertEqual(Gender.MALE.label, 'Male') + self.assertEqual(Gender.FEMALE.label, 'Female') + self.assertEqual(Gender.NOT_SPECIFIED.label, 'Not Specified') + + def test_textchoices_empty_label(self): + self.assertEqual(Gender.choices[0], (None, '(Undeclared)')) + self.assertEqual(Gender.labels[0], '(Undeclared)') + self.assertEqual(Gender.values[0], None) + self.assertEqual(Gender.names[0], '__empty__') + + def test_textchoices_functional_api(self): + Medal = models.TextChoices('Medal', 'GOLD SILVER BRONZE') + self.assertEqual(Medal.labels, ['Gold', 'Silver', 'Bronze']) + self.assertEqual(Medal.values, ['GOLD', 'SILVER', 'BRONZE']) + self.assertEqual(Medal.names, ['GOLD', 'SILVER', 'BRONZE']) + + def test_textchoices_containment(self): + self.assertIn(YearInSchool.FRESHMAN, YearInSchool) + self.assertIn('FR', YearInSchool) + self.assertNotIn('XX', YearInSchool) + + def test_textchoices_blank_value(self): + class BlankStr(models.TextChoices): + EMPTY = '', '(Empty)' + ONE = 'ONE', 'One' + + self.assertEqual(BlankStr.labels, ['(Empty)', 'One']) + self.assertEqual(BlankStr.values, ['', 'ONE']) + self.assertEqual(BlankStr.names, ['EMPTY', 'ONE']) + + def test_invalid_definition(self): + msg = "'str' object cannot be interpreted as an integer" + with self.assertRaisesMessage(TypeError, msg): + class InvalidArgumentEnum(models.IntegerChoices): + # A string is not permitted as the second argument to int(). + ONE = 1, 'X', 'Invalid' + + msg = "duplicate values found in : PINEAPPLE -> APPLE" + with self.assertRaisesMessage(ValueError, msg): + class Fruit(models.IntegerChoices): + APPLE = 1, 'Apple' + PINEAPPLE = 1, 'Pineapple' + + +class Separator(bytes, models.Choices): + FS = b'\x1c', 'File Separator' + GS = b'\x1d', 'Group Separator' + RS = b'\x1e', 'Record Separator' + US = b'\x1f', 'Unit Separator' + + +class Constants(float, models.Choices): + PI = 3.141592653589793, 'π' + TAU = 6.283185307179586, 'τ' + + +class Set(frozenset, models.Choices): + A = {1, 2} + B = {2, 3} + UNION = A | B + DIFFERENCE = A - B + INTERSECTION = A & B + + +class MoonLandings(datetime.date, models.Choices): + APOLLO_11 = 1969, 7, 20, 'Apollo 11 (Eagle)' + APOLLO_12 = 1969, 11, 19, 'Apollo 12 (Intrepid)' + APOLLO_14 = 1971, 2, 5, 'Apollo 14 (Antares)' + APOLLO_15 = 1971, 7, 30, 'Apollo 15 (Falcon)' + APOLLO_16 = 1972, 4, 21, 'Apollo 16 (Orion)' + APOLLO_17 = 1972, 12, 11, 'Apollo 17 (Challenger)' + + +class DateAndTime(datetime.datetime, models.Choices): + A = 2010, 10, 10, 10, 10, 10 + B = 2011, 11, 11, 11, 11, 11 + C = 2012, 12, 12, 12, 12, 12 + + +class MealTimes(datetime.time, models.Choices): + BREAKFAST = 7, 0 + LUNCH = 13, 0 + DINNER = 18, 30 + + +class Frequency(datetime.timedelta, models.Choices): + WEEK = 0, 0, 0, 0, 0, 0, 1, 'Week' + DAY = 1, 'Day' + HOUR = 0, 0, 0, 0, 0, 1, 'Hour' + MINUTE = 0, 0, 0, 0, 1, 'Hour' + SECOND = 0, 1, 'Second' + + +class Number(decimal.Decimal, models.Choices): + E = 2.718281828459045, 'e' + PI = '3.141592653589793', 'π' + TAU = decimal.Decimal('6.283185307179586'), 'τ' + + +class IPv4Address(ipaddress.IPv4Address, models.Choices): + LOCALHOST = '127.0.0.1', 'Localhost' + GATEWAY = '192.168.0.1', 'Gateway' + BROADCAST = '192.168.0.255', 'Broadcast' + + +class IPv6Address(ipaddress.IPv6Address, models.Choices): + LOCALHOST = '::1', 'Localhost' + UNSPECIFIED = '::', 'Unspecified' + + +class IPv4Network(ipaddress.IPv4Network, models.Choices): + LOOPBACK = '127.0.0.0/8', 'Loopback' + LINK_LOCAL = '169.254.0.0/16', 'Link-Local' + PRIVATE_USE_A = '10.0.0.0/8', 'Private-Use (Class A)' + + +class IPv6Network(ipaddress.IPv6Network, models.Choices): + LOOPBACK = '::1/128', 'Loopback' + UNSPECIFIED = '::/128', 'Unspecified' + UNIQUE_LOCAL = 'fc00::/7', 'Unique-Local' + LINK_LOCAL_UNICAST = 'fe80::/10', 'Link-Local Unicast' + + +class CustomChoicesTests(SimpleTestCase): + def test_labels_valid(self): + enums = ( + Separator, Constants, Set, MoonLandings, DateAndTime, MealTimes, + Frequency, Number, IPv4Address, IPv6Address, IPv4Network, + IPv6Network, + ) + for choice_enum in enums: + with self.subTest(choice_enum.__name__): + self.assertNotIn(None, choice_enum.labels) + + def test_bool_unsupported(self): + msg = "type 'bool' is not an acceptable base type" + with self.assertRaisesMessage(TypeError, msg): + class Boolean(bool, models.Choices): + pass + + def test_timezone_unsupported(self): + msg = "type 'datetime.timezone' is not an acceptable base type" + with self.assertRaisesMessage(TypeError, msg): + class Timezone(datetime.timezone, models.Choices): + pass + + def test_uuid_unsupported(self): + msg = 'UUID objects are immutable' + with self.assertRaisesMessage(TypeError, msg): + class Identifier(uuid.UUID, models.Choices): + A = '972ce4eb-a95f-4a56-9339-68c208a76f18' diff --git a/tests/model_fields/test_charfield.py b/tests/model_fields/test_charfield.py index 1b50f01f3a..051be2eeec 100644 --- a/tests/model_fields/test_charfield.py +++ b/tests/model_fields/test_charfield.py @@ -28,9 +28,27 @@ class TestCharField(TestCase): p.refresh_from_db() self.assertEqual(p.title, 'Smile 😀') + def test_assignment_from_choice_enum(self): + class Event(models.TextChoices): + C = 'Carnival!' + F = 'Festival!' + + p1 = Post.objects.create(title=Event.C, body=Event.F) + p1.refresh_from_db() + self.assertEqual(p1.title, 'Carnival!') + self.assertEqual(p1.body, 'Festival!') + self.assertEqual(p1.title, Event.C) + self.assertEqual(p1.body, Event.F) + p2 = Post.objects.get(title='Carnival!') + self.assertEquals(p1, p2) + self.assertEquals(p2.title, Event.C) + class ValidationTests(SimpleTestCase): + class Choices(models.TextChoices): + C = 'c', 'C' + def test_charfield_raises_error_on_empty_string(self): f = models.CharField() with self.assertRaises(ValidationError): @@ -49,6 +67,15 @@ class ValidationTests(SimpleTestCase): with self.assertRaises(ValidationError): f.clean('not a', None) + def test_enum_choices_cleans_valid_string(self): + f = models.CharField(choices=self.Choices.choices, max_length=1) + self.assertEqual(f.clean('c', None), 'c') + + def test_enum_choices_invalid_input(self): + f = models.CharField(choices=self.Choices.choices, max_length=1) + with self.assertRaises(ValidationError): + f.clean('a', None) + def test_charfield_raises_error_on_empty_input(self): f = models.CharField(null=False) with self.assertRaises(ValidationError): diff --git a/tests/model_fields/test_integerfield.py b/tests/model_fields/test_integerfield.py index c0bb7627cf..606d71057a 100644 --- a/tests/model_fields/test_integerfield.py +++ b/tests/model_fields/test_integerfield.py @@ -184,6 +184,9 @@ class PositiveIntegerFieldTests(IntegerFieldTests): class ValidationTests(SimpleTestCase): + class Choices(models.IntegerChoices): + A = 1 + def test_integerfield_cleans_valid_string(self): f = models.IntegerField() self.assertEqual(f.clean('2', None), 2) @@ -217,3 +220,14 @@ class ValidationTests(SimpleTestCase): f = models.IntegerField(choices=((1, 1),)) with self.assertRaises(ValidationError): f.clean('0', None) + + def test_enum_choices_cleans_valid_string(self): + f = models.IntegerField(choices=self.Choices.choices) + self.assertEqual(f.clean('1', None), 1) + + def test_enum_choices_invalid_input(self): + f = models.IntegerField(choices=self.Choices.choices) + with self.assertRaises(ValidationError): + f.clean('A', None) + with self.assertRaises(ValidationError): + f.clean('3', None)