Fixed #24483 -- Prevented keepdb from breaking with generator choices.

If Field.choices is provided as an iterator, consume it in __init__ instead
of using itertools.tee (which ends up holding everything in memory
anyway). Fixes a bug where deconstruct() was consuming the iterator but
bypassing the call to `tee`.
This commit is contained in:
David Szotten 2015-03-15 19:07:39 +00:00 committed by Tim Graham
parent 118cae2df8
commit 80e3444eca
3 changed files with 24 additions and 19 deletions

View File

@ -9,7 +9,6 @@ import math
import uuid import uuid
import warnings import warnings
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from itertools import tee
from django.apps import apps from django.apps import apps
from django.db import connection from django.db import connection
@ -155,7 +154,9 @@ class Field(RegisterLookupMixin):
self.unique_for_date = unique_for_date self.unique_for_date = unique_for_date
self.unique_for_month = unique_for_month self.unique_for_month = unique_for_month
self.unique_for_year = unique_for_year self.unique_for_year = unique_for_year
self._choices = choices or [] if isinstance(choices, collections.Iterator):
choices = list(choices)
self.choices = choices or []
self.help_text = help_text self.help_text = help_text
self.db_index = db_index self.db_index = db_index
self.db_column = db_column self.db_column = db_column
@ -405,7 +406,6 @@ class Field(RegisterLookupMixin):
} }
attr_overrides = { attr_overrides = {
"unique": "_unique", "unique": "_unique",
"choices": "_choices",
"error_messages": "_error_messages", "error_messages": "_error_messages",
"validators": "_validators", "validators": "_validators",
"verbose_name": "_verbose_name", "verbose_name": "_verbose_name",
@ -553,7 +553,7 @@ class Field(RegisterLookupMixin):
# Skip validation for non-editable fields. # Skip validation for non-editable fields.
return return
if self._choices and value not in self.empty_values: if self.choices and value not in self.empty_values:
for option_key, option_value in self.choices: for option_key, option_value in self.choices:
if isinstance(option_value, (list, tuple)): if isinstance(option_value, (list, tuple)):
# This is an optgroup, so look inside the group for # This is an optgroup, so look inside the group for
@ -848,14 +848,6 @@ class Field(RegisterLookupMixin):
""" """
return smart_text(self._get_val_from_obj(obj)) return smart_text(self._get_val_from_obj(obj))
def _get_choices(self):
if isinstance(self._choices, collections.Iterator):
choices, self._choices = tee(self._choices)
return choices
else:
return self._choices
choices = property(_get_choices)
def _get_flatchoices(self): def _get_flatchoices(self):
"""Flattened version of choices tuple.""" """Flattened version of choices tuple."""
flat = [] flat = []

View File

@ -595,6 +595,26 @@ class StateTests(TestCase):
self.assertIsNot(old_model.food_mgr.model, new_model.food_mgr.model) self.assertIsNot(old_model.food_mgr.model, new_model.food_mgr.model)
self.assertIsNot(old_model.food_qs.model, new_model.food_qs.model) self.assertIsNot(old_model.food_qs.model, new_model.food_qs.model)
def test_choices_iterator(self):
"""
#24483 - ProjectState.from_apps should not destructively consume
Field.choices iterators.
"""
new_apps = Apps(["migrations"])
choices = [('a', 'A'), ('b', 'B')]
class Author(models.Model):
name = models.CharField(max_length=255)
choice = models.CharField(max_length=255, choices=iter(choices))
class Meta:
app_label = "migrations"
apps = new_apps
ProjectState.from_apps(new_apps)
choices_field = Author._meta.get_field('choice')
self.assertEqual(list(choices_field.choices), choices)
class ModelStateTests(TestCase): class ModelStateTests(TestCase):
def test_custom_model_base(self): def test_custom_model_base(self):

View File

@ -445,13 +445,6 @@ class ChoicesTests(test.TestCase):
self.assertEqual(WhizIterEmpty(c=None).c, None) # Blank value self.assertEqual(WhizIterEmpty(c=None).c, None) # Blank value
self.assertEqual(WhizIterEmpty(c='').c, '') # Empty value self.assertEqual(WhizIterEmpty(c='').c, '') # Empty value
def test_charfield_get_choices_with_blank_iterator(self):
"""
Check that get_choices works with an empty Iterator
"""
f = models.CharField(choices=(x for x in []))
self.assertEqual(f.get_choices(include_blank=True), [('', '---------')])
class SlugFieldTests(test.TestCase): class SlugFieldTests(test.TestCase):
def test_slugfield_max_length(self): def test_slugfield_max_length(self):