Fixed #20348 -- Consistently handle Promise objects in model fields.

All Promise objects were passed to force_text() deep in ORM query code.
Not only does this make it difficult or impossible for developers to
prevent or alter this behaviour, but it is also wrong for non-text
fields.

This commit changes `Field.get_prep_value()` from a no-op to one that
resolved Promise objects. All subclasses now call super() method first
to ensure that they have a real value to work with.
This commit is contained in:
Tai Lee 2013-05-04 00:02:10 +10:00 committed by Anssi Kääriäinen
parent 8f5533ab25
commit 31e6d58d46
5 changed files with 195 additions and 26 deletions

View File

@ -148,6 +148,7 @@ class GeometryField(Field):
value properly, and preserve any other lookup parameters before value properly, and preserve any other lookup parameters before
returning to the caller. returning to the caller.
""" """
value = super(GeometryField, self).get_prep_value(value)
if isinstance(value, SQLEvaluator): if isinstance(value, SQLEvaluator):
return value return value
elif isinstance(value, (tuple, list)): elif isinstance(value, (tuple, list)):

View File

@ -17,7 +17,7 @@ from django import forms
from django.core import exceptions, validators from django.core import exceptions, validators
from django.utils.datastructures import DictWrapper from django.utils.datastructures import DictWrapper
from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import curry, total_ordering from django.utils.functional import curry, total_ordering, Promise
from django.utils.text import capfirst from django.utils.text import capfirst
from django.utils import timezone from django.utils import timezone
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -421,6 +421,8 @@ class Field(object):
""" """
Perform preliminary non-db specific value checks and conversions. Perform preliminary non-db specific value checks and conversions.
""" """
if isinstance(value, Promise):
value = value._proxy____cast()
return value return value
def get_db_prep_value(self, value, connection, prepared=False): def get_db_prep_value(self, value, connection, prepared=False):
@ -704,6 +706,7 @@ class AutoField(Field):
return value return value
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(AutoField, self).get_prep_value(value)
if value is None: if value is None:
return None return None
return int(value) return int(value)
@ -763,6 +766,7 @@ class BooleanField(Field):
return super(BooleanField, self).get_prep_lookup(lookup_type, value) return super(BooleanField, self).get_prep_lookup(lookup_type, value)
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(BooleanField, self).get_prep_value(value)
if value is None: if value is None:
return None return None
return bool(value) return bool(value)
@ -796,6 +800,7 @@ class CharField(Field):
return smart_text(value) return smart_text(value)
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(CharField, self).get_prep_value(value)
return self.to_python(value) return self.to_python(value)
def formfield(self, **kwargs): def formfield(self, **kwargs):
@ -911,6 +916,7 @@ class DateField(Field):
return super(DateField, self).get_prep_lookup(lookup_type, value) return super(DateField, self).get_prep_lookup(lookup_type, value)
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(DateField, self).get_prep_value(value)
return self.to_python(value) return self.to_python(value)
def get_db_prep_value(self, value, connection, prepared=False): def get_db_prep_value(self, value, connection, prepared=False):
@ -1008,6 +1014,7 @@ class DateTimeField(DateField):
# get_prep_lookup is inherited from DateField # get_prep_lookup is inherited from DateField
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(DateTimeField, self).get_prep_value(value)
value = self.to_python(value) value = self.to_python(value)
if value is not None and settings.USE_TZ and timezone.is_naive(value): if value is not None and settings.USE_TZ and timezone.is_naive(value):
# For backwards compatibility, interpret naive datetimes in local # For backwards compatibility, interpret naive datetimes in local
@ -1096,6 +1103,7 @@ class DecimalField(Field):
self.max_digits, self.decimal_places) self.max_digits, self.decimal_places)
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(DecimalField, self).get_prep_value(value)
return self.to_python(value) return self.to_python(value)
def formfield(self, **kwargs): def formfield(self, **kwargs):
@ -1185,6 +1193,7 @@ class FloatField(Field):
description = _("Floating point number") description = _("Floating point number")
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(FloatField, self).get_prep_value(value)
if value is None: if value is None:
return None return None
return float(value) return float(value)
@ -1218,6 +1227,7 @@ class IntegerField(Field):
description = _("Integer") description = _("Integer")
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(IntegerField, self).get_prep_value(value)
if value is None: if value is None:
return None return None
return int(value) return int(value)
@ -1326,6 +1336,7 @@ class GenericIPAddressField(Field):
return value or None return value or None
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(GenericIPAddressField, self).get_prep_value(value)
if value and ':' in value: if value and ':' in value:
try: try:
return clean_ipv6_address(value, self.unpack_ipv4) return clean_ipv6_address(value, self.unpack_ipv4)
@ -1391,6 +1402,7 @@ class NullBooleanField(Field):
value) value)
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(NullBooleanField, self).get_prep_value(value)
if value is None: if value is None:
return None return None
return bool(value) return bool(value)
@ -1473,6 +1485,7 @@ class TextField(Field):
return "TextField" return "TextField"
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(TextField, self).get_prep_value(value)
if isinstance(value, six.string_types) or value is None: if isinstance(value, six.string_types) or value is None:
return value return value
return smart_text(value) return smart_text(value)
@ -1549,6 +1562,7 @@ class TimeField(Field):
return super(TimeField, self).pre_save(model_instance, add) return super(TimeField, self).pre_save(model_instance, add)
def get_prep_value(self, value): def get_prep_value(self, value):
value = super(TimeField, self).get_prep_value(value)
return self.to_python(value) return self.to_python(value)
def get_db_prep_value(self, value, connection, prepared=False): def get_db_prep_value(self, value, connection, prepared=False):

View File

@ -253,6 +253,7 @@ class FileField(Field):
def get_prep_value(self, value): def get_prep_value(self, value):
"Returns field's value prepared for saving into a database." "Returns field's value prepared for saving into a database."
value = super(FileField, self).get_prep_value(value)
# Need to convert File objects provided via a form to unicode for database insertion # Need to convert File objects provided via a form to unicode for database insertion
if value is None: if value is None:
return None return None

View File

@ -11,8 +11,6 @@ from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint from django.db.models.sql.where import AND, Constraint
from django.utils.functional import Promise
from django.utils.encoding import force_text
from django.utils import six from django.utils import six
from django.utils import timezone from django.utils import timezone
@ -147,10 +145,6 @@ class UpdateQuery(Query):
Used by add_update_values() as well as the "fast" update path when Used by add_update_values() as well as the "fast" update path when
saving models. saving models.
""" """
# Check that no Promise object passes to the query. Refs #10498.
values_seq = [(value[0], value[1], force_text(value[2]))
if isinstance(value[2], Promise) else value
for value in values_seq]
self.values.extend(values_seq) self.values.extend(values_seq)
def add_related_update(self, model, field, value): def add_related_update(self, model, field, value):
@ -210,12 +204,6 @@ class InsertQuery(Query):
into the query, for example. into the query, for example.
""" """
self.fields = fields self.fields = fields
# Check that no Promise object reaches the DB. Refs #10498.
for field in fields:
for obj in objs:
value = getattr(obj, field.attname)
if isinstance(value, Promise):
setattr(obj, field.attname, force_text(value))
self.objs = objs self.objs = objs
self.raw = raw self.raw = raw

View File

@ -8,11 +8,21 @@ from django import test
from django import forms from django import forms
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import connection, models, IntegrityError from django.db import connection, models, IntegrityError
from django.db.models.fields.files import FieldFile from django.db.models.fields import (
AutoField, BigIntegerField, BinaryField, BooleanField, CharField,
CommaSeparatedIntegerField, DateField, DateTimeField, DecimalField,
EmailField, FilePathField, FloatField, IntegerField, IPAddressField,
GenericIPAddressField, NullBooleanField, PositiveIntegerField,
PositiveSmallIntegerField, SlugField, SmallIntegerField, TextField,
TimeField, URLField)
from django.db.models.fields.files import FileField, ImageField
from django.utils import six from django.utils import six
from django.utils.functional import lazy
from django.utils.unittest import skipIf
from .models import (Foo, Bar, Whiz, BigD, BigS, Image, BigInt, Post, from .models import (
NullBooleanModel, BooleanModel, DataModel, Document, RenamedField, Foo, Bar, Whiz, BigD, BigS, BigInt, Post, NullBooleanModel,
BooleanModel, DataModel, Document, RenamedField,
VerboseNameField, FksToBooleans) VerboseNameField, FksToBooleans)
@ -64,7 +74,7 @@ class BasicFieldTests(test.TestCase):
m = VerboseNameField m = VerboseNameField
for i in range(1, 23): for i in range(1, 23):
self.assertEqual(m._meta.get_field('field%d' % i).verbose_name, self.assertEqual(m._meta.get_field('field%d' % i).verbose_name,
'verbose field%d' % i) 'verbose field%d' % i)
self.assertEqual(m._meta.get_field('id').verbose_name, 'verbose pk') self.assertEqual(m._meta.get_field('id').verbose_name, 'verbose pk')
@ -290,9 +300,9 @@ class SlugFieldTests(test.TestCase):
""" """
Make sure SlugField honors max_length (#9706) Make sure SlugField honors max_length (#9706)
""" """
bs = BigS.objects.create(s = 'slug'*50) bs = BigS.objects.create(s='slug' * 50)
bs = BigS.objects.get(pk=bs.pk) bs = BigS.objects.get(pk=bs.pk)
self.assertEqual(bs.s, 'slug'*50) self.assertEqual(bs.s, 'slug' * 50)
class ValidationTest(test.TestCase): class ValidationTest(test.TestCase):
@ -313,15 +323,17 @@ class ValidationTest(test.TestCase):
self.assertRaises(ValidationError, f.clean, "a", None) self.assertRaises(ValidationError, f.clean, "a", None)
def test_charfield_with_choices_cleans_valid_choice(self): def test_charfield_with_choices_cleans_valid_choice(self):
f = models.CharField(max_length=1, choices=[('a','A'), ('b','B')]) f = models.CharField(max_length=1,
choices=[('a', 'A'), ('b', 'B')])
self.assertEqual('a', f.clean('a', None)) self.assertEqual('a', f.clean('a', None))
def test_charfield_with_choices_raises_error_on_invalid_choice(self): def test_charfield_with_choices_raises_error_on_invalid_choice(self):
f = models.CharField(choices=[('a','A'), ('b','B')]) f = models.CharField(choices=[('a', 'A'), ('b', 'B')])
self.assertRaises(ValidationError, f.clean, "not a", None) self.assertRaises(ValidationError, f.clean, "not a", None)
def test_choices_validation_supports_named_groups(self): def test_choices_validation_supports_named_groups(self):
f = models.IntegerField(choices=(('group',((10,'A'),(20,'B'))),(30,'C'))) f = models.IntegerField(
choices=(('group', ((10, 'A'), (20, 'B'))), (30, 'C')))
self.assertEqual(10, f.clean(10, None)) self.assertEqual(10, f.clean(10, None))
def test_nullable_integerfield_raises_error_with_blank_false(self): def test_nullable_integerfield_raises_error_with_blank_false(self):
@ -370,7 +382,7 @@ class BigIntegerFieldTests(test.TestCase):
self.assertEqual(qs[0].value, minval) self.assertEqual(qs[0].value, minval)
def test_types(self): def test_types(self):
b = BigInt(value = 0) b = BigInt(value=0)
self.assertIsInstance(b.value, six.integer_types) self.assertIsInstance(b.value, six.integer_types)
b.save() b.save()
self.assertIsInstance(b.value, six.integer_types) self.assertIsInstance(b.value, six.integer_types)
@ -378,8 +390,8 @@ class BigIntegerFieldTests(test.TestCase):
self.assertIsInstance(b.value, six.integer_types) self.assertIsInstance(b.value, six.integer_types)
def test_coercing(self): def test_coercing(self):
BigInt.objects.create(value ='10') BigInt.objects.create(value='10')
b = BigInt.objects.get(value = '10') b = BigInt.objects.get(value='10')
self.assertEqual(b.value, 10) self.assertEqual(b.value, 10)
class TypeCoercionTests(test.TestCase): class TypeCoercionTests(test.TestCase):
@ -466,7 +478,7 @@ class BinaryFieldTests(test.TestCase):
test_set_and_retrieve = unittest.expectedFailure(test_set_and_retrieve) test_set_and_retrieve = unittest.expectedFailure(test_set_and_retrieve)
def test_max_length(self): def test_max_length(self):
dm = DataModel(short_data=self.binary_data*4) dm = DataModel(short_data=self.binary_data * 4)
self.assertRaises(ValidationError, dm.full_clean) self.assertRaises(ValidationError, dm.full_clean)
class GenericIPAddressFieldTests(test.TestCase): class GenericIPAddressFieldTests(test.TestCase):
@ -481,3 +493,156 @@ class GenericIPAddressFieldTests(test.TestCase):
model_field = models.GenericIPAddressField(protocol='IPv6') model_field = models.GenericIPAddressField(protocol='IPv6')
form_field = model_field.formfield() form_field = model_field.formfield()
self.assertRaises(ValidationError, form_field.clean, '127.0.0.1') self.assertRaises(ValidationError, form_field.clean, '127.0.0.1')
class PromiseTest(test.TestCase):
def test_AutoField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
AutoField(primary_key=True).get_prep_value(lazy_func()),
int)
@skipIf(six.PY3, "Python 3 has no `long` type.")
def test_BigIntegerField(self):
lazy_func = lazy(lambda: long(9999999999999999999), long)
self.assertIsInstance(
BigIntegerField().get_prep_value(lazy_func()),
long)
def test_BinaryField(self):
lazy_func = lazy(lambda: b'', bytes)
self.assertIsInstance(
BinaryField().get_prep_value(lazy_func()),
bytes)
def test_BooleanField(self):
lazy_func = lazy(lambda: True, bool)
self.assertIsInstance(
BooleanField().get_prep_value(lazy_func()),
bool)
def test_CharField(self):
lazy_func = lazy(lambda: '', six.text_type)
self.assertIsInstance(
CharField().get_prep_value(lazy_func()),
six.text_type)
def test_CommaSeparatedIntegerField(self):
lazy_func = lazy(lambda: '1,2', six.text_type)
self.assertIsInstance(
CommaSeparatedIntegerField().get_prep_value(lazy_func()),
six.text_type)
def test_DateField(self):
lazy_func = lazy(lambda: datetime.date.today(), datetime.date)
self.assertIsInstance(
DateField().get_prep_value(lazy_func()),
datetime.date)
def test_DateTimeField(self):
lazy_func = lazy(lambda: datetime.datetime.now(), datetime.datetime)
self.assertIsInstance(
DateTimeField().get_prep_value(lazy_func()),
datetime.datetime)
def test_DecimalField(self):
lazy_func = lazy(lambda: Decimal('1.2'), Decimal)
self.assertIsInstance(
DecimalField().get_prep_value(lazy_func()),
Decimal)
def test_EmailField(self):
lazy_func = lazy(lambda: 'mailbox@domain.com', six.text_type)
self.assertIsInstance(
EmailField().get_prep_value(lazy_func()),
six.text_type)
def test_FileField(self):
lazy_func = lazy(lambda: 'filename.ext', six.text_type)
self.assertIsInstance(
FileField().get_prep_value(lazy_func()),
six.text_type)
def test_FilePathField(self):
lazy_func = lazy(lambda: 'tests.py', six.text_type)
self.assertIsInstance(
FilePathField().get_prep_value(lazy_func()),
six.text_type)
def test_FloatField(self):
lazy_func = lazy(lambda: 1.2, float)
self.assertIsInstance(
FloatField().get_prep_value(lazy_func()),
float)
def test_ImageField(self):
lazy_func = lazy(lambda: 'filename.ext', six.text_type)
self.assertIsInstance(
ImageField().get_prep_value(lazy_func()),
six.text_type)
def test_IntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
IntegerField().get_prep_value(lazy_func()),
int)
def test_IPAddressField(self):
lazy_func = lazy(lambda: '127.0.0.1', six.text_type)
self.assertIsInstance(
IPAddressField().get_prep_value(lazy_func()),
six.text_type)
def test_GenericIPAddressField(self):
lazy_func = lazy(lambda: '127.0.0.1', six.text_type)
self.assertIsInstance(
GenericIPAddressField().get_prep_value(lazy_func()),
six.text_type)
def test_NullBooleanField(self):
lazy_func = lazy(lambda: True, bool)
self.assertIsInstance(
NullBooleanField().get_prep_value(lazy_func()),
bool)
def test_PositiveIntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
PositiveIntegerField().get_prep_value(lazy_func()),
int)
def test_PositiveSmallIntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
PositiveSmallIntegerField().get_prep_value(lazy_func()),
int)
def test_SlugField(self):
lazy_func = lazy(lambda: 'slug', six.text_type)
self.assertIsInstance(
SlugField().get_prep_value(lazy_func()),
six.text_type)
def test_SmallIntegerField(self):
lazy_func = lazy(lambda: 1, int)
self.assertIsInstance(
SmallIntegerField().get_prep_value(lazy_func()),
int)
def test_TextField(self):
lazy_func = lazy(lambda: 'Abc', six.text_type)
self.assertIsInstance(
TextField().get_prep_value(lazy_func()),
six.text_type)
def test_TimeField(self):
lazy_func = lazy(lambda: datetime.datetime.now().time(), datetime.time)
self.assertIsInstance(
TimeField().get_prep_value(lazy_func()),
datetime.time)
def test_URLField(self):
lazy_func = lazy(lambda: 'http://domain.com', six.text_type)
self.assertIsInstance(
URLField().get_prep_value(lazy_func()),
six.text_type)