Fixed #28817 -- Made QuerySet.iterator() use server-side cursors after values() and values_list().

This commit is contained in:
Dražen Odobašić 2017-11-19 10:13:10 -05:00 committed by Tim Graham
parent 6cb6382639
commit d97f026a7a
4 changed files with 27 additions and 7 deletions

View File

@ -105,7 +105,7 @@ class ValuesIterable(BaseIterable):
names = extra_names + field_names + annotation_names names = extra_names + field_names + annotation_names
indexes = range(len(names)) indexes = range(len(names))
for row in compiler.results_iter(): for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size):
yield {names[i]: row[i] for i in indexes} yield {names[i]: row[i] for i in indexes}
@ -133,8 +133,11 @@ class ValuesListIterable(BaseIterable):
# Reorder according to fields. # Reorder according to fields.
index_map = {name: idx for idx, name in enumerate(names)} index_map = {name: idx for idx, name in enumerate(names)}
rowfactory = operator.itemgetter(*[index_map[f] for f in fields]) rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
return map(rowfactory, compiler.results_iter()) return map(
return compiler.results_iter(tuple_expected=True) rowfactory,
compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size)
)
return compiler.results_iter(tuple_expected=True, chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size)
class NamedValuesListIterable(ValuesListIterable): class NamedValuesListIterable(ValuesListIterable):
@ -174,7 +177,7 @@ class FlatValuesListIterable(BaseIterable):
def __iter__(self): def __iter__(self):
queryset = self.queryset queryset = self.queryset
compiler = queryset.query.get_compiler(queryset.db) compiler = queryset.query.get_compiler(queryset.db)
return chain.from_iterable(compiler.results_iter()) return chain.from_iterable(compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size))
class QuerySet: class QuerySet:

View File

@ -1002,10 +1002,11 @@ class SQLCompiler:
row[pos] = value row[pos] = value
yield row yield row
def results_iter(self, results=None, tuple_expected=False): def results_iter(self, results=None, tuple_expected=False, chunked_fetch=False,
chunk_size=GET_ITERATOR_CHUNK_SIZE):
"""Return an iterator over the results from executing this query.""" """Return an iterator over the results from executing this query."""
if results is None: if results is None:
results = self.execute_sql(MULTI) results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
fields = [s[0] for s in self.select[0:self.col_count]] fields = [s[0] for s in self.select[0:self.col_count]]
converters = self.get_converters(fields) converters = self.get_converters(fields)
rows = chain.from_iterable(results) rows = chain.from_iterable(results)

View File

@ -18,3 +18,6 @@ Bugfixes
* Fixed incorrect index name truncation when using a namespaced ``db_table`` * Fixed incorrect index name truncation when using a namespaced ``db_table``
(:ticket:`28792`). (:ticket:`28792`).
* Made ``QuerySet.iterator()`` use server-side cursors on PostgreSQL after
``values()`` and ``values_list()`` (:ticket:`28817`).

View File

@ -3,7 +3,7 @@ import unittest
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager from contextlib import contextmanager
from django.db import connection from django.db import connection, models
from django.test import TestCase from django.test import TestCase
from ..models import Person from ..models import Person
@ -53,6 +53,19 @@ class ServerSideCursorsPostgres(TestCase):
def test_server_side_cursor(self): def test_server_side_cursor(self):
self.assertUsesCursor(Person.objects.iterator()) self.assertUsesCursor(Person.objects.iterator())
def test_values(self):
self.assertUsesCursor(Person.objects.values('first_name').iterator())
def test_values_list(self):
self.assertUsesCursor(Person.objects.values_list('first_name').iterator())
def test_values_list_flat(self):
self.assertUsesCursor(Person.objects.values_list('first_name', flat=True).iterator())
def test_values_list_fields_not_equal_to_names(self):
expr = models.Count('id')
self.assertUsesCursor(Person.objects.annotate(id__count=expr).values_list(expr, 'id__count').iterator())
def test_server_side_cursor_many_cursors(self): def test_server_side_cursor_many_cursors(self):
persons = Person.objects.iterator() persons = Person.objects.iterator()
persons2 = Person.objects.iterator() persons2 = Person.objects.iterator()