Fixed #29391 -- Made PostgresSimpleLookup respect Field.get_db_prep_value().

This commit is contained in:
Vinay Karanam 2019-02-04 04:57:19 +05:30 committed by Tim Graham
parent c492fdfd24
commit 5a36c81f58
5 changed files with 42 additions and 9 deletions

View File

@ -1,10 +1,10 @@
from django.db.models import Lookup, Transform
from django.db.models.lookups import Exact
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin
from .search import SearchVector, SearchVectorExact, SearchVectorField
class PostgresSimpleLookup(Lookup):
class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup):
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)

View File

@ -2,6 +2,8 @@
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
run with a backend other than PostgreSQL.
"""
import enum
from django.db import models
try:
@ -40,3 +42,8 @@ except ImportError:
IntegerRangeField = models.Field
JSONField = DummyJSONField
SearchVectorField = models.Field
class EnumField(models.CharField):
def get_prep_value(self, value):
return value.value if isinstance(value, enum.Enum) else value

View File

@ -3,8 +3,8 @@ from django.db import migrations, models
from ..fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
IntegerRangeField, JSONField, SearchVectorField,
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
HStoreField, IntegerRangeField, JSONField, SearchVectorField,
)
from ..models import TagField
@ -249,4 +249,15 @@ class Migration(migrations.Migration):
},
bases=(models.Model,),
),
migrations.CreateModel(
name='ArrayEnumModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
]

View File

@ -3,8 +3,8 @@ from django.db import models
from .fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
IntegerRangeField, JSONField, SearchVectorField,
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
HStoreField, IntegerRangeField, JSONField, SearchVectorField,
)
@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel):
array_field = ArrayField(HStoreField(), null=True)
class ArrayEnumModel(PostgreSQLModel):
array_of_enums = ArrayField(EnumField(max_length=20))
class CharFieldModel(models.Model):
field = models.CharField(max_length=16)

View File

@ -1,4 +1,5 @@
import decimal
import enum
import json
import unittest
import uuid
@ -16,9 +17,9 @@ from . import (
PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
)
from .models import (
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
PostgreSQLModel, Tag,
ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
OtherTypesArrayModel, PostgreSQLModel, Tag,
)
try:
@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase):
[self.objs[3]]
)
def test_enum_lookup(self):
class TestEnum(enum.Enum):
VALUE_1 = 'value_1'
instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
self.assertSequenceEqual(
ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
[instance]
)
def test_unsupported_lookup(self):
msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
with self.assertRaisesMessage(FieldError, msg):