Fixed #25335 -- Fixed regression where QuerySet.iterator() didn't return an iterator.

This commit is contained in:
Gavin Wahl 2015-09-02 13:17:53 -06:00 committed by Tim Graham
parent a8eb715b66
commit 627c7eb7bf
2 changed files with 17 additions and 14 deletions

View File

@ -33,14 +33,14 @@ REPR_OUTPUT_SIZE = 20
EmptyResultSet = sql.EmptyResultSet EmptyResultSet = sql.EmptyResultSet
class BaseIterator(object): class BaseIterable(object):
def __init__(self, queryset): def __init__(self, queryset):
self.queryset = queryset self.queryset = queryset
class ModelIterator(BaseIterator): class ModelIterable(BaseIterable):
""" """
Iterator that yields a model instance for each row. Iterable that yields a model instance for each row.
""" """
def __iter__(self): def __iter__(self):
@ -91,9 +91,9 @@ class ModelIterator(BaseIterator):
yield obj yield obj
class ValuesIterator(BaseIterator): class ValuesIterable(BaseIterable):
""" """
Iterator returned by QuerySet.values() that yields a dict Iterable returned by QuerySet.values() that yields a dict
for each row. for each row.
""" """
@ -113,9 +113,9 @@ class ValuesIterator(BaseIterator):
yield dict(zip(names, row)) yield dict(zip(names, row))
class ValuesListIterator(BaseIterator): class ValuesListIterable(BaseIterable):
""" """
Iterator returned by QuerySet.values_lists(flat=False) Iterable returned by QuerySet.values_lists(flat=False)
that yields a tuple for each row. that yields a tuple for each row.
""" """
@ -146,9 +146,9 @@ class ValuesListIterator(BaseIterator):
yield tuple(data[f] for f in fields) yield tuple(data[f] for f in fields)
class FlatValuesListIterator(BaseIterator): class FlatValuesListIterable(BaseIterable):
""" """
Iterator returned by QuerySet.values_lists(flat=True) that Iterable returned by QuerySet.values_lists(flat=True) that
yields single values. yields single values.
""" """
@ -175,7 +175,7 @@ class QuerySet(object):
self._prefetch_related_lookups = [] self._prefetch_related_lookups = []
self._prefetch_done = False self._prefetch_done = False
self._known_related_objects = {} # {rel_field, {pk: rel_obj}} self._known_related_objects = {} # {rel_field, {pk: rel_obj}}
self._iterator_class = ModelIterator self._iterable_class = ModelIterable
self._fields = None self._fields = None
def as_manager(cls): def as_manager(cls):
@ -327,7 +327,7 @@ class QuerySet(object):
An iterator over the results from applying this QuerySet to the An iterator over the results from applying this QuerySet to the
database. database.
""" """
return self._iterator_class(self) return iter(self._iterable_class(self))
def aggregate(self, *args, **kwargs): def aggregate(self, *args, **kwargs):
""" """
@ -708,7 +708,7 @@ class QuerySet(object):
def values(self, *fields): def values(self, *fields):
clone = self._values(*fields) clone = self._values(*fields)
clone._iterator_class = ValuesIterator clone._iterable_class = ValuesIterable
return clone return clone
def values_list(self, *fields, **kwargs): def values_list(self, *fields, **kwargs):
@ -721,7 +721,7 @@ class QuerySet(object):
raise TypeError("'flat' is not valid when values_list is called with more than one field.") raise TypeError("'flat' is not valid when values_list is called with more than one field.")
clone = self._values(*fields) clone = self._values(*fields)
clone._iterator_class = FlatValuesListIterator if flat else ValuesListIterator clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
return clone return clone
def dates(self, field_name, kind, order='ASC'): def dates(self, field_name, kind, order='ASC'):
@ -1061,7 +1061,7 @@ class QuerySet(object):
clone._for_write = self._for_write clone._for_write = self._for_write
clone._prefetch_related_lookups = self._prefetch_related_lookups[:] clone._prefetch_related_lookups = self._prefetch_related_lookups[:]
clone._known_related_objects = self._known_related_objects clone._known_related_objects = self._known_related_objects
clone._iterator_class = self._iterator_class clone._iterable_class = self._iterable_class
clone._fields = self._fields clone._fields = self._fields
clone.__dict__.update(kwargs) clone.__dict__.update(kwargs)

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import collections
from datetime import datetime from datetime import datetime
from operator import attrgetter from operator import attrgetter
from unittest import skipUnless from unittest import skipUnless
@ -75,6 +76,8 @@ class LookupTests(TestCase):
def test_iterator(self): def test_iterator(self):
# Each QuerySet gets iterator(), which is a generator that "lazily" # Each QuerySet gets iterator(), which is a generator that "lazily"
# returns results using database-level iteration. # returns results using database-level iteration.
self.assertIsInstance(Article.objects.iterator(), collections.Iterator)
self.assertQuerysetEqual(Article.objects.iterator(), self.assertQuerysetEqual(Article.objects.iterator(),
[ [
'Article 5', 'Article 5',