Fixed #28853 -- Updated connection.cursor() uses to use a context manager.

This commit is contained in:
Jon Dufresne 2017-11-28 05:12:28 -08:00 committed by Tim Graham
parent 3308085838
commit 7a6fbf36b1
18 changed files with 234 additions and 257 deletions

View File

@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection):
data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor()
try:
with self.connection.cursor() as cursor:
# In order to get the specific geometry type of the field,
# we introspect on the table definition using `DESCRIBE`.
cursor.execute('DESCRIBE %s' %
@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection):
field_type = OGRGeomType(typ).django
field_params = {}
break
finally:
cursor.close()
return field_type, field_params
def supports_spatial_index(self, cursor, table_name):

View File

@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection):
data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor()
try:
with self.connection.cursor() as cursor:
# Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information.
try:
cursor.execute(
@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection):
dim = dim.size()
if dim != 2:
field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params

View File

@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection):
# to query the PostgreSQL pg_type table corresponding to the
# PostGIS custom data types.
oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s'
cursor = self.connection.cursor()
try:
with self.connection.cursor() as cursor:
for field_type in field_types:
cursor.execute(oid_sql, (field_type[0],))
for result in cursor.fetchall():
postgis_types[result[0]] = field_type[1]
finally:
cursor.close()
return postgis_types
def get_field_type(self, data_type, description):
@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection):
PointField or a PolygonField). Thus, this routine queries the PostGIS
metadata tables to determine the geometry type.
"""
cursor = self.connection.cursor()
try:
with self.connection.cursor() as cursor:
try:
# First seeing if this geometry column is in the `geometry_columns`
cursor.execute('SELECT "coord_dimension", "srid", "type" '
@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection):
field_params['srid'] = srid
if dim != 2:
field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params

View File

@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
data_types_reverse = GeoFlexibleFieldLookupDict()
def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor()
try:
with self.connection.cursor() as cursor:
# Querying the `geometry_columns` table to get additional metadata.
cursor.execute('SELECT coord_dimension, srid, geometry_type '
'FROM geometry_columns '
@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
field_params['srid'] = srid
if (isinstance(dim, str) and 'Z' in dim) or dim == 3:
field_params['dim'] = 3
finally:
cursor.close()
return field_type, field_params
def get_constraints(self, cursor, table_name):

View File

@ -573,11 +573,10 @@ class BaseDatabaseWrapper:
Provide a cursor: with self.temporary_connection() as cursor: ...
"""
must_close = self.connection is None
cursor = self.cursor()
try:
with self.cursor() as cursor:
yield cursor
finally:
cursor.close()
if must_close:
self.close()

View File

@ -116,8 +116,7 @@ class BaseDatabaseIntrospection:
from django.db import router
sequence_list = []
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
for app_config in apps.get_app_configs():
for model in router.get_migratable_models(app_config, self.connection.alias):
if not model._meta.managed:

View File

@ -294,7 +294,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
"""
cursor = self.cursor()
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
@ -318,7 +318,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,

View File

@ -30,7 +30,7 @@ class DatabaseCreation(BaseDatabaseCreation):
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor()
with self._maindb_connection.cursor() as cursor:
if self._test_database_create():
try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
@ -96,9 +96,6 @@ class DatabaseCreation(BaseDatabaseCreation):
else:
print("Tests cancelled.")
sys.exit(1)
# Cursor must be closed before closing connection.
cursor.close()
self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME']
@ -175,7 +172,7 @@ class DatabaseCreation(BaseDatabaseCreation):
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close()
parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor()
with self._maindb_connection.cursor() as cursor:
if self._test_user_create():
if verbosity >= 1:
print('Destroying test user...')
@ -184,8 +181,6 @@ class DatabaseCreation(BaseDatabaseCreation):
if verbosity >= 1:
print('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity)
# Cursor must be closed before closing connection.
cursor.close()
self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):

View File

@ -237,7 +237,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
"""
cursor = self.cursor()
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:

View File

@ -322,7 +322,8 @@ class MigrationExecutor:
apps = after_state.apps
found_create_model_migration = False
found_add_field_migration = False
existing_table_names = self.connection.introspection.table_names(self.connection.cursor())
with self.connection.cursor() as cursor:
existing_table_names = self.connection.introspection.table_names(cursor)
# Make sure all create model and add field operations are done
for operation in migration.operations:
if isinstance(operation, migrations.CreateModel):

View File

@ -852,7 +852,7 @@ class TransactionTestCase(SimpleTestCase):
no_style(), conn.introspection.sequence_list())
if sql_list:
with transaction.atomic(using=db_name):
cursor = conn.cursor()
with conn.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)

View File

@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its
alias::
from django.db import connections
cursor = connections['my_db_alias'].cursor()
with connections['my_db_alias'].cursor() as cursor:
...
Limitations of multiple databases
=================================

View File

@ -279,7 +279,7 @@ object that allows you to retrieve a specific connection using its
alias::
from django.db import connections
cursor = connections['my_db_alias'].cursor()
with connections['my_db_alias'].cursor() as cursor:
# Your code here...
By default, the Python DB API will return results without their field names,

View File

@ -9,7 +9,7 @@ from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase):
def test_get_sequences(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual(
seqs,

View File

@ -44,7 +44,7 @@ class Tests(TestCase):
# Ensure the database default time zone is different than
# the time zone in new_connection.settings_dict. We can
# get the default time zone by reset & show.
cursor = new_connection.cursor()
with new_connection.cursor() as cursor:
cursor.execute("RESET TIMEZONE")
cursor.execute("SHOW TIMEZONE")
db_default_tz = cursor.fetchone()[0]
@ -59,10 +59,10 @@ class Tests(TestCase):
# time zone, run a query and rollback.
with self.settings(TIME_ZONE=new_tz):
new_connection.set_autocommit(False)
cursor = new_connection.cursor()
new_connection.rollback()
# Now let's see if the rollback rolled back the SET TIME ZONE.
with new_connection.cursor() as cursor:
cursor.execute("SHOW TIMEZONE")
tz = cursor.fetchone()[0]
self.assertEqual(new_tz, tz)

View File

@ -82,7 +82,7 @@ class LastExecutedQueryTest(TestCase):
# If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
# greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
# can hit the SQLITE_MAX_COLUMN limit (#26063).
cursor = connection.cursor()
with connection.cursor() as cursor:
sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
params = list(range(2001))
# This should not raise an exception.
@ -97,7 +97,7 @@ class EscapingChecks(TestCase):
"""
def test_parameter_escaping(self):
# '%s' escaping support for sqlite3 (#13648).
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("select strftime('%s', date('now'))")
response = cursor.fetchall()[0][0]
# response should be an non-zero integer

View File

@ -56,7 +56,7 @@ class LastExecutedQueryTest(TestCase):
last_executed_query should not raise an exception even if no previous
query has been run.
"""
cursor = connection.cursor()
with connection.cursor() as cursor:
connection.ops.last_executed_query(cursor, '', ())
def test_debug_sql(self):
@ -78,7 +78,7 @@ class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
cursor = connection.cursor()
with connection.cursor() as cursor:
query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
connection.introspection.table_name_converter('backends_square'),
connection.ops.quote_name('root'),
@ -133,8 +133,9 @@ class LongNameTest(TransactionTestCase):
'table': VLM._meta.db_table
},
]
cursor = connection.cursor()
for statement in connection.ops.sql_flush(no_style(), tables, sequences):
sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
with connection.cursor() as cursor:
for statement in sql_list:
cursor.execute(statement)
@ -146,8 +147,8 @@ class SequenceResetTest(TestCase):
Post.objects.create(id=10, name='1st post', text='hello world')
# Reset the sequences for the database
cursor = connection.cursor()
commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
with connection.cursor() as cursor:
for sql in commands:
cursor.execute(sql)
@ -192,12 +193,12 @@ class EscapingChecks(TestCase):
bare_select_suffix = connection.features.bare_select_suffix
def test_paramless_no_escaping(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("SELECT '%s'" + self.bare_select_suffix)
self.assertEqual(cursor.fetchall()[0][0], '%s')
def test_parameter_escaping(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
self.create_squares(args, 'format', True)
def create_squares(self, args, paramstyle, multiple):
cursor = connection.cursor()
opts = Square._meta
tbl = connection.introspection.table_name_converter(opts.db_table)
f1 = connection.ops.quote_name(opts.get_field('root').column)
@ -226,6 +226,7 @@ class BackendTestCase(TransactionTestCase):
query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
else:
raise ValueError("unsupported paramstyle in test")
with connection.cursor() as cursor:
if multiple:
cursor.executemany(query, args)
else:
@ -297,7 +298,7 @@ class BackendTestCase(TransactionTestCase):
Person(first_name="Clark", last_name="Kent").save()
opts2 = Person._meta
f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute(
'SELECT %s, %s FROM %s ORDER BY %s' % (
qn(f3.column),
@ -344,8 +345,8 @@ class BackendTestCase(TransactionTestCase):
def test_duplicate_table_error(self):
""" Creating an existing table returns a DatabaseError """
cursor = connection.cursor()
query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
with connection.cursor() as cursor:
with self.assertRaises(DatabaseError):
cursor.execute(query)

View File

@ -26,7 +26,7 @@ class DatabaseErrorWrapperTests(TestCase):
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
def test_reraising_backend_specific_database_exception(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
msg = 'table "X" does not exist'
with self.assertRaisesMessage(ProgrammingError, msg) as cm:
cursor.execute('DROP TABLE "X"')