Fixed #23112 -- Field.get_choices tries to index an iterable

This commit is contained in:
areski 2014-07-27 23:39:40 +02:00 committed by Florian Apolloner
parent c2ab501bab
commit 97a38de230
3 changed files with 54 additions and 4 deletions

View File

@ -730,9 +730,10 @@ class Field(RegisterLookupMixin):
"""Returns choices with a default blank choices included, for use """Returns choices with a default blank choices included, for use
as SelectField choices for this field.""" as SelectField choices for this field."""
blank_defined = False blank_defined = False
named_groups = self.choices and isinstance(self.choices[0][1], (list, tuple)) choices = list(self.choices) if self.choices else []
named_groups = choices and isinstance(choices[0][1], (list, tuple))
if not named_groups: if not named_groups:
for choice, __ in self.choices: for choice, __ in choices:
if choice in ('', None): if choice in ('', None):
blank_defined = True blank_defined = True
break break
@ -740,7 +741,7 @@ class Field(RegisterLookupMixin):
first_choice = (blank_choice if include_blank and first_choice = (blank_choice if include_blank and
not blank_defined else []) not blank_defined else [])
if self.choices: if self.choices:
return first_choice + list(self.choices) return first_choice + choices
rel_model = self.rel.to rel_model = self.rel.to
if hasattr(self.rel, 'get_related_field'): if hasattr(self.rel, 'get_related_field'):
lst = [(getattr(x, self.rel.get_related_field().attname), lst = [(getattr(x, self.rel.get_related_field().attname),

View File

@ -43,6 +43,29 @@ class Whiz(models.Model):
c = models.IntegerField(choices=CHOICES, null=True) c = models.IntegerField(choices=CHOICES, null=True)
class Counter:
def __init__(self):
self.n = 1
def __iter__(self):
return self
def next(self): # Python 3: def __next__(self)
if self.n > 5:
raise StopIteration
else:
self.n += 1
return (self.n, 'val-'+str(self.n))
class WhizIter(models.Model):
c = models.IntegerField(choices=Counter(), null=True)
class WhizIterEmpty(models.Model):
c = models.CharField(choices=(x for x in []), blank=True, max_length=1)
class BigD(models.Model): class BigD(models.Model):
d = models.DecimalField(max_digits=38, decimal_places=30) d = models.DecimalField(max_digits=38, decimal_places=30)

View File

@ -25,7 +25,8 @@ from .models import (
Foo, Bar, Whiz, BigD, BigS, BigIntegerModel, Post, NullBooleanModel, Foo, Bar, Whiz, BigD, BigS, BigIntegerModel, Post, NullBooleanModel,
BooleanModel, PrimaryKeyCharModel, DataModel, Document, RenamedField, BooleanModel, PrimaryKeyCharModel, DataModel, Document, RenamedField,
DateTimeModel, VerboseNameField, FksToBooleans, FkToChar, FloatModel, DateTimeModel, VerboseNameField, FksToBooleans, FkToChar, FloatModel,
SmallIntegerModel, IntegerModel, PositiveSmallIntegerModel, PositiveIntegerModel) SmallIntegerModel, IntegerModel, PositiveSmallIntegerModel, PositiveIntegerModel,
WhizIter, WhizIterEmpty)
class BasicFieldTests(test.TestCase): class BasicFieldTests(test.TestCase):
@ -375,6 +376,31 @@ class ChoicesTests(test.TestCase):
self.assertEqual(Whiz(c=None).get_c_display(), None) # Blank value self.assertEqual(Whiz(c=None).get_c_display(), None) # Blank value
self.assertEqual(Whiz(c='').get_c_display(), '') # Empty value self.assertEqual(Whiz(c='').get_c_display(), '') # Empty value
def test_iterator_choices(self):
"""
Check that get_choices works with Iterators (#23112).
"""
self.assertEqual(WhizIter(c=1).c, 1) # A nested value
self.assertEqual(WhizIter(c=9).c, 9) # Invalid value
self.assertEqual(WhizIter(c=None).c, None) # Blank value
self.assertEqual(WhizIter(c='').c, '') # Empty value
def test_empty_iterator_choices(self):
"""
Check that get_choices works with empty iterators (#23112).
"""
self.assertEqual(WhizIterEmpty(c="a").c, "a") # A nested value
self.assertEqual(WhizIterEmpty(c="b").c, "b") # Invalid value
self.assertEqual(WhizIterEmpty(c=None).c, None) # Blank value
self.assertEqual(WhizIterEmpty(c='').c, '') # Empty value
def test_charfield_get_choices_with_blank_iterator(self):
"""
Check that get_choices works with an empty Iterator
"""
f = models.CharField(choices=(x for x in []))
self.assertEqual(f.get_choices(include_blank=True), [('', '---------')])
class SlugFieldTests(test.TestCase): class SlugFieldTests(test.TestCase):
def test_slugfield_max_length(self): def test_slugfield_max_length(self):