Fixed #29391 -- Made PostgresSimpleLookup respect Field.get_db_prep_value().

This commit is contained in:
Vinay Karanam 2019-02-04 04:57:19 +05:30 committed by Tim Graham
parent c492fdfd24
commit 5a36c81f58
5 changed files with 42 additions and 9 deletions

View File

@ -1,10 +1,10 @@
from django.db.models import Lookup, Transform from django.db.models import Lookup, Transform
from django.db.models.lookups import Exact from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin
from .search import SearchVector, SearchVectorExact, SearchVectorField from .search import SearchVector, SearchVectorExact, SearchVectorField
class PostgresSimpleLookup(Lookup): class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection)

View File

@ -2,6 +2,8 @@
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
run with a backend other than PostgreSQL. run with a backend other than PostgreSQL.
""" """
import enum
from django.db import models from django.db import models
try: try:
@ -40,3 +42,8 @@ except ImportError:
IntegerRangeField = models.Field IntegerRangeField = models.Field
JSONField = DummyJSONField JSONField = DummyJSONField
SearchVectorField = models.Field SearchVectorField = models.Field
class EnumField(models.CharField):
def get_prep_value(self, value):
return value.value if isinstance(value, enum.Enum) else value

View File

@ -3,8 +3,8 @@ from django.db import migrations, models
from ..fields import ( from ..fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField, DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
IntegerRangeField, JSONField, SearchVectorField, HStoreField, IntegerRangeField, JSONField, SearchVectorField,
) )
from ..models import TagField from ..models import TagField
@ -249,4 +249,15 @@ class Migration(migrations.Migration):
}, },
bases=(models.Model,), bases=(models.Model,),
), ),
migrations.CreateModel(
name='ArrayEnumModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
] ]

View File

@ -3,8 +3,8 @@ from django.db import models
from .fields import ( from .fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField, DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
IntegerRangeField, JSONField, SearchVectorField, HStoreField, IntegerRangeField, JSONField, SearchVectorField,
) )
@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel):
array_field = ArrayField(HStoreField(), null=True) array_field = ArrayField(HStoreField(), null=True)
class ArrayEnumModel(PostgreSQLModel):
array_of_enums = ArrayField(EnumField(max_length=20))
class CharFieldModel(models.Model): class CharFieldModel(models.Model):
field = models.CharField(max_length=16) field = models.CharField(max_length=16)

View File

@ -1,4 +1,5 @@
import decimal import decimal
import enum
import json import json
import unittest import unittest
import uuid import uuid
@ -16,9 +17,9 @@ from . import (
PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase, PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
) )
from .models import ( from .models import (
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
PostgreSQLModel, Tag, OtherTypesArrayModel, PostgreSQLModel, Tag,
) )
try: try:
@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase):
[self.objs[3]] [self.objs[3]]
) )
def test_enum_lookup(self):
class TestEnum(enum.Enum):
VALUE_1 = 'value_1'
instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
self.assertSequenceEqual(
ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
[instance]
)
def test_unsupported_lookup(self): def test_unsupported_lookup(self):
msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted." msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):