Fixed #29391 -- Made PostgresSimpleLookup respect Field.get_db_prep_value().
This commit is contained in:
parent
c492fdfd24
commit
5a36c81f58
|
@ -1,10 +1,10 @@
|
||||||
from django.db.models import Lookup, Transform
|
from django.db.models import Lookup, Transform
|
||||||
from django.db.models.lookups import Exact
|
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin
|
||||||
|
|
||||||
from .search import SearchVector, SearchVectorExact, SearchVectorField
|
from .search import SearchVector, SearchVectorExact, SearchVectorField
|
||||||
|
|
||||||
|
|
||||||
class PostgresSimpleLookup(Lookup):
|
class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup):
|
||||||
def as_sql(self, qn, connection):
|
def as_sql(self, qn, connection):
|
||||||
lhs, lhs_params = self.process_lhs(qn, connection)
|
lhs, lhs_params = self.process_lhs(qn, connection)
|
||||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
|
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
|
||||||
run with a backend other than PostgreSQL.
|
run with a backend other than PostgreSQL.
|
||||||
"""
|
"""
|
||||||
|
import enum
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -40,3 +42,8 @@ except ImportError:
|
||||||
IntegerRangeField = models.Field
|
IntegerRangeField = models.Field
|
||||||
JSONField = DummyJSONField
|
JSONField = DummyJSONField
|
||||||
SearchVectorField = models.Field
|
SearchVectorField = models.Field
|
||||||
|
|
||||||
|
|
||||||
|
class EnumField(models.CharField):
|
||||||
|
def get_prep_value(self, value):
|
||||||
|
return value.value if isinstance(value, enum.Enum) else value
|
||||||
|
|
|
@ -3,8 +3,8 @@ from django.db import migrations, models
|
||||||
|
|
||||||
from ..fields import (
|
from ..fields import (
|
||||||
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
|
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
|
||||||
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
|
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
|
||||||
IntegerRangeField, JSONField, SearchVectorField,
|
HStoreField, IntegerRangeField, JSONField, SearchVectorField,
|
||||||
)
|
)
|
||||||
from ..models import TagField
|
from ..models import TagField
|
||||||
|
|
||||||
|
@ -249,4 +249,15 @@ class Migration(migrations.Migration):
|
||||||
},
|
},
|
||||||
bases=(models.Model,),
|
bases=(models.Model,),
|
||||||
),
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='ArrayEnumModel',
|
||||||
|
fields=[
|
||||||
|
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
|
||||||
|
('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'required_db_vendor': 'postgresql',
|
||||||
|
},
|
||||||
|
bases=(models.Model,),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,8 +3,8 @@ from django.db import models
|
||||||
|
|
||||||
from .fields import (
|
from .fields import (
|
||||||
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
|
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
|
||||||
DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
|
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
|
||||||
IntegerRangeField, JSONField, SearchVectorField,
|
HStoreField, IntegerRangeField, JSONField, SearchVectorField,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel):
|
||||||
array_field = ArrayField(HStoreField(), null=True)
|
array_field = ArrayField(HStoreField(), null=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayEnumModel(PostgreSQLModel):
|
||||||
|
array_of_enums = ArrayField(EnumField(max_length=20))
|
||||||
|
|
||||||
|
|
||||||
class CharFieldModel(models.Model):
|
class CharFieldModel(models.Model):
|
||||||
field = models.CharField(max_length=16)
|
field = models.CharField(max_length=16)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import decimal
|
import decimal
|
||||||
|
import enum
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -16,9 +17,9 @@ from . import (
|
||||||
PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
|
PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
|
ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
|
||||||
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
|
IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
|
||||||
PostgreSQLModel, Tag,
|
OtherTypesArrayModel, PostgreSQLModel, Tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase):
|
||||||
[self.objs[3]]
|
[self.objs[3]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_enum_lookup(self):
|
||||||
|
class TestEnum(enum.Enum):
|
||||||
|
VALUE_1 = 'value_1'
|
||||||
|
|
||||||
|
instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
|
||||||
|
[instance]
|
||||||
|
)
|
||||||
|
|
||||||
def test_unsupported_lookup(self):
|
def test_unsupported_lookup(self):
|
||||||
msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
|
msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
|
||||||
with self.assertRaisesMessage(FieldError, msg):
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
|
Loading…
Reference in New Issue