diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index c2b3d2b569..f0a523d849 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -1,10 +1,10 @@ from django.db.models import Lookup, Transform -from django.db.models.lookups import Exact +from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin from .search import SearchVector, SearchVectorExact, SearchVectorField -class PostgresSimpleLookup(Lookup): +class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup): def as_sql(self, qn, connection): lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py index 2275eb2ab2..4ebc0ce7dc 100644 --- a/tests/postgres_tests/fields.py +++ b/tests/postgres_tests/fields.py @@ -2,6 +2,8 @@ Indirection layer for PostgreSQL-specific fields, so the tests don't fail when run with a backend other than PostgreSQL. """ +import enum + from django.db import models try: @@ -40,3 +42,8 @@ except ImportError: IntegerRangeField = models.Field JSONField = DummyJSONField SearchVectorField = models.Field + + +class EnumField(models.CharField): + def get_prep_value(self, value): + return value.value if isinstance(value, enum.Enum) else value diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 5db8a71385..dc941de139 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -3,8 +3,8 @@ from django.db import migrations, models from ..fields import ( ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, - DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField, - IntegerRangeField, JSONField, SearchVectorField, + DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField, + HStoreField, IntegerRangeField, JSONField, SearchVectorField, ) from ..models import TagField @@ -249,4 +249,15 @@ class Migration(migrations.Migration): }, bases=(models.Model,), ), + migrations.CreateModel( + name='ArrayEnumModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)), + ], + options={ + 'required_db_vendor': 'postgresql', + }, + bases=(models.Model,), + ), ] diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index cbe477e402..2bb6e6fcdf 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -3,8 +3,8 @@ from django.db import models from .fields import ( ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, - DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField, - IntegerRangeField, JSONField, SearchVectorField, + DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField, + HStoreField, IntegerRangeField, JSONField, SearchVectorField, ) @@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel): array_field = ArrayField(HStoreField(), null=True) +class ArrayEnumModel(PostgreSQLModel): + array_of_enums = ArrayField(EnumField(max_length=20)) + + class CharFieldModel(models.Model): field = models.CharField(max_length=16) diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 447d511c9f..465eac1785 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -1,4 +1,5 @@ import decimal +import enum import json import unittest import uuid @@ -16,9 +17,9 @@ from . import ( PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase, ) from .models import ( - ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, - NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, - PostgreSQLModel, Tag, + ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, + IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, + OtherTypesArrayModel, PostgreSQLModel, Tag, ) try: @@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase): [self.objs[3]] ) + def test_enum_lookup(self): + class TestEnum(enum.Enum): + VALUE_1 = 'value_1' + + instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1]) + self.assertSequenceEqual( + ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]), + [instance] + ) + def test_unsupported_lookup(self): msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted." with self.assertRaisesMessage(FieldError, msg):