Fixed #28853 -- Updated connection.cursor() uses to use a context manager.

This commit is contained in:
Jon Dufresne 2017-11-28 05:12:28 -08:00 committed by Tim Graham
parent 3308085838
commit 7a6fbf36b1
18 changed files with 234 additions and 257 deletions

View File

@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection):
data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField' data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# In order to get the specific geometry type of the field, # In order to get the specific geometry type of the field,
# we introspect on the table definition using `DESCRIBE`. # we introspect on the table definition using `DESCRIBE`.
cursor.execute('DESCRIBE %s' % cursor.execute('DESCRIBE %s' %
@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection):
field_type = OGRGeomType(typ).django field_type = OGRGeomType(typ).django
field_params = {} field_params = {}
break break
finally:
cursor.close()
return field_type, field_params return field_type, field_params
def supports_spatial_index(self, cursor, table_name): def supports_spatial_index(self, cursor, table_name):

View File

@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection):
data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField' data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField'
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information. # Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information.
try: try:
cursor.execute( cursor.execute(
@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection):
dim = dim.size() dim = dim.size()
if dim != 2: if dim != 2:
field_params['dim'] = dim field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params return field_type, field_params

View File

@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection):
# to query the PostgreSQL pg_type table corresponding to the # to query the PostgreSQL pg_type table corresponding to the
# PostGIS custom data types. # PostGIS custom data types.
oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s' oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s'
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
for field_type in field_types: for field_type in field_types:
cursor.execute(oid_sql, (field_type[0],)) cursor.execute(oid_sql, (field_type[0],))
for result in cursor.fetchall(): for result in cursor.fetchall():
postgis_types[result[0]] = field_type[1] postgis_types[result[0]] = field_type[1]
finally:
cursor.close()
return postgis_types return postgis_types
def get_field_type(self, data_type, description): def get_field_type(self, data_type, description):
@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection):
PointField or a PolygonField). Thus, this routine queries the PostGIS PointField or a PolygonField). Thus, this routine queries the PostGIS
metadata tables to determine the geometry type. metadata tables to determine the geometry type.
""" """
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
try: try:
# First seeing if this geometry column is in the `geometry_columns` # First seeing if this geometry column is in the `geometry_columns`
cursor.execute('SELECT "coord_dimension", "srid", "type" ' cursor.execute('SELECT "coord_dimension", "srid", "type" '
@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection):
field_params['srid'] = srid field_params['srid'] = srid
if dim != 2: if dim != 2:
field_params['dim'] = dim field_params['dim'] = dim
finally:
cursor.close()
return field_type, field_params return field_type, field_params

View File

@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
data_types_reverse = GeoFlexibleFieldLookupDict() data_types_reverse = GeoFlexibleFieldLookupDict()
def get_geometry_type(self, table_name, geo_col): def get_geometry_type(self, table_name, geo_col):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
try:
# Querying the `geometry_columns` table to get additional metadata. # Querying the `geometry_columns` table to get additional metadata.
cursor.execute('SELECT coord_dimension, srid, geometry_type ' cursor.execute('SELECT coord_dimension, srid, geometry_type '
'FROM geometry_columns ' 'FROM geometry_columns '
@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
field_params['srid'] = srid field_params['srid'] = srid
if (isinstance(dim, str) and 'Z' in dim) or dim == 3: if (isinstance(dim, str) and 'Z' in dim) or dim == 3:
field_params['dim'] = 3 field_params['dim'] = 3
finally:
cursor.close()
return field_type, field_params return field_type, field_params
def get_constraints(self, cursor, table_name): def get_constraints(self, cursor, table_name):

View File

@ -573,11 +573,10 @@ class BaseDatabaseWrapper:
Provide a cursor: with self.temporary_connection() as cursor: ... Provide a cursor: with self.temporary_connection() as cursor: ...
""" """
must_close = self.connection is None must_close = self.connection is None
cursor = self.cursor()
try: try:
yield cursor with self.cursor() as cursor:
yield cursor
finally: finally:
cursor.close()
if must_close: if must_close:
self.close() self.close()

View File

@ -116,21 +116,20 @@ class BaseDatabaseIntrospection:
from django.db import router from django.db import router
sequence_list = [] sequence_list = []
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
for app_config in apps.get_app_configs():
for app_config in apps.get_app_configs(): for model in router.get_migratable_models(app_config, self.connection.alias):
for model in router.get_migratable_models(app_config, self.connection.alias): if not model._meta.managed:
if not model._meta.managed: continue
continue if model._meta.swapped:
if model._meta.swapped: continue
continue sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields)) for f in model._meta.local_many_to_many:
for f in model._meta.local_many_to_many: # If this is an m2m using an intermediate table,
# If this is an m2m using an intermediate table, # we don't need to reset the sequence.
# we don't need to reset the sequence. if f.remote_field.through is None:
if f.remote_field.through is None: sequence = self.get_sequences(cursor, f.m2m_db_table())
sequence = self.get_sequences(cursor, f.m2m_db_table()) sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()): def get_sequences(self, cursor, table_name, table_fields=()):

View File

@ -294,36 +294,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
""" """
cursor = self.cursor() with self.cursor() as cursor:
if table_names is None: if table_names is None:
table_names = self.introspection.table_names(cursor) table_names = self.introspection.table_names(cursor)
for table_name in table_names: for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name: if not primary_key_column_name:
continue continue
key_columns = self.introspection.get_key_columns(cursor, table_name) key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns: for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute( cursor.execute(
""" """
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`) ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" % ( """ % (
primary_key_column_name, column_name, table_name, primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name, referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name, column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
) )
) )
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self): def is_usable(self):
try: try:

View File

@ -30,75 +30,72 @@ class DatabaseCreation(BaseDatabaseCreation):
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False): def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params() parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor() with self._maindb_connection.cursor() as cursor:
if self._test_database_create(): if self._test_database_create():
try: try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e: except Exception as e:
if 'ORA-01543' not in str(e): if 'ORA-01543' not in str(e):
# All errors except "tablespace already exists" cancel tests # All errors except "tablespace already exists" cancel tests
sys.stderr.write("Got an error creating the test database: %s\n" % e) sys.stderr.write("Got an error creating the test database: %s\n" % e)
sys.exit(2) sys.exit(2)
if not autoclobber: if not autoclobber:
confirm = input( confirm = input(
"It appears the test database, %s, already exists. " "It appears the test database, %s, already exists. "
"Type 'yes' to delete it, or 'no' to cancel: " % parameters['user']) "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes': if autoclobber or confirm == 'yes':
if verbosity >= 1: if verbosity >= 1:
print("Destroying old test database for alias '%s'..." % self.connection.alias) print("Destroying old test database for alias '%s'..." % self.connection.alias)
try: try:
self._execute_test_db_destruction(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
except DatabaseError as e: except DatabaseError as e:
if 'ORA-29857' in str(e): if 'ORA-29857' in str(e):
self._handle_objects_preventing_db_destruction(cursor, parameters, self._handle_objects_preventing_db_destruction(cursor, parameters,
verbosity, autoclobber) verbosity, autoclobber)
else: else:
# Ran into a database error that isn't about leftover objects in the tablespace # Ran into a database error that isn't about leftover objects in the tablespace
sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
sys.exit(2)
except Exception as e:
sys.stderr.write("Got an error destroying the old test database: %s\n" % e) sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
sys.exit(2) sys.exit(2)
except Exception as e: try:
sys.stderr.write("Got an error destroying the old test database: %s\n" % e) self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
sys.exit(2) except Exception as e:
try: sys.stderr.write("Got an error recreating the test database: %s\n" % e)
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) sys.exit(2)
except Exception as e: else:
sys.stderr.write("Got an error recreating the test database: %s\n" % e) print("Tests cancelled.")
sys.exit(2) sys.exit(1)
else:
print("Tests cancelled.")
sys.exit(1)
if self._test_user_create(): if self._test_user_create():
if verbosity >= 1: if verbosity >= 1:
print("Creating test user...") print("Creating test user...")
try: try:
self._create_test_user(cursor, parameters, verbosity, keepdb) self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e: except Exception as e:
if 'ORA-01920' not in str(e): if 'ORA-01920' not in str(e):
# All errors except "user already exists" cancel tests # All errors except "user already exists" cancel tests
sys.stderr.write("Got an error creating the test user: %s\n" % e) sys.stderr.write("Got an error creating the test user: %s\n" % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print("Destroying old test user...")
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print("Creating test user...")
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2) sys.exit(2)
else: if not autoclobber:
print("Tests cancelled.") confirm = input(
sys.exit(1) "It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
# Cursor must be closed before closing connection. if autoclobber or confirm == 'yes':
cursor.close() try:
if verbosity >= 1:
print("Destroying old test user...")
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print("Creating test user...")
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2)
else:
print("Tests cancelled.")
sys.exit(1)
self._maindb_connection.close() # done with main user -- test user and tablespaces created self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters) self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME'] return self.connection.settings_dict['NAME']
@ -175,17 +172,15 @@ class DatabaseCreation(BaseDatabaseCreation):
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close() self.connection.close()
parameters = self._get_test_db_params() parameters = self._get_test_db_params()
cursor = self._maindb_connection.cursor() with self._maindb_connection.cursor() as cursor:
if self._test_user_create(): if self._test_user_create():
if verbosity >= 1: if verbosity >= 1:
print('Destroying test user...') print('Destroying test user...')
self._destroy_test_user(cursor, parameters, verbosity) self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create(): if self._test_database_create():
if verbosity >= 1: if verbosity >= 1:
print('Destroying test database tables...') print('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity) self._execute_test_db_destruction(cursor, parameters, verbosity)
# Cursor must be closed before closing connection.
cursor.close()
self._maindb_connection.close() self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False): def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):

View File

@ -237,37 +237,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Backends can override this method if they can more directly apply Backends can override this method if they can more directly apply
constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
""" """
cursor = self.cursor() with self.cursor() as cursor:
if table_names is None: if table_names is None:
table_names = self.introspection.table_names(cursor) table_names = self.introspection.table_names(cursor)
for table_name in table_names: for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name: if not primary_key_column_name:
continue continue
key_columns = self.introspection.get_key_columns(cursor, table_name) key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns: for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute( cursor.execute(
""" """
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`) ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" """
% ( % (
primary_key_column_name, column_name, table_name, primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name, referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name, column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
) )
) )
for bad_row in cursor.fetchall():
raise utils.IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self): def is_usable(self):
return True return True

View File

@ -322,7 +322,8 @@ class MigrationExecutor:
apps = after_state.apps apps = after_state.apps
found_create_model_migration = False found_create_model_migration = False
found_add_field_migration = False found_add_field_migration = False
existing_table_names = self.connection.introspection.table_names(self.connection.cursor()) with self.connection.cursor() as cursor:
existing_table_names = self.connection.introspection.table_names(cursor)
# Make sure all create model and add field operations are done # Make sure all create model and add field operations are done
for operation in migration.operations: for operation in migration.operations:
if isinstance(operation, migrations.CreateModel): if isinstance(operation, migrations.CreateModel):

View File

@ -852,9 +852,9 @@ class TransactionTestCase(SimpleTestCase):
no_style(), conn.introspection.sequence_list()) no_style(), conn.introspection.sequence_list())
if sql_list: if sql_list:
with transaction.atomic(using=db_name): with transaction.atomic(using=db_name):
cursor = conn.cursor() with conn.cursor() as cursor:
for sql in sql_list: for sql in sql_list:
cursor.execute(sql) cursor.execute(sql)
def _fixture_setup(self): def _fixture_setup(self):
for db_name in self._databases_names(include_mirrors=False): for db_name in self._databases_names(include_mirrors=False):

View File

@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its
alias:: alias::
from django.db import connections from django.db import connections
cursor = connections['my_db_alias'].cursor() with connections['my_db_alias'].cursor() as cursor:
...
Limitations of multiple databases Limitations of multiple databases
================================= =================================

View File

@ -279,8 +279,8 @@ object that allows you to retrieve a specific connection using its
alias:: alias::
from django.db import connections from django.db import connections
cursor = connections['my_db_alias'].cursor() with connections['my_db_alias'].cursor() as cursor:
# Your code here... # Your code here...
By default, the Python DB API will return results without their field names, By default, the Python DB API will return results without their field names,
which means you end up with a ``list`` of values, rather than a ``dict``. At a which means you end up with a ``list`` of values, rather than a ``dict``. At a

View File

@ -9,15 +9,15 @@ from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class DatabaseSequenceTests(TestCase): class DatabaseSequenceTests(TestCase):
def test_get_sequences(self): def test_get_sequences(self):
cursor = connection.cursor() with connection.cursor() as cursor:
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual( self.assertEqual(
seqs, seqs,
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
) )
cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
self.assertEqual( self.assertEqual(
seqs, seqs,
[{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
) )

View File

@ -44,10 +44,10 @@ class Tests(TestCase):
# Ensure the database default time zone is different than # Ensure the database default time zone is different than
# the time zone in new_connection.settings_dict. We can # the time zone in new_connection.settings_dict. We can
# get the default time zone by reset & show. # get the default time zone by reset & show.
cursor = new_connection.cursor() with new_connection.cursor() as cursor:
cursor.execute("RESET TIMEZONE") cursor.execute("RESET TIMEZONE")
cursor.execute("SHOW TIMEZONE") cursor.execute("SHOW TIMEZONE")
db_default_tz = cursor.fetchone()[0] db_default_tz = cursor.fetchone()[0]
new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC' new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
new_connection.close() new_connection.close()
@ -59,12 +59,12 @@ class Tests(TestCase):
# time zone, run a query and rollback. # time zone, run a query and rollback.
with self.settings(TIME_ZONE=new_tz): with self.settings(TIME_ZONE=new_tz):
new_connection.set_autocommit(False) new_connection.set_autocommit(False)
cursor = new_connection.cursor()
new_connection.rollback() new_connection.rollback()
# Now let's see if the rollback rolled back the SET TIME ZONE. # Now let's see if the rollback rolled back the SET TIME ZONE.
cursor.execute("SHOW TIMEZONE") with new_connection.cursor() as cursor:
tz = cursor.fetchone()[0] cursor.execute("SHOW TIMEZONE")
tz = cursor.fetchone()[0]
self.assertEqual(new_tz, tz) self.assertEqual(new_tz, tz)
finally: finally:

View File

@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase):
# If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
# greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
# can hit the SQLITE_MAX_COLUMN limit (#26063). # can hit the SQLITE_MAX_COLUMN limit (#26063).
cursor = connection.cursor() with connection.cursor() as cursor:
sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
params = list(range(2001)) params = list(range(2001))
# This should not raise an exception. # This should not raise an exception.
cursor.db.ops.last_executed_query(cursor.cursor, sql, params) cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
@ -97,9 +97,9 @@ class EscapingChecks(TestCase):
""" """
def test_parameter_escaping(self): def test_parameter_escaping(self):
# '%s' escaping support for sqlite3 (#13648). # '%s' escaping support for sqlite3 (#13648).
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("select strftime('%s', date('now'))") cursor.execute("select strftime('%s', date('now'))")
response = cursor.fetchall()[0][0] response = cursor.fetchall()[0][0]
# response should be an non-zero integer # response should be an non-zero integer
self.assertTrue(int(response)) self.assertTrue(int(response))

View File

@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase):
last_executed_query should not raise an exception even if no previous last_executed_query should not raise an exception even if no previous
query has been run. query has been run.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
connection.ops.last_executed_query(cursor, '', ()) connection.ops.last_executed_query(cursor, '', ())
def test_debug_sql(self): def test_debug_sql(self):
list(Reporter.objects.filter(first_name="test")) list(Reporter.objects.filter(first_name="test"))
@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self): def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
cursor = connection.cursor() with connection.cursor() as cursor:
query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
connection.introspection.table_name_converter('backends_square'), connection.introspection.table_name_converter('backends_square'),
connection.ops.quote_name('root'), connection.ops.quote_name('root'),
connection.ops.quote_name('square') connection.ops.quote_name('square')
)) ))
with self.assertRaises(Exception): with self.assertRaises(Exception):
cursor.executemany(query, [(1, 2, 3)]) cursor.executemany(query, [(1, 2, 3)])
with self.assertRaises(Exception): with self.assertRaises(Exception):
cursor.executemany(query, [(1,)]) cursor.executemany(query, [(1,)])
class LongNameTest(TransactionTestCase): class LongNameTest(TransactionTestCase):
@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase):
'table': VLM._meta.db_table 'table': VLM._meta.db_table
}, },
] ]
cursor = connection.cursor() sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
for statement in connection.ops.sql_flush(no_style(), tables, sequences): with connection.cursor() as cursor:
cursor.execute(statement) for statement in sql_list:
cursor.execute(statement)
class SequenceResetTest(TestCase): class SequenceResetTest(TestCase):
@ -146,10 +147,10 @@ class SequenceResetTest(TestCase):
Post.objects.create(id=10, name='1st post', text='hello world') Post.objects.create(id=10, name='1st post', text='hello world')
# Reset the sequences for the database # Reset the sequences for the database
cursor = connection.cursor()
commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post]) commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
for sql in commands: with connection.cursor() as cursor:
cursor.execute(sql) for sql in commands:
cursor.execute(sql)
# If we create a new object now, it should have a PK greater # If we create a new object now, it should have a PK greater
# than the PK we specified manually. # than the PK we specified manually.
@ -192,14 +193,14 @@ class EscapingChecks(TestCase):
bare_select_suffix = connection.features.bare_select_suffix bare_select_suffix = connection.features.bare_select_suffix
def test_paramless_no_escaping(self): def test_paramless_no_escaping(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT '%s'" + self.bare_select_suffix) cursor.execute("SELECT '%s'" + self.bare_select_suffix)
self.assertEqual(cursor.fetchall()[0][0], '%s') self.assertEqual(cursor.fetchall()[0][0], '%s')
def test_parameter_escaping(self): def test_parameter_escaping(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
self.assertEqual(cursor.fetchall()[0], ('%', '%d')) self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
@override_settings(DEBUG=True) @override_settings(DEBUG=True)
@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
self.create_squares(args, 'format', True) self.create_squares(args, 'format', True)
def create_squares(self, args, paramstyle, multiple): def create_squares(self, args, paramstyle, multiple):
cursor = connection.cursor()
opts = Square._meta opts = Square._meta
tbl = connection.introspection.table_name_converter(opts.db_table) tbl = connection.introspection.table_name_converter(opts.db_table)
f1 = connection.ops.quote_name(opts.get_field('root').column) f1 = connection.ops.quote_name(opts.get_field('root').column)
@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase):
query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2) query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
else: else:
raise ValueError("unsupported paramstyle in test") raise ValueError("unsupported paramstyle in test")
if multiple: with connection.cursor() as cursor:
cursor.executemany(query, args) if multiple:
else: cursor.executemany(query, args)
cursor.execute(query, args) else:
cursor.execute(query, args)
def test_cursor_executemany(self): def test_cursor_executemany(self):
# Test cursor.executemany #4896 # Test cursor.executemany #4896
@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase):
Person(first_name="Clark", last_name="Kent").save() Person(first_name="Clark", last_name="Kent").save()
opts2 = Person._meta opts2 = Person._meta
f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute( cursor.execute(
'SELECT %s, %s FROM %s ORDER BY %s' % ( 'SELECT %s, %s FROM %s ORDER BY %s' % (
qn(f3.column), qn(f3.column),
qn(f4.column), qn(f4.column),
connection.introspection.table_name_converter(opts2.db_table), connection.introspection.table_name_converter(opts2.db_table),
qn(f3.column), qn(f3.column),
)
) )
) self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
def test_unicode_password(self): def test_unicode_password(self):
old_password = connection.settings_dict['PASSWORD'] old_password = connection.settings_dict['PASSWORD']
@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase):
def test_duplicate_table_error(self): def test_duplicate_table_error(self):
""" Creating an existing table returns a DatabaseError """ """ Creating an existing table returns a DatabaseError """
cursor = connection.cursor()
query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
with self.assertRaises(DatabaseError): with connection.cursor() as cursor:
cursor.execute(query) with self.assertRaises(DatabaseError):
cursor.execute(query)
def test_cursor_contextmanager(self): def test_cursor_contextmanager(self):
""" """

View File

@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase):
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test') @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
def test_reraising_backend_specific_database_exception(self): def test_reraising_backend_specific_database_exception(self):
cursor = connection.cursor() with connection.cursor() as cursor:
msg = 'table "X" does not exist' msg = 'table "X" does not exist'
with self.assertRaisesMessage(ProgrammingError, msg) as cm: with self.assertRaisesMessage(ProgrammingError, msg) as cm:
cursor.execute('DROP TABLE "X"') cursor.execute('DROP TABLE "X"')
self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__)) self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__))
self.assertIsNotNone(cm.exception.__cause__) self.assertIsNotNone(cm.exception.__cause__)
self.assertIsNotNone(cm.exception.__cause__.pgcode) self.assertIsNotNone(cm.exception.__cause__.pgcode)