Made SQLCompiler.execute_sql(result_type) more explicit.

Updated SQLUpdateCompiler.execute_sql to match the behavior described in
the docstring; the 'first non-empty query' will now include all queries,
not just the main and first related update.

Added CURSOR and NO_RESULTS result_type constants to make the usages more
self documenting and allow execute_sql to explicitly close the cursor when
it is no longer needed.
This commit is contained in:
Michael Manfre 2014-01-08 23:31:34 -05:00
parent ab2f21080d
commit 0837eacc4e
5 changed files with 69 additions and 27 deletions

View File

@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty
from django.db.models.query_utils import (Q, select_related_descend, from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery) deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.sql.constants import CURSOR
from django.db.models import sql from django.db.models import sql
from django.utils.functional import partition from django.utils.functional import partition
from django.utils import six from django.utils import six
@ -574,7 +575,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery) query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs) query.add_update_values(kwargs)
with transaction.commit_on_success_unless_managed(using=self.db): with transaction.commit_on_success_unless_managed(using=self.db):
rows = query.get_compiler(self.db).execute_sql(None) rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None self._result_cache = None
return rows return rows
update.alters_data = True update.alters_data = True
@ -591,7 +592,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery) query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values) query.add_update_fields(values)
self._result_cache = None self._result_cache = None
return query.get_compiler(self.db).execute_sql(None) return query.get_compiler(self.db).execute_sql(CURSOR)
_update.alters_data = True _update.alters_data = True
_update.queryset_only = False _update.queryset_only = False

View File

@ -5,8 +5,8 @@ from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend, QueryWrapper from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
GET_ITERATOR_CHUNK_SIZE, SelectInfo) ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query from django.db.models.sql.query import get_order_dir, Query
@ -762,6 +762,8 @@ class SQLCompiler(object):
is needed, as the filters describe an empty set. In that case, None is is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction. returned, to avoid any unnecessary database interaction.
""" """
if not result_type:
result_type = NO_RESULTS
try: try:
sql, params = self.as_sql() sql, params = self.as_sql()
if not sql: if not sql:
@ -773,27 +775,44 @@ class SQLCompiler(object):
return return
cursor = self.connection.cursor() cursor = self.connection.cursor()
cursor.execute(sql, params) try:
cursor.execute(sql, params)
except:
cursor.close()
raise
if not result_type: if result_type == CURSOR:
# Caller didn't specify a result_type, so just give them back the
# cursor to process (and close).
return cursor return cursor
if result_type == SINGLE: if result_type == SINGLE:
if self.ordering_aliases: try:
return cursor.fetchone()[:-len(self.ordering_aliases)] if self.ordering_aliases:
return cursor.fetchone() return cursor.fetchone()[:-len(self.ordering_aliases)]
return cursor.fetchone()
finally:
# done with the cursor
cursor.close()
if result_type == NO_RESULTS:
cursor.close()
return
# The MULTI case. # The MULTI case.
if self.ordering_aliases: if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.ordering_aliases), result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value) self.connection.features.empty_fetchmany_value)
else: else:
result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), result = cursor_iter(cursor,
self.connection.features.empty_fetchmany_value) self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads: if not self.connection.features.can_use_chunked_reads:
# If we are using non-chunked reads, we return the same data try:
# structure as normally, but ensure it is all read into memory # If we are using non-chunked reads, we return the same data
# before going any further. # structure as normally, but ensure it is all read into memory
return list(result) # before going any further.
return list(result)
finally:
# done with the cursor
cursor.close()
return result return result
def as_subquery_condition(self, alias, columns, qn): def as_subquery_condition(self, alias, columns, qn):
@ -970,12 +989,15 @@ class SQLUpdateCompiler(SQLCompiler):
related queries are not available. related queries are not available.
""" """
cursor = super(SQLUpdateCompiler, self).execute_sql(result_type) cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
rows = cursor.rowcount if cursor else 0 try:
is_empty = cursor is None rows = cursor.rowcount if cursor else 0
del cursor is_empty = cursor is None
finally:
if cursor:
cursor.close()
for query in self.query.get_related_updates(): for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type) aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty: if is_empty and aux_rows:
rows = aux_rows rows = aux_rows
is_empty = False is_empty = False
return rows return rows
@ -1111,6 +1133,19 @@ class SQLDateTimeCompiler(SQLCompiler):
yield datetime yield datetime
def cursor_iter(cursor, sentinel):
"""
Yields blocks of rows from a cursor and ensures the cursor is closed when
done.
"""
try:
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield rows
finally:
cursor.close()
def order_modified_iter(cursor, trim, sentinel): def order_modified_iter(cursor, trim, sentinel):
""" """
Yields blocks of rows from a cursor. We use this iterator in the special Yields blocks of rows from a cursor. We use this iterator in the special
@ -1118,6 +1153,9 @@ def order_modified_iter(cursor, trim, sentinel):
requirements. We must trim those extra columns before anything else can use requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid. the results, since they're only needed to make the SQL valid.
""" """
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), try:
sentinel): for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
yield [r[:-trim] for r in rows] sentinel):
yield [r[:-trim] for r in rows]
finally:
cursor.close()

View File

@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field')
# How many results to expect from a cursor.execute call # How many results to expect from a cursor.execute call
MULTI = 'multi' MULTI = 'multi'
SINGLE = 'single' SINGLE = 'single'
CURSOR = 'cursor'
NO_RESULTS = 'no results'
ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$') ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
ORDER_DIR = { ORDER_DIR = {

View File

@ -8,7 +8,7 @@ from django.db import connections
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils import six from django.utils import six
@ -30,7 +30,7 @@ class DeleteQuery(Query):
def do_query(self, table, where, using): def do_query(self, table, where, using):
self.tables = [table] self.tables = [table]
self.where = where self.where = where
self.get_compiler(using).execute_sql(None) self.get_compiler(using).execute_sql(NO_RESULTS)
def delete_batch(self, pk_list, using, field=None): def delete_batch(self, pk_list, using, field=None):
""" """
@ -82,7 +82,7 @@ class DeleteQuery(Query):
values = innerq values = innerq
self.where = self.where_class() self.where = self.where_class()
self.add_q(Q(pk__in=values)) self.add_q(Q(pk__in=values))
self.get_compiler(using).execute_sql(None) self.get_compiler(using).execute_sql(NO_RESULTS)
class UpdateQuery(Query): class UpdateQuery(Query):
@ -116,7 +116,7 @@ class UpdateQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class() self.where = self.where_class()
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])) self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
self.get_compiler(using).execute_sql(None) self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values): def add_update_values(self, values):
""" """

View File

@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper
from django.db.models import Sum, Avg, Variance, StdDev from django.db.models import Sum, Avg, Variance, StdDev
from django.db.models.fields import (AutoField, DateField, DateTimeField, from django.db.models.fields import (AutoField, DateField, DateTimeField,
DecimalField, IntegerField, TimeField) DecimalField, IntegerField, TimeField)
from django.db.models.sql.constants import CURSOR
from django.db.utils import ConnectionHandler from django.db.utils import ConnectionHandler
from django.test import (TestCase, TransactionTestCase, override_settings, from django.test import (TestCase, TransactionTestCase, override_settings,
skipUnlessDBFeature, skipIfDBFeature) skipUnlessDBFeature, skipIfDBFeature)
@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase):
""" """
persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1}) persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1})
sql, params = persons.query.sql_with_params() sql, params = persons.query.sql_with_params()
cursor = persons.query.get_compiler('default').execute_sql(None) cursor = persons.query.get_compiler('default').execute_sql(CURSOR)
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
self.assertIsInstance(last_sql, six.text_type) self.assertIsInstance(last_sql, six.text_type)