From 31e6d58d46894ca35080b4eab7967e4c6aae82d4 Mon Sep 17 00:00:00 2001 From: Tai Lee Date: Sat, 4 May 2013 00:02:10 +1000 Subject: [PATCH] Fixed #20348 -- Consistently handle Promise objects in model fields. All Promise objects were passed to force_text() deep in ORM query code. Not only does this make it difficult or impossible for developers to prevent or alter this behaviour, but it is also wrong for non-text fields. This commit changes `Field.get_prep_value()` from a no-op to one that resolved Promise objects. All subclasses now call super() method first to ensure that they have a real value to work with. --- django/contrib/gis/db/models/fields.py | 1 + django/db/models/fields/__init__.py | 16 ++- django/db/models/fields/files.py | 1 + django/db/models/sql/subqueries.py | 12 -- tests/model_fields/tests.py | 191 +++++++++++++++++++++++-- 5 files changed, 195 insertions(+), 26 deletions(-) diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index 2e221b7477..d29705986f 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -148,6 +148,7 @@ class GeometryField(Field): value properly, and preserve any other lookup parameters before returning to the caller. """ + value = super(GeometryField, self).get_prep_value(value) if isinstance(value, SQLEvaluator): return value elif isinstance(value, (tuple, list)): diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 4f5707952e..d0a2defc48 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -17,7 +17,7 @@ from django import forms from django.core import exceptions, validators from django.utils.datastructures import DictWrapper from django.utils.dateparse import parse_date, parse_datetime, parse_time -from django.utils.functional import curry, total_ordering +from django.utils.functional import curry, total_ordering, Promise from django.utils.text import capfirst from django.utils import timezone from django.utils.translation import ugettext_lazy as _ @@ -421,6 +421,8 @@ class Field(object): """ Perform preliminary non-db specific value checks and conversions. """ + if isinstance(value, Promise): + value = value._proxy____cast() return value def get_db_prep_value(self, value, connection, prepared=False): @@ -704,6 +706,7 @@ class AutoField(Field): return value def get_prep_value(self, value): + value = super(AutoField, self).get_prep_value(value) if value is None: return None return int(value) @@ -763,6 +766,7 @@ class BooleanField(Field): return super(BooleanField, self).get_prep_lookup(lookup_type, value) def get_prep_value(self, value): + value = super(BooleanField, self).get_prep_value(value) if value is None: return None return bool(value) @@ -796,6 +800,7 @@ class CharField(Field): return smart_text(value) def get_prep_value(self, value): + value = super(CharField, self).get_prep_value(value) return self.to_python(value) def formfield(self, **kwargs): @@ -911,6 +916,7 @@ class DateField(Field): return super(DateField, self).get_prep_lookup(lookup_type, value) def get_prep_value(self, value): + value = super(DateField, self).get_prep_value(value) return self.to_python(value) def get_db_prep_value(self, value, connection, prepared=False): @@ -1008,6 +1014,7 @@ class DateTimeField(DateField): # get_prep_lookup is inherited from DateField def get_prep_value(self, value): + value = super(DateTimeField, self).get_prep_value(value) value = self.to_python(value) if value is not None and settings.USE_TZ and timezone.is_naive(value): # For backwards compatibility, interpret naive datetimes in local @@ -1096,6 +1103,7 @@ class DecimalField(Field): self.max_digits, self.decimal_places) def get_prep_value(self, value): + value = super(DecimalField, self).get_prep_value(value) return self.to_python(value) def formfield(self, **kwargs): @@ -1185,6 +1193,7 @@ class FloatField(Field): description = _("Floating point number") def get_prep_value(self, value): + value = super(FloatField, self).get_prep_value(value) if value is None: return None return float(value) @@ -1218,6 +1227,7 @@ class IntegerField(Field): description = _("Integer") def get_prep_value(self, value): + value = super(IntegerField, self).get_prep_value(value) if value is None: return None return int(value) @@ -1326,6 +1336,7 @@ class GenericIPAddressField(Field): return value or None def get_prep_value(self, value): + value = super(GenericIPAddressField, self).get_prep_value(value) if value and ':' in value: try: return clean_ipv6_address(value, self.unpack_ipv4) @@ -1391,6 +1402,7 @@ class NullBooleanField(Field): value) def get_prep_value(self, value): + value = super(NullBooleanField, self).get_prep_value(value) if value is None: return None return bool(value) @@ -1473,6 +1485,7 @@ class TextField(Field): return "TextField" def get_prep_value(self, value): + value = super(TextField, self).get_prep_value(value) if isinstance(value, six.string_types) or value is None: return value return smart_text(value) @@ -1549,6 +1562,7 @@ class TimeField(Field): return super(TimeField, self).pre_save(model_instance, add) def get_prep_value(self, value): + value = super(TimeField, self).get_prep_value(value) return self.to_python(value) def get_db_prep_value(self, value, connection, prepared=False): diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index 311f74a905..61e3eebf49 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -253,6 +253,7 @@ class FileField(Field): def get_prep_value(self, value): "Returns field's value prepared for saving into a database." + value = super(FileField, self).get_prep_value(value) # Need to convert File objects provided via a form to unicode for database insertion if value is None: return None diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 6aab02bd9a..8beb3fa74a 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -11,8 +11,6 @@ from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint -from django.utils.functional import Promise -from django.utils.encoding import force_text from django.utils import six from django.utils import timezone @@ -147,10 +145,6 @@ class UpdateQuery(Query): Used by add_update_values() as well as the "fast" update path when saving models. """ - # Check that no Promise object passes to the query. Refs #10498. - values_seq = [(value[0], value[1], force_text(value[2])) - if isinstance(value[2], Promise) else value - for value in values_seq] self.values.extend(values_seq) def add_related_update(self, model, field, value): @@ -210,12 +204,6 @@ class InsertQuery(Query): into the query, for example. """ self.fields = fields - # Check that no Promise object reaches the DB. Refs #10498. - for field in fields: - for obj in objs: - value = getattr(obj, field.attname) - if isinstance(value, Promise): - setattr(obj, field.attname, force_text(value)) self.objs = objs self.raw = raw diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index a43f764407..04a17df947 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -8,11 +8,21 @@ from django import test from django import forms from django.core.exceptions import ValidationError from django.db import connection, models, IntegrityError -from django.db.models.fields.files import FieldFile +from django.db.models.fields import ( + AutoField, BigIntegerField, BinaryField, BooleanField, CharField, + CommaSeparatedIntegerField, DateField, DateTimeField, DecimalField, + EmailField, FilePathField, FloatField, IntegerField, IPAddressField, + GenericIPAddressField, NullBooleanField, PositiveIntegerField, + PositiveSmallIntegerField, SlugField, SmallIntegerField, TextField, + TimeField, URLField) +from django.db.models.fields.files import FileField, ImageField from django.utils import six +from django.utils.functional import lazy +from django.utils.unittest import skipIf -from .models import (Foo, Bar, Whiz, BigD, BigS, Image, BigInt, Post, - NullBooleanModel, BooleanModel, DataModel, Document, RenamedField, +from .models import ( + Foo, Bar, Whiz, BigD, BigS, BigInt, Post, NullBooleanModel, + BooleanModel, DataModel, Document, RenamedField, VerboseNameField, FksToBooleans) @@ -64,7 +74,7 @@ class BasicFieldTests(test.TestCase): m = VerboseNameField for i in range(1, 23): self.assertEqual(m._meta.get_field('field%d' % i).verbose_name, - 'verbose field%d' % i) + 'verbose field%d' % i) self.assertEqual(m._meta.get_field('id').verbose_name, 'verbose pk') @@ -290,9 +300,9 @@ class SlugFieldTests(test.TestCase): """ Make sure SlugField honors max_length (#9706) """ - bs = BigS.objects.create(s = 'slug'*50) + bs = BigS.objects.create(s='slug' * 50) bs = BigS.objects.get(pk=bs.pk) - self.assertEqual(bs.s, 'slug'*50) + self.assertEqual(bs.s, 'slug' * 50) class ValidationTest(test.TestCase): @@ -313,15 +323,17 @@ class ValidationTest(test.TestCase): self.assertRaises(ValidationError, f.clean, "a", None) def test_charfield_with_choices_cleans_valid_choice(self): - f = models.CharField(max_length=1, choices=[('a','A'), ('b','B')]) + f = models.CharField(max_length=1, + choices=[('a', 'A'), ('b', 'B')]) self.assertEqual('a', f.clean('a', None)) def test_charfield_with_choices_raises_error_on_invalid_choice(self): - f = models.CharField(choices=[('a','A'), ('b','B')]) + f = models.CharField(choices=[('a', 'A'), ('b', 'B')]) self.assertRaises(ValidationError, f.clean, "not a", None) def test_choices_validation_supports_named_groups(self): - f = models.IntegerField(choices=(('group',((10,'A'),(20,'B'))),(30,'C'))) + f = models.IntegerField( + choices=(('group', ((10, 'A'), (20, 'B'))), (30, 'C'))) self.assertEqual(10, f.clean(10, None)) def test_nullable_integerfield_raises_error_with_blank_false(self): @@ -370,7 +382,7 @@ class BigIntegerFieldTests(test.TestCase): self.assertEqual(qs[0].value, minval) def test_types(self): - b = BigInt(value = 0) + b = BigInt(value=0) self.assertIsInstance(b.value, six.integer_types) b.save() self.assertIsInstance(b.value, six.integer_types) @@ -378,8 +390,8 @@ class BigIntegerFieldTests(test.TestCase): self.assertIsInstance(b.value, six.integer_types) def test_coercing(self): - BigInt.objects.create(value ='10') - b = BigInt.objects.get(value = '10') + BigInt.objects.create(value='10') + b = BigInt.objects.get(value='10') self.assertEqual(b.value, 10) class TypeCoercionTests(test.TestCase): @@ -466,7 +478,7 @@ class BinaryFieldTests(test.TestCase): test_set_and_retrieve = unittest.expectedFailure(test_set_and_retrieve) def test_max_length(self): - dm = DataModel(short_data=self.binary_data*4) + dm = DataModel(short_data=self.binary_data * 4) self.assertRaises(ValidationError, dm.full_clean) class GenericIPAddressFieldTests(test.TestCase): @@ -481,3 +493,156 @@ class GenericIPAddressFieldTests(test.TestCase): model_field = models.GenericIPAddressField(protocol='IPv6') form_field = model_field.formfield() self.assertRaises(ValidationError, form_field.clean, '127.0.0.1') + + +class PromiseTest(test.TestCase): + def test_AutoField(self): + lazy_func = lazy(lambda: 1, int) + self.assertIsInstance( + AutoField(primary_key=True).get_prep_value(lazy_func()), + int) + + @skipIf(six.PY3, "Python 3 has no `long` type.") + def test_BigIntegerField(self): + lazy_func = lazy(lambda: long(9999999999999999999), long) + self.assertIsInstance( + BigIntegerField().get_prep_value(lazy_func()), + long) + + def test_BinaryField(self): + lazy_func = lazy(lambda: b'', bytes) + self.assertIsInstance( + BinaryField().get_prep_value(lazy_func()), + bytes) + + def test_BooleanField(self): + lazy_func = lazy(lambda: True, bool) + self.assertIsInstance( + BooleanField().get_prep_value(lazy_func()), + bool) + + def test_CharField(self): + lazy_func = lazy(lambda: '', six.text_type) + self.assertIsInstance( + CharField().get_prep_value(lazy_func()), + six.text_type) + + def test_CommaSeparatedIntegerField(self): + lazy_func = lazy(lambda: '1,2', six.text_type) + self.assertIsInstance( + CommaSeparatedIntegerField().get_prep_value(lazy_func()), + six.text_type) + + def test_DateField(self): + lazy_func = lazy(lambda: datetime.date.today(), datetime.date) + self.assertIsInstance( + DateField().get_prep_value(lazy_func()), + datetime.date) + + def test_DateTimeField(self): + lazy_func = lazy(lambda: datetime.datetime.now(), datetime.datetime) + self.assertIsInstance( + DateTimeField().get_prep_value(lazy_func()), + datetime.datetime) + + def test_DecimalField(self): + lazy_func = lazy(lambda: Decimal('1.2'), Decimal) + self.assertIsInstance( + DecimalField().get_prep_value(lazy_func()), + Decimal) + + def test_EmailField(self): + lazy_func = lazy(lambda: 'mailbox@domain.com', six.text_type) + self.assertIsInstance( + EmailField().get_prep_value(lazy_func()), + six.text_type) + + def test_FileField(self): + lazy_func = lazy(lambda: 'filename.ext', six.text_type) + self.assertIsInstance( + FileField().get_prep_value(lazy_func()), + six.text_type) + + def test_FilePathField(self): + lazy_func = lazy(lambda: 'tests.py', six.text_type) + self.assertIsInstance( + FilePathField().get_prep_value(lazy_func()), + six.text_type) + + def test_FloatField(self): + lazy_func = lazy(lambda: 1.2, float) + self.assertIsInstance( + FloatField().get_prep_value(lazy_func()), + float) + + def test_ImageField(self): + lazy_func = lazy(lambda: 'filename.ext', six.text_type) + self.assertIsInstance( + ImageField().get_prep_value(lazy_func()), + six.text_type) + + def test_IntegerField(self): + lazy_func = lazy(lambda: 1, int) + self.assertIsInstance( + IntegerField().get_prep_value(lazy_func()), + int) + + def test_IPAddressField(self): + lazy_func = lazy(lambda: '127.0.0.1', six.text_type) + self.assertIsInstance( + IPAddressField().get_prep_value(lazy_func()), + six.text_type) + + def test_GenericIPAddressField(self): + lazy_func = lazy(lambda: '127.0.0.1', six.text_type) + self.assertIsInstance( + GenericIPAddressField().get_prep_value(lazy_func()), + six.text_type) + + def test_NullBooleanField(self): + lazy_func = lazy(lambda: True, bool) + self.assertIsInstance( + NullBooleanField().get_prep_value(lazy_func()), + bool) + + def test_PositiveIntegerField(self): + lazy_func = lazy(lambda: 1, int) + self.assertIsInstance( + PositiveIntegerField().get_prep_value(lazy_func()), + int) + + def test_PositiveSmallIntegerField(self): + lazy_func = lazy(lambda: 1, int) + self.assertIsInstance( + PositiveSmallIntegerField().get_prep_value(lazy_func()), + int) + + def test_SlugField(self): + lazy_func = lazy(lambda: 'slug', six.text_type) + self.assertIsInstance( + SlugField().get_prep_value(lazy_func()), + six.text_type) + + def test_SmallIntegerField(self): + lazy_func = lazy(lambda: 1, int) + self.assertIsInstance( + SmallIntegerField().get_prep_value(lazy_func()), + int) + + def test_TextField(self): + lazy_func = lazy(lambda: 'Abc', six.text_type) + self.assertIsInstance( + TextField().get_prep_value(lazy_func()), + six.text_type) + + def test_TimeField(self): + lazy_func = lazy(lambda: datetime.datetime.now().time(), datetime.time) + self.assertIsInstance( + TimeField().get_prep_value(lazy_func()), + datetime.time) + + def test_URLField(self): + lazy_func = lazy(lambda: 'http://domain.com', six.text_type) + self.assertIsInstance( + URLField().get_prep_value(lazy_func()), + six.text_type)