Merge pull request #2154 from manfre/close-cursors

Fixed #21751 -- Explicitly closed cursors.
This commit is contained in:
Aymeric Augustin 2014-02-02 10:37:27 -08:00
commit 54bfa4caab
33 changed files with 725 additions and 641 deletions

View File

@ -11,7 +11,7 @@ class PostGISCreation(DatabaseCreation):
@cached_property
def template_postgis(self):
template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
if cursor.fetchone():
return template_postgis
@ -88,7 +88,7 @@ class PostGISCreation(DatabaseCreation):
# Connect to the test database in order to create the postgis extension
self.connection.close()
self.connection.settings_dict["NAME"] = test_database_name
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
cursor.connection.commit()

View File

@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
self.connection.cursor()
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name

View File

@ -33,7 +33,7 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
if sequence_sql:
if verbosity >= 2:
print("Resetting sequence")
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
for command in sequence_sql:
cursor.execute(command)

View File

@ -59,8 +59,8 @@ class DatabaseCache(BaseDatabaseCache):
self.validate_key(key)
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key, value, expires FROM %s "
"WHERE cache_key = %%s" % table, [key])
row = cursor.fetchone()
@ -75,7 +75,7 @@ class DatabaseCache(BaseDatabaseCache):
expires = typecast_timestamp(str(expires))
if expires < now:
db = router.db_for_write(self.cache_model_class)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s "
"WHERE cache_key = %%s" % table, [key])
return default
@ -96,8 +96,8 @@ class DatabaseCache(BaseDatabaseCache):
timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0]
now = timezone.now()
@ -152,8 +152,8 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
def has_key(self, key, version=None):
@ -162,13 +162,14 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
if settings.USE_TZ:
now = datetime.utcnow()
else:
now = datetime.now()
now = now.replace(microsecond=0)
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key FROM %s "
"WHERE cache_key = %%s and expires > %%s" % table,
[key, connections[db].ops.value_to_db_datetime(now)])
@ -197,7 +198,7 @@ class DatabaseCache(BaseDatabaseCache):
def clear(self):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute('DELETE FROM %s' % table)

View File

@ -72,7 +72,7 @@ class Command(BaseCommand):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');')
with transaction.commit_on_success_unless_managed():
curs = connection.cursor()
with connection.cursor() as curs:
try:
curs.execute("\n".join(full_statement))
except DatabaseError as e:

View File

@ -64,7 +64,7 @@ Are you sure you want to do this?
if confirm == 'yes':
try:
with transaction.commit_on_success_unless_managed():
cursor = connection.cursor()
with connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
except Exception as e:

View File

@ -37,7 +37,7 @@ class Command(NoArgsCommand):
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
cursor = connection.cursor()
with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order"

View File

@ -100,10 +100,9 @@ class Command(BaseCommand):
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write("Resetting sequences\n")
cursor = connection.cursor()
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
cursor.close()
if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count:

View File

@ -171,8 +171,9 @@ class Command(BaseCommand):
"Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor()
try:
# Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names()
tables = connection.introspection.table_names(cursor)
seen_models = connection.introspection.installed_models(tables)
created_models = set()
pending_references = {}
@ -226,10 +227,12 @@ class Command(BaseCommand):
# We force a commit here, as that was the previous behaviour.
# If you can prove we don't need this, remove it.
transaction.set_dirty(using=connection.alias)
finally:
cursor.close()
# The connection may have been closed by a syncdb handler.
cursor = connection.cursor()
try:
# Install custom SQL for the app (but only if this
# is a model we've just created)
if self.verbosity >= 1:
@ -270,6 +273,8 @@ class Command(BaseCommand):
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
finally:
cursor.close()
# Load initial_data fixtures (unless that has been disabled)
if self.load_initial_data:

View File

@ -67,6 +67,7 @@ def sql_delete(app_config, style, connection):
except Exception:
cursor = None
try:
# Figure out which tables already exist
if cursor:
table_names = connection.introspection.table_names(cursor)
@ -93,7 +94,7 @@ def sql_delete(app_config, style, connection):
for model in app_models:
if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
finally:
# Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues.
if cursor:

View File

@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
self.cursor().execute(self.ops.savepoint_create_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
self.cursor().execute(self.ops.savepoint_rollback_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
self.cursor().execute(self.ops.savepoint_commit_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
@ -688,7 +691,7 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to
# fail.
self.connection.enter_transaction_management()
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.commit()
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases.
"""
if cursor is None:
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor):

View File

@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
self.connection.cursor()
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name
@ -406,7 +405,7 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name
# Create the test database and connect to it.
cursor = self._nodb_connection.cursor()
with self._nodb_connection.cursor() as cursor:
try:
cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
@ -461,7 +460,7 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = self._nodb_connection.cursor()
with self._nodb_connection.cursor() as cursor:
# Wait to avoid "database is being accessed by other users" errors.
time.sleep(1)
cursor.execute("DROP DATABASE %s"

View File

@ -180,7 +180,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
# This command is MySQL specific; the second column
# will tell you the default table type of the created
@ -207,7 +207,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
# Test if the time zone definitions are installed.
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None
@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn
def init_connection_state(self):
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
def create_cursor(self):
cursor = self.connection.cursor()

View File

@ -353,7 +353,7 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur.
self.connection.cursor()
with self.connection.cursor():
return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):

View File

@ -149,7 +149,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz:
cursor = self.connection.cursor()
try:
cursor.execute(self.ops.set_time_zone_sql(), [tz])
finally:
cursor.close()
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():

View File

@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'):
return connection.server_version
else:
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("SELECT version()")
return _parse_version(cursor.fetchone()[0])

View File

@ -86,13 +86,12 @@ class BaseDatabaseSchemaEditor(object):
"""
Executes the given SQL statement, with optional parameters.
"""
# Get the cursor
cursor = self.connection.cursor()
# Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else:
with self.connection.cursor() as cursor:
cursor.execute(sql, params)
def quote_name(self, name):
@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions
"""
column_names = list(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:

View File

@ -122,7 +122,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check
whether the call works.
"""
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')

View File

@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty
from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
from django.db.models.sql.constants import CURSOR
from django.db.models import sql
from django.utils.functional import partition
from django.utils import six
@ -574,7 +575,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
with transaction.commit_on_success_unless_managed(using=self.db):
rows = query.get_compiler(self.db).execute_sql(None)
rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None
return rows
update.alters_data = True
@ -591,7 +592,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values)
self._result_cache = None
return query.get_compiler(self.db).execute_sql(None)
return query.get_compiler(self.db).execute_sql(CURSOR)
_update.alters_data = True
_update.queryset_only = False
@ -1521,6 +1522,7 @@ class RawQuerySet(object):
query = iter(self.query)
try:
# Find out which columns are model's fields, and which ones should be
# annotated to the model.
for pos, column in enumerate(self.columns):
@ -1569,6 +1571,10 @@ class RawQuerySet(object):
instance._state.adding = False
yield instance
finally:
# Done iterating the Query. If it has its own cursor, close it.
if hasattr(self.query, 'cursor') and self.query.cursor:
self.query.cursor.close()
def __repr__(self):
text = self.raw_query

View File

@ -1,12 +1,13 @@
import datetime
import sys
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@ -762,6 +763,8 @@ class SQLCompiler(object):
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
"""
if not result_type:
result_type = NO_RESULTS
try:
sql, params = self.as_sql()
if not sql:
@ -773,27 +776,44 @@ class SQLCompiler(object):
return
cursor = self.connection.cursor()
try:
cursor.execute(sql, params)
except Exception:
cursor.close()
raise
if not result_type:
if result_type == CURSOR:
# Caller didn't specify a result_type, so just give them back the
# cursor to process (and close).
return cursor
if result_type == SINGLE:
try:
if self.ordering_aliases:
return cursor.fetchone()[:-len(self.ordering_aliases)]
return cursor.fetchone()
finally:
# done with the cursor
cursor.close()
if result_type == NO_RESULTS:
cursor.close()
return
# The MULTI case.
if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value)
else:
result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
result = cursor_iter(cursor,
self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads:
try:
# If we are using non-chunked reads, we return the same data
# structure as normally, but ensure it is all read into memory
# before going any further.
return list(result)
finally:
# done with the cursor
cursor.close()
return result
def as_subquery_condition(self, alias, columns, qn):
@ -889,7 +909,7 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor):
@ -970,12 +990,15 @@ class SQLUpdateCompiler(SQLCompiler):
related queries are not available.
"""
cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
try:
rows = cursor.rowcount if cursor else 0
is_empty = cursor is None
del cursor
finally:
if cursor:
cursor.close()
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty:
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows
@ -1111,6 +1134,19 @@ class SQLDateTimeCompiler(SQLCompiler):
yield datetime
def cursor_iter(cursor, sentinel):
"""
Yields blocks of rows from a cursor and ensures the cursor is closed when
done.
"""
try:
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield rows
finally:
cursor.close()
def order_modified_iter(cursor, trim, sentinel):
"""
Yields blocks of rows from a cursor. We use this iterator in the special
@ -1118,6 +1154,9 @@ def order_modified_iter(cursor, trim, sentinel):
requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid.
"""
try:
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield [r[:-trim] for r in rows]
finally:
cursor.close()

View File

@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field')
# How many results to expect from a cursor.execute call
MULTI = 'multi'
SINGLE = 'single'
CURSOR = 'cursor'
NO_RESULTS = 'no results'
ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
ORDER_DIR = {

View File

@ -8,7 +8,7 @@ from django.db import connections
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.utils import six
@ -30,7 +30,7 @@ class DeleteQuery(Query):
def do_query(self, table, where, using):
self.tables = [table]
self.where = where
self.get_compiler(using).execute_sql(None)
self.get_compiler(using).execute_sql(NO_RESULTS)
def delete_batch(self, pk_list, using, field=None):
"""
@ -82,7 +82,7 @@ class DeleteQuery(Query):
values = innerq
self.where = self.where_class()
self.add_q(Q(pk__in=values))
self.get_compiler(using).execute_sql(None)
self.get_compiler(using).execute_sql(NO_RESULTS)
class UpdateQuery(Query):
@ -116,7 +116,7 @@ class UpdateQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class()
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
self.get_compiler(using).execute_sql(None)
self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
"""

View File

@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper
from django.db.models import Sum, Avg, Variance, StdDev
from django.db.models.fields import (AutoField, DateField, DateTimeField,
DecimalField, IntegerField, TimeField)
from django.db.models.sql.constants import CURSOR
from django.db.utils import ConnectionHandler
from django.test import (TestCase, TransactionTestCase, override_settings,
skipUnlessDBFeature, skipIfDBFeature)
@ -58,7 +59,7 @@ class OracleChecks(unittest.TestCase):
# stored procedure through our cursor wrapper.
from django.db.backends.oracle.base import convert_unicode
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!')])
@ -69,7 +70,7 @@ class OracleChecks(unittest.TestCase):
# as query parameters.
from django.db.backends.oracle.base import Database
cursor = connection.cursor()
with connection.cursor() as cursor:
var = cursor.var(Database.STRING)
cursor.execute("BEGIN %s := 'X'; END; ", [var])
self.assertEqual(var.getvalue(), 'X')
@ -79,21 +80,21 @@ class OracleChecks(unittest.TestCase):
def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly
c = connection.cursor()
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join(six.text_type(x) for x in xrange(4000))
c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext')
row = c.fetchone()
cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
cursor.execute('SELECT text FROM ltext')
row = cursor.fetchone()
self.assertEqual(long_str, row[0].read())
c.execute('DROP TABLE ltext')
cursor.execute('DROP TABLE ltext')
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle connection semantics")
def test_client_encoding(self):
# If the backend is Oracle, test that the client encoding is set
# correctly. This was broken under Cygwin prior to r14781.
connection.cursor() # Ensure the connection is initialized.
self.connection.ensure_connection()
self.assertEqual(connection.connection.encoding, "UTF-8")
self.assertEqual(connection.connection.nencoding, "UTF-8")
@ -102,12 +103,12 @@ class OracleChecks(unittest.TestCase):
def test_order_of_nls_parameters(self):
# an 'almost right' datetime should work with configured
# NLS parameters as per #18465.
c = connection.cursor()
with connection.cursor() as cursor:
query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
# Test that the query succeeds without errors - pre #18465 this
# wasn't the case.
c.execute(query)
self.assertEqual(c.fetchone()[0], 1)
cursor.execute(query)
self.assertEqual(cursor.fetchone()[0], 1)
class SQLiteTests(TestCase):
@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase):
"""
persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1})
sql, params = persons.query.sql_with_params()
cursor = persons.query.get_compiler('default').execute_sql(None)
cursor = persons.query.get_compiler('default').execute_sql(CURSOR)
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
self.assertIsInstance(last_sql, six.text_type)
@ -327,6 +328,12 @@ class PostgresVersionTest(TestCase):
def fetchone(self):
return ["PostgreSQL 8.3"]
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class OlderConnectionMock(object):
"Mock of psycopg2 (< 2.0.12) connection"
def cursor(self):

View File

@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
management.call_command('createcachetable', verbosity=0, interactive=False)
def drop_table(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
table_name = connection.ops.quote_name('test cache table')
cursor.execute('DROP TABLE %s' % table_name)
cursor.close()
def test_zero_cull(self):
self._perform_cull_test(caches['zero_cull'], 50, 18)

View File

@ -30,7 +30,7 @@ class Article(models.Model):
database query for the sake of demonstration.
"""
from django.db import connection
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("""
SELECT id, headline, pub_date
FROM custom_methods_article

View File

@ -28,7 +28,7 @@ class InitialSQLTests(TestCase):
connection = connections[DEFAULT_DB_ALIAS]
custom_sql = custom_sql_for_model(Simple, no_style(), connection)
self.assertEqual(len(custom_sql), 9)
cursor = connection.cursor()
with connection.cursor() as cursor:
for sql in custom_sql:
cursor.execute(sql)
self.assertEqual(Simple.objects.count(), 9)

View File

@ -23,7 +23,7 @@ class IntrospectionTests(TestCase):
"'%s' isn't in table_list()." % Article._meta.db_table)
def test_django_table_names(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names()
cursor.execute("DROP TABLE django_ixn_test_table;")
@ -32,7 +32,7 @@ class IntrospectionTests(TestCase):
def test_django_table_names_retval_type(self):
# Ticket #15216
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names(only_existing=True)
@ -53,13 +53,13 @@ class IntrospectionTests(TestCase):
'Reporter sequence not found in sequence_list()')
def test_get_table_description_names(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual([r[0] for r in desc],
[f.column for f in Reporter._meta.fields])
def test_get_table_description_types(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
# The MySQL exception is due to the cursor.description returning the same constant for
# text and blob columns. TODO: use information_schema database to retrieve the proper
@ -75,7 +75,7 @@ class IntrospectionTests(TestCase):
# inspect the length of character columns).
@expectedFailureOnOracle
def test_get_table_description_col_lengths(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'],
@ -87,7 +87,7 @@ class IntrospectionTests(TestCase):
# so its idea about null_ok in cursor.description is different from ours.
@skipIfDBFeature('interprets_empty_strings_as_nulls')
def test_get_table_description_nullable(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[6] for r in desc],
@ -97,14 +97,14 @@ class IntrospectionTests(TestCase):
# Regression test for #9991 - 'real' types in postgres
@skipUnlessDBFeature('has_real_datatype')
def test_postgresql_real_type(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
cursor.execute('DROP TABLE django_ixn_real_test_table;')
self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
def test_get_relations(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# Older versions of MySQL don't have the chops to report on this stuff,
@ -117,7 +117,7 @@ class IntrospectionTests(TestCase):
@skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual(
set(key_columns),
@ -125,12 +125,12 @@ class IntrospectionTests(TestCase):
('response_to_id', Article._meta.db_table, 'id')]))
def test_get_primary_key_column(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, 'id')
def test_get_indexes(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
@ -139,7 +139,7 @@ class IntrospectionTests(TestCase):
Test that multicolumn indexes are not included in the introspection
results.
"""
cursor = connection.cursor()
with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
self.assertNotIn('first_name', indexes)
self.assertIn('id', indexes)

View File

@ -9,30 +9,37 @@ class MigrationTestBase(TransactionTestCase):
available_apps = ["migrations"]
def get_table_description(self, table):
with connection.cursor() as cursor:
return connection.introspection.get_table_description(cursor, table)
def assertTableExists(self, table):
self.assertIn(table, connection.introspection.get_table_list(connection.cursor()))
with connection.cursor() as cursor:
self.assertIn(table, connection.introspection.get_table_list(cursor))
def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
with connection.cursor() as cursor:
self.assertNotIn(table, connection.introspection.get_table_list(cursor))
def assertColumnExists(self, table, column):
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
self.assertIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNotExists(self, table, column):
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True)
self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
def assertColumnNotNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False)
self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
def assertIndexExists(self, table, columns, value=True):
with connection.cursor() as cursor:
self.assertEqual(
value,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), table).values()
for c in connection.introspection.get_constraints(cursor, table).values()
if c['columns'] == list(columns)
),
)

View File

@ -19,7 +19,7 @@ class OperationTests(MigrationTestBase):
Creates a test model state and database table.
"""
# Delete the tables if they already exist
cursor = connection.cursor()
with connection.cursor() as cursor:
try:
cursor.execute("DROP TABLE %s_pony" % app_label)
except:
@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
# Test the database alteration
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
def assertIdTypeEqualsFkType(self):
with connection.cursor() as cursor:
id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# Test the database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
def test_rename_field(self):
"""
@ -400,7 +400,7 @@ class OperationTests(MigrationTestBase):
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")

View File

@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
# request_finished signal.
response = self.client.get('/')
# Make sure there is an open connection
connection.cursor()
self.connection.ensure_connection()
connection.enter_transaction_management()
signals.request_finished.send(sender=response._handler_class)
self.assertEqual(len(connection.transaction_state), 0)

View File

@ -37,7 +37,7 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self):
"Deletes all model tables for our models for a clean test environment"
cursor = connection.cursor()
with connection.cursor() as cursor:
connection.disable_constraint_checking()
table_names = connection.introspection.table_names(cursor)
for model in self.models:
@ -61,7 +61,7 @@ class SchemaTests(TransactionTestCase):
connection.enable_constraint_checking()
def column_classes(self, model):
cursor = connection.cursor()
with connection.cursor() as cursor:
columns = dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description(
@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
raise DatabaseError("Table does not exist (empty pragma)")
return columns
def get_indexes(self, table):
"""
Get the indexes on the table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_indexes(cursor, table)
def get_constraints(self, table):
"""
Get the constraints on a table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
# Tests
def test_creation_deletion(self):
@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
strict=True,
)
# Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
constraints = self.get_constraints(Book._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(TagM2MTest)
editor.create_model(UniqueTest)
# Ensure the M2M exists and points to TagM2MTest
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
# Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
# Ensure the new M2M exists and points to UniqueTest
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
constraints = self.get_constraints(new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor:
editor.create_model(Author)
# Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
new_field,
strict=True,
)
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found")
@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
Author._meta.get_field_by_name("height")[0],
strict=True,
)
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values()
for c in self.get_constraints("schema_tagindexed").values()
if c['columns'] == ["slug", "title"]
),
)
@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the right index
self.assertIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Alter to remove the index
new_field = CharField(max_length=100, db_index=False)
@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has no index
self.assertNotIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Alter to re-add the index
with connection.schema_editor() as editor:
@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the index again
self.assertIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor:
@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertIn(
"slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Remove the unique, check the index goes with it
new_field2 = CharField(max_length=20, unique=False)
@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertNotIn(
"slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
def test_primary_key(self):
@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(Tag)
# Ensure the table is there and has the right PK
self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'],
self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
)
# Alter to change the PK
new_field = SlugField(primary_key=True)
@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
# Ensure the PK changed
self.assertNotIn(
'id',
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table),
self.get_indexes(Tag._meta.db_table),
)
self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'],
self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
)
def test_context_manager_exit(self):
@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has an index on the column
self.assertIn(
column_name,
connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table),
self.get_indexes(BookWithLongName._meta.db_table),
)
def test_creation_deletion_reserved_names(self):

View File

@ -202,7 +202,8 @@ class AtomicTests(TransactionTestCase):
# trigger a database error inside an inner atomic without savepoint
with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False):
connection.cursor().execute(
with connection.cursor() as cursor:
cursor.execute(
"SELECT no_such_col FROM transactions_reporter")
# prevent atomic from rolling back since we're recovering manually
self.assertTrue(transaction.get_rollback())
@ -534,7 +535,7 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
available_apps = ['transactions']
def execute_bad_sql(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
@skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
"""
with self.assertRaises(IntegrityError):
with transaction.commit_on_success():
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
transaction.rollback()

View File

@ -54,7 +54,7 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
@commit_on_success
def raw_sql():
"Write a record using raw sql under a commit_on_success decorator"
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
raw_sql()
@ -143,7 +143,7 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
(reference). All this under commit_on_success, so the second insert should
be committed.
"""
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
transaction.rollback()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")