Fixed #28853 -- Updated connection.cursor() uses to use a context manager.

This commit is contained in:
Jon Dufresne 2017-11-28 05:12:28 -08:00 committed by Tim Graham
parent 3308085838
commit 7a6fbf36b1
18 changed files with 234 additions and 257 deletions

View File

@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection):
data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField' data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# In order to get the specific geometry type of the field, # In order to get the specific geometry type of the field,
# we introspect on the table definition using `DESCRIBE`. # we introspect on the table definition using `DESCRIBE`.
cursor.execute('DESCRIBE %s' % cursor.execute('DESCRIBE %s' %
@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection):
field_type = OGRGeomType(typ).django field_type = OGRGeomType(typ).django
field_params = {} field_params = {}
break break
finally:
cursor.close()
return field_type, field_params return field_type, field_params
def supports_spatial_index(self, cursor, table_name): def supports_spatial_index(self, cursor, table_name):

View File

@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection):
data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField' data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information. # Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information.
try: try:
cursor.execute( cursor.execute(
@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection):
dim = dim.size() dim = dim.size()
if dim != 2: if dim != 2:
field_params['dim'] = dim field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params return field_type, field_params

View File

@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection):
# to query the PostgreSQL pg_type table corresponding to the # to query the PostgreSQL pg_type table corresponding to the
# PostGIS custom data types. # PostGIS custom data types.
oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s' oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s'
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
for field_type in field_types: for field_type in field_types:
cursor.execute(oid_sql, (field_type[0],)) cursor.execute(oid_sql, (field_type[0],))
for result in cursor.fetchall(): for result in cursor.fetchall():
postgis_types[result[0]] = field_type[1] postgis_types[result[0]] = field_type[1]
finally:
cursor.close()
return postgis_types return postgis_types
def get_field_type(self, data_type, description): def get_field_type(self, data_type, description):
@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection):
PointField or a PolygonField). Thus, this routine queries the PostGIS PointField or a PolygonField). Thus, this routine queries the PostGIS
metadata tables to determine the geometry type. metadata tables to determine the geometry type.
""" """
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
try: try:
# First seeing if this geometry column is in the `geometry_columns` # First seeing if this geometry column is in the `geometry_columns`
cursor.execute('SELECT "coord_dimension", "srid", "type" ' cursor.execute('SELECT "coord_dimension", "srid", "type" '
@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection):
field_params['srid'] = srid field_params['srid'] = srid
if dim != 2: if dim != 2:
field_params['dim'] = dim field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params return field_type, field_params

View File

@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
data_types_reverse = GeoFlexibleFieldLookupDict() data_types_reverse = GeoFlexibleFieldLookupDict()
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# Querying the `geometry_columns` table to get additional metadata. # Querying the `geometry_columns` table to get additional metadata.
cursor.execute('SELECT coord_dimension, srid, geometry_type ' cursor.execute('SELECT coord_dimension, srid, geometry_type '
'FROM geometry_columns ' 'FROM geometry_columns '
@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
field_params['srid'] = srid field_params['srid'] = srid
if (isinstance(dim, str) and 'Z' in dim) or dim == 3: if (isinstance(dim, str) and 'Z' in dim) or dim == 3:
field_params['dim'] = 3 field_params['dim'] = 3
finally:
cursor.close()
return field_type, field_params return field_type, field_params
def get_constraints(self, cursor, table_name): def get_constraints(self, cursor, table_name):

View File

@ -573,11 +573,10 @@ class BaseDatabaseWrapper:
Provide a cursor: with self.temporary_connection() as cursor: ... Provide a cursor: with self.temporary_connection() as cursor: ...
""" """
must_close = self.connection is None must_close = self.connection is None
cursor = self.cursor()
try: try:
with self.cursor() as cursor:
yield cursor yield cursor
finally: finally:
cursor.close()
if must_close: if must_close:
self.close() self.close()

View File

@ -116,8 +116,7 @@ class BaseDatabaseIntrospection:
from django.db import router from django.db import router
sequence_list = [] sequence_list = []
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
for app_config in apps.get_app_configs(): for app_config in apps.get_app_configs():
for model in router.get_migratable_models(app_config, self.connection.alias): for model in router.get_migratable_models(app_config, self.connection.alias):
if not model._meta.managed: if not model._meta.managed:

View File

@ -294,7 +294,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
""" """
cursor = self.cursor() with self.cursor() as cursor:
if table_names is None: if table_names is None:
table_names = self.introspection.table_names(cursor) table_names = self.introspection.table_names(cursor)
for table_name in table_names: for table_name in table_names:
@ -318,7 +318,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
for bad_row in cursor.fetchall(): for bad_row in cursor.fetchall():
raise utils.IntegrityError( raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid " "The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s." "foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% ( % (
table_name, bad_row[0], table_name, column_name, table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name, bad_row[1], referenced_table_name, referenced_column_name,

View File

@ -30,7 +30,7 @@ class DatabaseCreation(BaseDatabaseCreation):
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False): def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params() parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor() with self._maindb_connection.cursor() as cursor:
if self._test_database_create(): if self._test_database_create():
try: try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
@ -96,9 +96,6 @@ class DatabaseCreation(BaseDatabaseCreation):
else: else:
print("Tests cancelled.") print("Tests cancelled.")
sys.exit(1) sys.exit(1)
# Cursor must be closed before closing connection.
cursor.close()
self._maindb_connection.close() # done with main user -- test user and tablespaces created self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters) self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME'] return self.connection.settings_dict['NAME']
@ -175,7 +172,7 @@ class DatabaseCreation(BaseDatabaseCreation):
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close() self.connection.close()
parameters = self._get_test_db_params() parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor() with self._maindb_connection.cursor() as cursor:
if self._test_user_create(): if self._test_user_create():
if verbosity >= 1: if verbosity >= 1:
print('Destroying test user...') print('Destroying test user...')
@ -184,8 +181,6 @@ class DatabaseCreation(BaseDatabaseCreation):
if verbosity >= 1: if verbosity >= 1:
print('Destroying test database tables...') print('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
# Cursor must be closed before closing connection.
cursor.close()
self._maindb_connection.close() self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False): def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):

View File

@ -237,7 +237,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
""" """
cursor = self.cursor() with self.cursor() as cursor:
if table_names is None: if table_names is None:
table_names = self.introspection.table_names(cursor) table_names = self.introspection.table_names(cursor)
for table_name in table_names: for table_name in table_names:

View File

@ -322,7 +322,8 @@ class MigrationExecutor:
apps = after_state.apps apps = after_state.apps
found_create_model_migration = False found_create_model_migration = False
found_add_field_migration = False found_add_field_migration = False
existing_table_names = self.connection.introspection.table_names(self.connection.cursor()) with self.connection.cursor() as cursor:
existing_table_names = self.connection.introspection.table_names(cursor)
# Make sure all create model and add field operations are done # Make sure all create model and add field operations are done
for operation in migration.operations: for operation in migration.operations:
if isinstance(operation, migrations.CreateModel): if isinstance(operation, migrations.CreateModel):

View File

@ -852,7 +852,7 @@ class TransactionTestCase(SimpleTestCase):
no_style(), conn.introspection.sequence_list()) no_style(), conn.introspection.sequence_list())
if sql_list: if sql_list:
with transaction.atomic(using=db_name): with transaction.atomic(using=db_name):
cursor = conn.cursor() with conn.cursor() as cursor:
for sql in sql_list: for sql in sql_list:
cursor.execute(sql) cursor.execute(sql)

View File

@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its
alias:: alias::
from django.db import connections from django.db import connections
cursor = connections['my_db_alias'].cursor() with connections['my_db_alias'].cursor() as cursor:
...
Limitations of multiple databases Limitations of multiple databases
================================= =================================

View File

@ -279,7 +279,7 @@ object that allows you to retrieve a specific connection using its
alias:: alias::
from django.db import connections from django.db import connections
cursor = connections['my_db_alias'].cursor() with connections['my_db_alias'].cursor() as cursor:
# Your code here... # Your code here...
By default, the Python DB API will return results without their field names, By default, the Python DB API will return results without their field names,

View File

@ -9,7 +9,7 @@ from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase): class DatabaseSequenceTests(TestCase):
def test_get_sequences(self): def test_get_sequences(self):
cursor = connection.cursor() with connection.cursor() as cursor:
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual( self.assertEqual(
seqs, seqs,

View File

@ -44,7 +44,7 @@ class Tests(TestCase):
# Ensure the database default time zone is different than # Ensure the database default time zone is different than
# the time zone in new_connection.settings_dict. We can # the time zone in new_connection.settings_dict. We can
# get the default time zone by reset & show. # get the default time zone by reset & show.
cursor = new_connection.cursor() with new_connection.cursor() as cursor:
cursor.execute("RESET TIMEZONE") cursor.execute("RESET TIMEZONE")
cursor.execute("SHOW TIMEZONE") cursor.execute("SHOW TIMEZONE")
db_default_tz = cursor.fetchone()[0] db_default_tz = cursor.fetchone()[0]
@ -59,10 +59,10 @@ class Tests(TestCase):
# time zone, run a query and rollback. # time zone, run a query and rollback.
with self.settings(TIME_ZONE=new_tz): with self.settings(TIME_ZONE=new_tz):
new_connection.set_autocommit(False) new_connection.set_autocommit(False)
cursor = new_connection.cursor()
new_connection.rollback() new_connection.rollback()
# Now let's see if the rollback rolled back the SET TIME ZONE. # Now let's see if the rollback rolled back the SET TIME ZONE.
with new_connection.cursor() as cursor:
cursor.execute("SHOW TIMEZONE") cursor.execute("SHOW TIMEZONE")
tz = cursor.fetchone()[0] tz = cursor.fetchone()[0]
self.assertEqual(new_tz, tz) self.assertEqual(new_tz, tz)

View File

@ -82,7 +82,7 @@ class LastExecutedQueryTest(TestCase):
# If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
# greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
# can hit the SQLITE_MAX_COLUMN limit (#26063). # can hit the SQLITE_MAX_COLUMN limit (#26063).
cursor = connection.cursor() with connection.cursor() as cursor:
sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
params = list(range(2001)) params = list(range(2001))
# This should not raise an exception. # This should not raise an exception.
@ -97,7 +97,7 @@ class EscapingChecks(TestCase):
""" """
def test_parameter_escaping(self): def test_parameter_escaping(self):
# '%s' escaping support for sqlite3 (#13648). # '%s' escaping support for sqlite3 (#13648).
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("select strftime('%s', date('now'))") cursor.execute("select strftime('%s', date('now'))")
response = cursor.fetchall()[0][0] response = cursor.fetchall()[0][0]
# response should be an non-zero integer # response should be an non-zero integer

View File

@ -56,7 +56,7 @@ class LastExecutedQueryTest(TestCase):
last_executed_query should not raise an exception even if no previous last_executed_query should not raise an exception even if no previous
query has been run. query has been run.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
connection.ops.last_executed_query(cursor, '', ()) connection.ops.last_executed_query(cursor, '', ())
def test_debug_sql(self): def test_debug_sql(self):
@ -78,7 +78,7 @@ class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self): def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
cursor = connection.cursor() with connection.cursor() as cursor:
query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
connection.introspection.table_name_converter('backends_square'), connection.introspection.table_name_converter('backends_square'),
connection.ops.quote_name('root'), connection.ops.quote_name('root'),
@ -133,8 +133,9 @@ class LongNameTest(TransactionTestCase):
'table': VLM._meta.db_table 'table': VLM._meta.db_table
}, },
] ]
cursor = connection.cursor() sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
for statement in connection.ops.sql_flush(no_style(), tables, sequences): with connection.cursor() as cursor:
for statement in sql_list:
cursor.execute(statement) cursor.execute(statement)
@ -146,8 +147,8 @@ class SequenceResetTest(TestCase):
Post.objects.create(id=10, name='1st post', text='hello world') Post.objects.create(id=10, name='1st post', text='hello world')
# Reset the sequences for the database # Reset the sequences for the database
cursor = connection.cursor()
commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post]) commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
with connection.cursor() as cursor:
for sql in commands: for sql in commands:
cursor.execute(sql) cursor.execute(sql)
@ -192,12 +193,12 @@ class EscapingChecks(TestCase):
bare_select_suffix = connection.features.bare_select_suffix bare_select_suffix = connection.features.bare_select_suffix
def test_paramless_no_escaping(self): def test_paramless_no_escaping(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT '%s'" + self.bare_select_suffix) cursor.execute("SELECT '%s'" + self.bare_select_suffix)
self.assertEqual(cursor.fetchall()[0][0], '%s') self.assertEqual(cursor.fetchall()[0][0], '%s')
def test_parameter_escaping(self): def test_parameter_escaping(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
self.assertEqual(cursor.fetchall()[0], ('%', '%d')) self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
self.create_squares(args, 'format', True) self.create_squares(args, 'format', True)
def create_squares(self, args, paramstyle, multiple): def create_squares(self, args, paramstyle, multiple):
cursor = connection.cursor()
opts = Square._meta opts = Square._meta
tbl = connection.introspection.table_name_converter(opts.db_table) tbl = connection.introspection.table_name_converter(opts.db_table)
f1 = connection.ops.quote_name(opts.get_field('root').column) f1 = connection.ops.quote_name(opts.get_field('root').column)
@ -226,6 +226,7 @@ class BackendTestCase(TransactionTestCase):
query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2) query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
else: else:
raise ValueError("unsupported paramstyle in test") raise ValueError("unsupported paramstyle in test")
with connection.cursor() as cursor:
if multiple: if multiple:
cursor.executemany(query, args) cursor.executemany(query, args)
else: else:
@ -297,7 +298,7 @@ class BackendTestCase(TransactionTestCase):
Person(first_name="Clark", last_name="Kent").save() Person(first_name="Clark", last_name="Kent").save()
opts2 = Person._meta opts2 = Person._meta
f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute( cursor.execute(
'SELECT %s, %s FROM %s ORDER BY %s' % ( 'SELECT %s, %s FROM %s ORDER BY %s' % (
qn(f3.column), qn(f3.column),
@ -344,8 +345,8 @@ class BackendTestCase(TransactionTestCase):
def test_duplicate_table_error(self): def test_duplicate_table_error(self):
""" Creating an existing table returns a DatabaseError """ """ Creating an existing table returns a DatabaseError """
cursor = connection.cursor()
query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
with connection.cursor() as cursor:
with self.assertRaises(DatabaseError): with self.assertRaises(DatabaseError):
cursor.execute(query) cursor.execute(query)

View File

@ -26,7 +26,7 @@ class DatabaseErrorWrapperTests(TestCase):
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test') @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
def test_reraising_backend_specific_database_exception(self): def test_reraising_backend_specific_database_exception(self):
cursor = connection.cursor() with connection.cursor() as cursor:
msg = 'table "X" does not exist' msg = 'table "X" does not exist'
with self.assertRaisesMessage(ProgrammingError, msg) as cm: with self.assertRaisesMessage(ProgrammingError, msg) as cm:
cursor.execute('DROP TABLE "X"') cursor.execute('DROP TABLE "X"')