diff --git a/django/db/models/query.py b/django/db/models/query.py index 48d295ccca..353dd95794 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty from django.db.models.query_utils import (Q, select_related_descend, deferred_class_factory, InvalidQuery) from django.db.models.deletion import Collector +from django.db.models.sql.constants import CURSOR from django.db.models import sql from django.utils.functional import partition from django.utils import six @@ -574,7 +575,7 @@ class QuerySet(object): query = self.query.clone(sql.UpdateQuery) query.add_update_values(kwargs) with transaction.commit_on_success_unless_managed(using=self.db): - rows = query.get_compiler(self.db).execute_sql(None) + rows = query.get_compiler(self.db).execute_sql(CURSOR) self._result_cache = None return rows update.alters_data = True @@ -591,7 +592,7 @@ class QuerySet(object): query = self.query.clone(sql.UpdateQuery) query.add_update_fields(values) self._result_cache = None - return query.get_compiler(self.db).execute_sql(None) + return query.get_compiler(self.db).execute_sql(CURSOR) _update.alters_data = True _update.queryset_only = False diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 123427cf8b..536a66d139 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -5,8 +5,8 @@ from django.core.exceptions import FieldError from django.db.backends.utils import truncate_name from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import select_related_descend, QueryWrapper -from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, - GET_ITERATOR_CHUNK_SIZE, SelectInfo) +from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS, + ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import get_order_dir, Query @@ -762,6 +762,8 @@ class SQLCompiler(object): is needed, as the filters describe an empty set. In that case, None is returned, to avoid any unnecessary database interaction. """ + if not result_type: + result_type = NO_RESULTS try: sql, params = self.as_sql() if not sql: @@ -773,27 +775,44 @@ class SQLCompiler(object): return cursor = self.connection.cursor() - cursor.execute(sql, params) + try: + cursor.execute(sql, params) + except: + cursor.close() + raise - if not result_type: + if result_type == CURSOR: + # Caller didn't specify a result_type, so just give them back the + # cursor to process (and close). return cursor if result_type == SINGLE: - if self.ordering_aliases: - return cursor.fetchone()[:-len(self.ordering_aliases)] - return cursor.fetchone() + try: + if self.ordering_aliases: + return cursor.fetchone()[:-len(self.ordering_aliases)] + return cursor.fetchone() + finally: + # done with the cursor + cursor.close() + if result_type == NO_RESULTS: + cursor.close() + return # The MULTI case. if self.ordering_aliases: result = order_modified_iter(cursor, len(self.ordering_aliases), self.connection.features.empty_fetchmany_value) else: - result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - self.connection.features.empty_fetchmany_value) + result = cursor_iter(cursor, + self.connection.features.empty_fetchmany_value) if not self.connection.features.can_use_chunked_reads: - # If we are using non-chunked reads, we return the same data - # structure as normally, but ensure it is all read into memory - # before going any further. - return list(result) + try: + # If we are using non-chunked reads, we return the same data + # structure as normally, but ensure it is all read into memory + # before going any further. + return list(result) + finally: + # done with the cursor + cursor.close() return result def as_subquery_condition(self, alias, columns, qn): @@ -970,12 +989,15 @@ class SQLUpdateCompiler(SQLCompiler): related queries are not available. """ cursor = super(SQLUpdateCompiler, self).execute_sql(result_type) - rows = cursor.rowcount if cursor else 0 - is_empty = cursor is None - del cursor + try: + rows = cursor.rowcount if cursor else 0 + is_empty = cursor is None + finally: + if cursor: + cursor.close() for query in self.query.get_related_updates(): aux_rows = query.get_compiler(self.using).execute_sql(result_type) - if is_empty: + if is_empty and aux_rows: rows = aux_rows is_empty = False return rows @@ -1111,6 +1133,19 @@ class SQLDateTimeCompiler(SQLCompiler): yield datetime +def cursor_iter(cursor, sentinel): + """ + Yields blocks of rows from a cursor and ensures the cursor is closed when + done. + """ + try: + for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + sentinel): + yield rows + finally: + cursor.close() + + def order_modified_iter(cursor, trim, sentinel): """ Yields blocks of rows from a cursor. We use this iterator in the special @@ -1118,6 +1153,9 @@ def order_modified_iter(cursor, trim, sentinel): requirements. We must trim those extra columns before anything else can use the results, since they're only needed to make the SQL valid. """ - for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - sentinel): - yield [r[:-trim] for r in rows] + try: + for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + sentinel): + yield [r[:-trim] for r in rows] + finally: + cursor.close() diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 904f7b2c8b..36aab23bae 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field') # How many results to expect from a cursor.execute call MULTI = 'multi' SINGLE = 'single' +CURSOR = 'cursor' +NO_RESULTS = 'no results' ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$') ORDER_DIR = { diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 86b1efd3f8..cfda1f552c 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -8,7 +8,7 @@ from django.db import connections from django.db.models.query_utils import Q from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist -from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo +from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.utils import six @@ -30,7 +30,7 @@ class DeleteQuery(Query): def do_query(self, table, where, using): self.tables = [table] self.where = where - self.get_compiler(using).execute_sql(None) + self.get_compiler(using).execute_sql(NO_RESULTS) def delete_batch(self, pk_list, using, field=None): """ @@ -82,7 +82,7 @@ class DeleteQuery(Query): values = innerq self.where = self.where_class() self.add_q(Q(pk__in=values)) - self.get_compiler(using).execute_sql(None) + self.get_compiler(using).execute_sql(NO_RESULTS) class UpdateQuery(Query): @@ -116,7 +116,7 @@ class UpdateQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.where = self.where_class() self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])) - self.get_compiler(using).execute_sql(None) + self.get_compiler(using).execute_sql(NO_RESULTS) def add_update_values(self, values): """ diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 0ff3ad0bba..4a3fc31b7a 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper from django.db.models import Sum, Avg, Variance, StdDev from django.db.models.fields import (AutoField, DateField, DateTimeField, DecimalField, IntegerField, TimeField) +from django.db.models.sql.constants import CURSOR from django.db.utils import ConnectionHandler from django.test import (TestCase, TransactionTestCase, override_settings, skipUnlessDBFeature, skipIfDBFeature) @@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase): """ persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1}) sql, params = persons.query.sql_with_params() - cursor = persons.query.get_compiler('default').execute_sql(None) + cursor = persons.query.get_compiler('default').execute_sql(CURSOR) last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) self.assertIsInstance(last_sql, six.text_type)