Fixed #20348 -- Consistently handle Promise objects in model fields.

All Promise objects were passed to force_text() deep in ORM query code.
Not only does this make it difficult or impossible for developers to
prevent or alter this behaviour, but it is also wrong for non-text
fields.

This commit changes `Field.get_prep_value()` from a no-op to one that
resolved Promise objects. All subclasses now call super() method first
to ensure that they have a real value to work with.
This commit is contained in:
Tai Lee 2013-05-04 00:02:10 +10:00 committed by Anssi Kääriäinen
parent 8f5533ab25
commit 31e6d58d46
5 changed files with 195 additions and 26 deletions

View File

@ -148,6 +148,7 @@ class GeometryField(Field):
value properly, and preserve any other lookup parameters before
returning to the caller.
"""
value = super(GeometryField, self).get_prep_value(value)
if isinstance(value, SQLEvaluator):
return value
elif isinstance(value, (tuple, list)):

View File

@ -17,7 +17,7 @@ from django import forms
from django.core import exceptions, validators
from django.utils.datastructures import DictWrapper
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import curry, total_ordering
from django.utils.functional import curry, total_ordering, Promise
from django.utils.text import capfirst
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
@ -421,6 +421,8 @@ class Field(object):
"""
Perform preliminary non-db specific value checks and conversions.
"""
if isinstance(value, Promise):
value = value._proxy____cast()
return value
def get_db_prep_value(self, value, connection, prepared=False):
@ -704,6 +706,7 @@ class AutoField(Field):
return value
def get_prep_value(self, value):
value = super(AutoField, self).get_prep_value(value)
if value is None:
return None
return int(value)
@ -763,6 +766,7 @@ class BooleanField(Field):
return super(BooleanField, self).get_prep_lookup(lookup_type, value)
def get_prep_value(self, value):
value = super(BooleanField, self).get_prep_value(value)
if value is None:
return None
return bool(value)
@ -796,6 +800,7 @@ class CharField(Field):
return smart_text(value)
def get_prep_value(self, value):
value = super(CharField, self).get_prep_value(value)
return self.to_python(value)
def formfield(self, **kwargs):
@ -911,6 +916,7 @@ class DateField(Field):
return super(DateField, self).get_prep_lookup(lookup_type, value)
def get_prep_value(self, value):
value = super(DateField, self).get_prep_value(value)
return self.to_python(value)
def get_db_prep_value(self, value, connection, prepared=False):
@ -1008,6 +1014,7 @@ class DateTimeField(DateField):
# get_prep_lookup is inherited from DateField
def get_prep_value(self, value):
value = super(DateTimeField, self).get_prep_value(value)
value = self.to_python(value)
if value is not None and settings.USE_TZ and timezone.is_naive(value):
# For backwards compatibility, interpret naive datetimes in local
@ -1096,6 +1103,7 @@ class DecimalField(Field):
self.max_digits, self.decimal_places)
def get_prep_value(self, value):
value = super(DecimalField, self).get_prep_value(value)
return self.to_python(value)
def formfield(self, **kwargs):
@ -1185,6 +1193,7 @@ class FloatField(Field):
description = _("Floating point number")
def get_prep_value(self, value):
value = super(FloatField, self).get_prep_value(value)
if value is None:
return None
return float(value)
@ -1218,6 +1227,7 @@ class IntegerField(Field):
description = _("Integer")
def get_prep_value(self, value):
value = super(IntegerField, self).get_prep_value(value)
if value is None:
return None
return int(value)
@ -1326,6 +1336,7 @@ class GenericIPAddressField(Field):
return value or None
def get_prep_value(self, value):
value = super(GenericIPAddressField, self).get_prep_value(value)
if value and ':' in value:
try:
return clean_ipv6_address(value, self.unpack_ipv4)
@ -1391,6 +1402,7 @@ class NullBooleanField(Field):
value)
def get_prep_value(self, value):
value = super(NullBooleanField, self).get_prep_value(value)
if value is None:
return None
return bool(value)
@ -1473,6 +1485,7 @@ class TextField(Field):
return "TextField"
def get_prep_value(self, value):
value = super(TextField, self).get_prep_value(value)
if isinstance(value, six.string_types) or value is None:
return value
return smart_text(value)
@ -1549,6 +1562,7 @@ class TimeField(Field):
return super(TimeField, self).pre_save(model_instance, add)
def get_prep_value(self, value):
value = super(TimeField, self).get_prep_value(value)
return self.to_python(value)
def get_db_prep_value(self, value, connection, prepared=False):

View File

@ -253,6 +253,7 @@ class FileField(Field):
def get_prep_value(self, value):
"Returns field's value prepared for saving into a database."
value = super(FileField, self).get_prep_value(value)
# Need to convert File objects provided via a form to unicode for database insertion
if value is None:
return None

View File

@ -11,8 +11,6 @@ from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint
from django.utils.functional import Promise
from django.utils.encoding import force_text
from django.utils import six
from django.utils import timezone
@ -147,10 +145,6 @@ class UpdateQuery(Query):
Used by add_update_values() as well as the "fast" update path when
saving models.
"""
# Check that no Promise object passes to the query. Refs #10498.
values_seq = [(value[0], value[1], force_text(value[2]))
if isinstance(value[2], Promise) else value
for value in values_seq]
self.values.extend(values_seq)
def add_related_update(self, model, field, value):
@ -210,12 +204,6 @@ class InsertQuery(Query):
into the query, for example.
"""
self.fields = fields
# Check that no Promise object reaches the DB. Refs #10498.
for field in fields:
for obj in objs:
value = getattr(obj, field.attname)
if isinstance(value, Promise):
setattr(obj, field.attname, force_text(value))
self.objs = objs
self.raw = raw

View File

@ -8,11 +8,21 @@ from django import test
from django import forms
from django.core.exceptions import ValidationError
from django.db import connection, models, IntegrityError
from django.db.models.fields.files import FieldFile
from django.db.models.fields import (
AutoField, BigIntegerField, BinaryField, BooleanField, CharField,
CommaSeparatedIntegerField, DateField, DateTimeField, DecimalField,
EmailField, FilePathField, FloatField, IntegerField, IPAddressField,
GenericIPAddressField, NullBooleanField, PositiveIntegerField,
PositiveSmallIntegerField, SlugField, SmallIntegerField, TextField,
TimeField, URLField)
from django.db.models.fields.files import FileField, ImageField
from django.utils import six
from django.utils.functional import lazy
from django.utils.unittest import skipIf
from .models import (Foo, Bar, Whiz, BigD, BigS, Image, BigInt, Post,
NullBooleanModel, BooleanModel, DataModel, Document, RenamedField,
from .models import (
Foo, Bar, Whiz, BigD, BigS, BigInt, Post, NullBooleanModel,
BooleanModel, DataModel, Document, RenamedField,
VerboseNameField, FksToBooleans)
@ -290,9 +300,9 @@ class SlugFieldTests(test.TestCase):
"""
Make sure SlugField honors max_length (#9706)
"""
bs = BigS.objects.create(s = 'slug'*50)
bs = BigS.objects.create(s='slug' * 50)
bs = BigS.objects.get(pk=bs.pk)
self.assertEqual(bs.s, 'slug'*50)
self.assertEqual(bs.s, 'slug' * 50)
class ValidationTest(test.TestCase):
@ -313,15 +323,17 @@ class ValidationTest(test.TestCase):
self.assertRaises(ValidationError, f.clean, "a", None)
def test_charfield_with_choices_cleans_valid_choice(self):
f = models.CharField(max_length=1, choices=[('a','A'), ('b','B')])
f = models.CharField(max_length=1,
choices=[('a', 'A'), ('b', 'B')])
self.assertEqual('a', f.clean('a', None))
def test_charfield_with_choices_raises_error_on_invalid_choice(self):
f = models.CharField(choices=[('a','A'), ('b','B')])
f = models.CharField(choices=[('a', 'A'), ('b', 'B')])
self.assertRaises(ValidationError, f.clean, "not a", None)
def test_choices_validation_supports_named_groups(self):
f = models.IntegerField(choices=(('group',((10,'A'),(20,'B'))),(30,'C')))
f = models.IntegerField(
choices=(('group', ((10, 'A'), (20, 'B'))), (30, 'C')))
self.assertEqual(10, f.clean(10, None))
def test_nullable_integerfield_raises_error_with_blank_false(self):
@ -370,7 +382,7 @@ class BigIntegerFieldTests(test.TestCase):
self.assertEqual(qs[0].value, minval)
def test_types(self):
b = BigInt(value = 0)
b = BigInt(value=0)
self.assertIsInstance(b.value, six.integer_types)
b.save()
self.assertIsInstance(b.value, six.integer_types)
@ -378,8 +390,8 @@ class BigIntegerFieldTests(test.TestCase):
self.assertIsInstance(b.value, six.integer_types)
def test_coercing(self):
BigInt.objects.create(value ='10')
b = BigInt.objects.get(value = '10')
BigInt.objects.create(value='10')
b = BigInt.objects.get(value='10')
self.assertEqual(b.value, 10)
class TypeCoercionTests(test.TestCase):
@ -466,7 +478,7 @@ class BinaryFieldTests(test.TestCase):
test_set_and_retrieve = unittest.expectedFailure(test_set_and_retrieve)
def test_max_length(self):
dm = DataModel(short_data=self.binary_data*4)
dm = DataModel(short_data=self.binary_data * 4)
self.assertRaises(ValidationError, dm.full_clean)
class GenericIPAddressFieldTests(test.TestCase):
@ -481,3 +493,156 @@ class GenericIPAddressFieldTests(test.TestCase):
model_field = models.GenericIPAddressField(protocol='IPv6')
form_field = model_field.formfield()
self.assertRaises(ValidationError, form_field.clean, '127.0.0.1')
class PromiseTest(test.TestCase):
def test_AutoField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
AutoField(primary_key=True).get_prep_value(lazy_func()),
int)
@skipIf(six.PY3, "Python 3 has no `long` type.")
def test_BigIntegerField(self):
lazy_func = lazy(lambda: long(9999999999999999999), long)
self.assertIsInstance(
BigIntegerField().get_prep_value(lazy_func()),
long)
def test_BinaryField(self):
lazy_func = lazy(lambda: b'', bytes)
self.assertIsInstance(
BinaryField().get_prep_value(lazy_func()),
bytes)
def test_BooleanField(self):
lazy_func = lazy(lambda: True, bool)
self.assertIsInstance(
BooleanField().get_prep_value(lazy_func()),
bool)
def test_CharField(self):
lazy_func = lazy(lambda: '', six.text_type)
self.assertIsInstance(
CharField().get_prep_value(lazy_func()),
six.text_type)
def test_CommaSeparatedIntegerField(self):
lazy_func = lazy(lambda: '1,2', six.text_type)
self.assertIsInstance(
CommaSeparatedIntegerField().get_prep_value(lazy_func()),
six.text_type)
def test_DateField(self):
lazy_func = lazy(lambda: datetime.date.today(), datetime.date)
self.assertIsInstance(
DateField().get_prep_value(lazy_func()),
datetime.date)
def test_DateTimeField(self):
lazy_func = lazy(lambda: datetime.datetime.now(), datetime.datetime)
self.assertIsInstance(
DateTimeField().get_prep_value(lazy_func()),
datetime.datetime)
def test_DecimalField(self):
lazy_func = lazy(lambda: Decimal('1.2'), Decimal)
self.assertIsInstance(
DecimalField().get_prep_value(lazy_func()),
Decimal)
def test_EmailField(self):
lazy_func = lazy(lambda: 'mailbox@domain.com', six.text_type)
self.assertIsInstance(
EmailField().get_prep_value(lazy_func()),
six.text_type)
def test_FileField(self):
lazy_func = lazy(lambda: 'filename.ext', six.text_type)
self.assertIsInstance(
FileField().get_prep_value(lazy_func()),
six.text_type)
def test_FilePathField(self):
lazy_func = lazy(lambda: 'tests.py', six.text_type)
self.assertIsInstance(
FilePathField().get_prep_value(lazy_func()),
six.text_type)
def test_FloatField(self):
lazy_func = lazy(lambda: 1.2, float)
self.assertIsInstance(
FloatField().get_prep_value(lazy_func()),
float)
def test_ImageField(self):
lazy_func = lazy(lambda: 'filename.ext', six.text_type)
self.assertIsInstance(
ImageField().get_prep_value(lazy_func()),
six.text_type)
def test_IntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
IntegerField().get_prep_value(lazy_func()),
int)
def test_IPAddressField(self):
lazy_func = lazy(lambda: '127.0.0.1', six.text_type)
self.assertIsInstance(
IPAddressField().get_prep_value(lazy_func()),
six.text_type)
def test_GenericIPAddressField(self):
lazy_func = lazy(lambda: '127.0.0.1', six.text_type)
self.assertIsInstance(
GenericIPAddressField().get_prep_value(lazy_func()),
six.text_type)
def test_NullBooleanField(self):
lazy_func = lazy(lambda: True, bool)
self.assertIsInstance(
NullBooleanField().get_prep_value(lazy_func()),
bool)
def test_PositiveIntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
PositiveIntegerField().get_prep_value(lazy_func()),
int)
def test_PositiveSmallIntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
PositiveSmallIntegerField().get_prep_value(lazy_func()),
int)
def test_SlugField(self):
lazy_func = lazy(lambda: 'slug', six.text_type)
self.assertIsInstance(
SlugField().get_prep_value(lazy_func()),
six.text_type)
def test_SmallIntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
SmallIntegerField().get_prep_value(lazy_func()),
int)
def test_TextField(self):
lazy_func = lazy(lambda: 'Abc', six.text_type)
self.assertIsInstance(
TextField().get_prep_value(lazy_func()),
six.text_type)
def test_TimeField(self):
lazy_func = lazy(lambda: datetime.datetime.now().time(), datetime.time)
self.assertIsInstance(
TimeField().get_prep_value(lazy_func()),
datetime.time)
def test_URLField(self):
lazy_func = lazy(lambda: 'http://domain.com', six.text_type)
self.assertIsInstance(
URLField().get_prep_value(lazy_func()),
six.text_type)