Fixed #29391 -- Made PostgresSimpleLookup respect Field.get_db_prep_value().
This commit is contained in:
parent
c492fdfd24
commit
5a36c81f58
|
@ -1,10 +1,10 @@
|
|||
from django.db.models import Lookup, Transform
|
||||
from django.db.models.lookups import Exact
|
||||
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin
|
||||
|
||||
from .search import SearchVector, SearchVectorExact, SearchVectorField
|
||||
|
||||
|
||||
class PostgresSimpleLookup(Lookup):
|
||||
class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
|
||||
run with a backend other than PostgreSQL.
|
||||
"""
|
||||
import enum
|
||||
|
||||
from django.db import models
|
||||
|
||||
try:
|
||||
|
@ -40,3 +42,8 @@ except ImportError:
|
|||
IntegerRangeField = models.Field
|
||||
JSONField = DummyJSONField
|
||||
SearchVectorField = models.Field
|
||||
|
||||
|
||||
class EnumField(models.CharField):
|
||||
def get_prep_value(self, value):
|
||||
return value.value if isinstance(value, enum.Enum) else value
|
||||
|
|
|
@ -3,8 +3,8 @@ from django.db import migrations, models
|
|||
|
||||
from ..fields import (
|
||||
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
|
||||
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
|
||||
IntegerRangeField, JSONField, SearchVectorField,
|
||||
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
|
||||
HStoreField, IntegerRangeField, JSONField, SearchVectorField,
|
||||
)
|
||||
from ..models import TagField
|
||||
|
||||
|
@ -249,4 +249,15 @@ class Migration(migrations.Migration):
|
|||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ArrayEnumModel',
|
||||
fields=[
|
||||
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
|
||||
('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
|
||||
],
|
||||
options={
|
||||
'required_db_vendor': 'postgresql',
|
||||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -3,8 +3,8 @@ from django.db import models
|
|||
|
||||
from .fields import (
|
||||
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
|
||||
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
|
||||
IntegerRangeField, JSONField, SearchVectorField,
|
||||
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
|
||||
HStoreField, IntegerRangeField, JSONField, SearchVectorField,
|
||||
)
|
||||
|
||||
|
||||
|
@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel):
|
|||
array_field = ArrayField(HStoreField(), null=True)
|
||||
|
||||
|
||||
class ArrayEnumModel(PostgreSQLModel):
|
||||
array_of_enums = ArrayField(EnumField(max_length=20))
|
||||
|
||||
|
||||
class CharFieldModel(models.Model):
|
||||
field = models.CharField(max_length=16)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import decimal
|
||||
import enum
|
||||
import json
|
||||
import unittest
|
||||
import uuid
|
||||
|
@ -16,9 +17,9 @@ from . import (
|
|||
PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
|
||||
)
|
||||
from .models import (
|
||||
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
|
||||
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
|
||||
PostgreSQLModel, Tag,
|
||||
ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
|
||||
IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
|
||||
OtherTypesArrayModel, PostgreSQLModel, Tag,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase):
|
|||
[self.objs[3]]
|
||||
)
|
||||
|
||||
def test_enum_lookup(self):
|
||||
class TestEnum(enum.Enum):
|
||||
VALUE_1 = 'value_1'
|
||||
|
||||
instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
|
||||
self.assertSequenceEqual(
|
||||
ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
|
||||
[instance]
|
||||
)
|
||||
|
||||
def test_unsupported_lookup(self):
|
||||
msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
|
|
Loading…
Reference in New Issue