From 97a38de230371c0b6ad8a86abba8425186c147c7 Mon Sep 17 00:00:00 2001 From: areski Date: Sun, 27 Jul 2014 23:39:40 +0200 Subject: [PATCH] Fixed #23112 -- Field.get_choices tries to index an iterable --- django/db/models/fields/__init__.py | 7 ++++--- tests/model_fields/models.py | 23 +++++++++++++++++++++++ tests/model_fields/tests.py | 28 +++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 99bb948184f..3e2c894a4d5 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -730,9 +730,10 @@ class Field(RegisterLookupMixin): """Returns choices with a default blank choices included, for use as SelectField choices for this field.""" blank_defined = False - named_groups = self.choices and isinstance(self.choices[0][1], (list, tuple)) + choices = list(self.choices) if self.choices else [] + named_groups = choices and isinstance(choices[0][1], (list, tuple)) if not named_groups: - for choice, __ in self.choices: + for choice, __ in choices: if choice in ('', None): blank_defined = True break @@ -740,7 +741,7 @@ class Field(RegisterLookupMixin): first_choice = (blank_choice if include_blank and not blank_defined else []) if self.choices: - return first_choice + list(self.choices) + return first_choice + choices rel_model = self.rel.to if hasattr(self.rel, 'get_related_field'): lst = [(getattr(x, self.rel.get_related_field().attname), diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index 30140a72311..8dfdcb91d58 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -43,6 +43,29 @@ class Whiz(models.Model): c = models.IntegerField(choices=CHOICES, null=True) +class Counter: + def __init__(self): + self.n = 1 + + def __iter__(self): + return self + + def next(self): # Python 3: def __next__(self) + if self.n > 5: + raise StopIteration + else: + self.n += 1 + return (self.n, 'val-'+str(self.n)) + + +class WhizIter(models.Model): + c = models.IntegerField(choices=Counter(), null=True) + + +class WhizIterEmpty(models.Model): + c = models.CharField(choices=(x for x in []), blank=True, max_length=1) + + class BigD(models.Model): d = models.DecimalField(max_digits=38, decimal_places=30) diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index 2bf3b90a743..64b848af7a6 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -25,7 +25,8 @@ from .models import ( Foo, Bar, Whiz, BigD, BigS, BigIntegerModel, Post, NullBooleanModel, BooleanModel, PrimaryKeyCharModel, DataModel, Document, RenamedField, DateTimeModel, VerboseNameField, FksToBooleans, FkToChar, FloatModel, - SmallIntegerModel, IntegerModel, PositiveSmallIntegerModel, PositiveIntegerModel) + SmallIntegerModel, IntegerModel, PositiveSmallIntegerModel, PositiveIntegerModel, + WhizIter, WhizIterEmpty) class BasicFieldTests(test.TestCase): @@ -375,6 +376,31 @@ class ChoicesTests(test.TestCase): self.assertEqual(Whiz(c=None).get_c_display(), None) # Blank value self.assertEqual(Whiz(c='').get_c_display(), '') # Empty value + def test_iterator_choices(self): + """ + Check that get_choices works with Iterators (#23112). + """ + self.assertEqual(WhizIter(c=1).c, 1) # A nested value + self.assertEqual(WhizIter(c=9).c, 9) # Invalid value + self.assertEqual(WhizIter(c=None).c, None) # Blank value + self.assertEqual(WhizIter(c='').c, '') # Empty value + + def test_empty_iterator_choices(self): + """ + Check that get_choices works with empty iterators (#23112). + """ + self.assertEqual(WhizIterEmpty(c="a").c, "a") # A nested value + self.assertEqual(WhizIterEmpty(c="b").c, "b") # Invalid value + self.assertEqual(WhizIterEmpty(c=None).c, None) # Blank value + self.assertEqual(WhizIterEmpty(c='').c, '') # Empty value + + def test_charfield_get_choices_with_blank_iterator(self): + """ + Check that get_choices works with an empty Iterator + """ + f = models.CharField(choices=(x for x in [])) + self.assertEqual(f.get_choices(include_blank=True), [('', '---------')]) + class SlugFieldTests(test.TestCase): def test_slugfield_max_length(self):