From 7a6fbf36b1fdb8978ea0842075ccce83bcd63789 Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Tue, 28 Nov 2017 05:12:28 -0800 Subject: [PATCH] Fixed #28853 -- Updated connection.cursor() uses to use a context manager. --- .../gis/db/backends/mysql/introspection.py | 6 +- .../gis/db/backends/oracle/introspection.py | 6 +- .../gis/db/backends/postgis/introspection.py | 12 +- .../db/backends/spatialite/introspection.py | 6 +- django/db/backends/base/base.py | 5 +- django/db/backends/base/introspection.py | 29 ++-- django/db/backends/mysql/base.py | 57 +++---- django/db/backends/oracle/creation.py | 147 +++++++++--------- django/db/backends/sqlite3/base.py | 58 +++---- django/db/migrations/executor.py | 3 +- django/test/testcases.py | 6 +- docs/topics/db/multi-db.txt | 3 +- docs/topics/db/sql.txt | 4 +- .../backends/postgresql/test_introspection.py | 24 +-- tests/backends/postgresql/tests.py | 14 +- tests/backends/sqlite/tests.py | 16 +- tests/backends/tests.py | 87 ++++++----- tests/db_utils/tests.py | 8 +- 18 files changed, 234 insertions(+), 257 deletions(-) diff --git a/django/contrib/gis/db/backends/mysql/introspection.py b/django/contrib/gis/db/backends/mysql/introspection.py index 364cdeecfeb..ba125ec4af7 100644 --- a/django/contrib/gis/db/backends/mysql/introspection.py +++ b/django/contrib/gis/db/backends/mysql/introspection.py @@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection): data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField' def get_geometry_type(self, table_name, geo_col): - cursor = self.connection.cursor() - try: + with self.connection.cursor() as cursor: # In order to get the specific geometry type of the field, # we introspect on the table definition using `DESCRIBE`. cursor.execute('DESCRIBE %s' % @@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection): field_type = OGRGeomType(typ).django field_params = {} break - finally: - cursor.close() - return field_type, field_params def supports_spatial_index(self, cursor, table_name): diff --git a/django/contrib/gis/db/backends/oracle/introspection.py b/django/contrib/gis/db/backends/oracle/introspection.py index 446dc782161..7f4f886a341 100644 --- a/django/contrib/gis/db/backends/oracle/introspection.py +++ b/django/contrib/gis/db/backends/oracle/introspection.py @@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection): data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField' def get_geometry_type(self, table_name, geo_col): - cursor = self.connection.cursor() - try: + with self.connection.cursor() as cursor: # Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information. try: cursor.execute( @@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection): dim = dim.size() if dim != 2: field_params['dim'] = dim - finally: - cursor.close() - return field_type, field_params diff --git a/django/contrib/gis/db/backends/postgis/introspection.py b/django/contrib/gis/db/backends/postgis/introspection.py index 3a90ebf5c5a..97fa7480e61 100644 --- a/django/contrib/gis/db/backends/postgis/introspection.py +++ b/django/contrib/gis/db/backends/postgis/introspection.py @@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection): # to query the PostgreSQL pg_type table corresponding to the # PostGIS custom data types. oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s' - cursor = self.connection.cursor() - try: + with self.connection.cursor() as cursor: for field_type in field_types: cursor.execute(oid_sql, (field_type[0],)) for result in cursor.fetchall(): postgis_types[result[0]] = field_type[1] - finally: - cursor.close() - return postgis_types def get_field_type(self, data_type, description): @@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection): PointField or a PolygonField). Thus, this routine queries the PostGIS metadata tables to determine the geometry type. """ - cursor = self.connection.cursor() - try: + with self.connection.cursor() as cursor: try: # First seeing if this geometry column is in the `geometry_columns` cursor.execute('SELECT "coord_dimension", "srid", "type" ' @@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection): field_params['srid'] = srid if dim != 2: field_params['dim'] = dim - finally: - cursor.close() - return field_type, field_params diff --git a/django/contrib/gis/db/backends/spatialite/introspection.py b/django/contrib/gis/db/backends/spatialite/introspection.py index 98fda427d04..5cd5613b69a 100644 --- a/django/contrib/gis/db/backends/spatialite/introspection.py +++ b/django/contrib/gis/db/backends/spatialite/introspection.py @@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection): data_types_reverse = GeoFlexibleFieldLookupDict() def get_geometry_type(self, table_name, geo_col): - cursor = self.connection.cursor() - try: + with self.connection.cursor() as cursor: # Querying the `geometry_columns` table to get additional metadata. cursor.execute('SELECT coord_dimension, srid, geometry_type ' 'FROM geometry_columns ' @@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection): field_params['srid'] = srid if (isinstance(dim, str) and 'Z' in dim) or dim == 3: field_params['dim'] = 3 - finally: - cursor.close() - return field_type, field_params def get_constraints(self, cursor, table_name): diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index 468eb16e140..53c3c30063d 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -573,11 +573,10 @@ class BaseDatabaseWrapper: Provide a cursor: with self.temporary_connection() as cursor: ... """ must_close = self.connection is None - cursor = self.cursor() try: - yield cursor + with self.cursor() as cursor: + yield cursor finally: - cursor.close() if must_close: self.close() diff --git a/django/db/backends/base/introspection.py b/django/db/backends/base/introspection.py index 154ae22bf15..8fe8966a215 100644 --- a/django/db/backends/base/introspection.py +++ b/django/db/backends/base/introspection.py @@ -116,21 +116,20 @@ class BaseDatabaseIntrospection: from django.db import router sequence_list = [] - cursor = self.connection.cursor() - - for app_config in apps.get_app_configs(): - for model in router.get_migratable_models(app_config, self.connection.alias): - if not model._meta.managed: - continue - if model._meta.swapped: - continue - sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields)) - for f in model._meta.local_many_to_many: - # If this is an m2m using an intermediate table, - # we don't need to reset the sequence. - if f.remote_field.through is None: - sequence = self.get_sequences(cursor, f.m2m_db_table()) - sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}]) + with self.connection.cursor() as cursor: + for app_config in apps.get_app_configs(): + for model in router.get_migratable_models(app_config, self.connection.alias): + if not model._meta.managed: + continue + if model._meta.swapped: + continue + sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields)) + for f in model._meta.local_many_to_many: + # If this is an m2m using an intermediate table, + # we don't need to reset the sequence. + if f.remote_field.through is None: + sequence = self.get_sequences(cursor, f.m2m_db_table()) + sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}]) return sequence_list def get_sequences(self, cursor, table_name, table_fields=()): diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index abf8f557368..6b82afc45d6 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -294,36 +294,37 @@ class DatabaseWrapper(BaseDatabaseWrapper): Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") """ - cursor = self.cursor() - if table_names is None: - table_names = self.introspection.table_names(cursor) - for table_name in table_names: - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) - if not primary_key_column_name: - continue - key_columns = self.introspection.get_key_columns(cursor, table_name) - for column_name, referenced_table_name, referenced_column_name in key_columns: - cursor.execute( - """ - SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING - LEFT JOIN `%s` as REFERRED - ON (REFERRING.`%s` = REFERRED.`%s`) - WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL - """ % ( - primary_key_column_name, column_name, table_name, - referenced_table_name, column_name, referenced_column_name, - column_name, referenced_column_name, - ) - ) - for bad_row in cursor.fetchall(): - raise utils.IntegrityError( - "The row in table '%s' with primary key '%s' has an invalid " - "foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s." - % ( - table_name, bad_row[0], table_name, column_name, - bad_row[1], referenced_table_name, referenced_column_name, + with self.cursor() as cursor: + if table_names is None: + table_names = self.introspection.table_names(cursor) + for table_name in table_names: + primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + if not primary_key_column_name: + continue + key_columns = self.introspection.get_key_columns(cursor, table_name) + for column_name, referenced_table_name, referenced_column_name in key_columns: + cursor.execute( + """ + SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING + LEFT JOIN `%s` as REFERRED + ON (REFERRING.`%s` = REFERRED.`%s`) + WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL + """ % ( + primary_key_column_name, column_name, table_name, + referenced_table_name, column_name, referenced_column_name, + column_name, referenced_column_name, ) ) + for bad_row in cursor.fetchall(): + raise utils.IntegrityError( + "The row in table '%s' with primary key '%s' has an invalid " + "foreign key: %s.%s contains a value '%s' that does not " + "have a corresponding value in %s.%s." + % ( + table_name, bad_row[0], table_name, column_name, + bad_row[1], referenced_table_name, referenced_column_name, + ) + ) def is_usable(self): try: diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py index fe053c54f72..aa57e6a6ab4 100644 --- a/django/db/backends/oracle/creation.py +++ b/django/db/backends/oracle/creation.py @@ -30,75 +30,72 @@ class DatabaseCreation(BaseDatabaseCreation): def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False): parameters = self._get_test_db_params() - cursor = self._maindb_connection.cursor() - if self._test_database_create(): - try: - self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) - except Exception as e: - if 'ORA-01543' not in str(e): - # All errors except "tablespace already exists" cancel tests - sys.stderr.write("Got an error creating the test database: %s\n" % e) - sys.exit(2) - if not autoclobber: - confirm = input( - "It appears the test database, %s, already exists. " - "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user']) - if autoclobber or confirm == 'yes': - if verbosity >= 1: - print("Destroying old test database for alias '%s'..." % self.connection.alias) - try: - self._execute_test_db_destruction(cursor, parameters, verbosity) - except DatabaseError as e: - if 'ORA-29857' in str(e): - self._handle_objects_preventing_db_destruction(cursor, parameters, - verbosity, autoclobber) - else: - # Ran into a database error that isn't about leftover objects in the tablespace + with self._maindb_connection.cursor() as cursor: + if self._test_database_create(): + try: + self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) + except Exception as e: + if 'ORA-01543' not in str(e): + # All errors except "tablespace already exists" cancel tests + sys.stderr.write("Got an error creating the test database: %s\n" % e) + sys.exit(2) + if not autoclobber: + confirm = input( + "It appears the test database, %s, already exists. " + "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user']) + if autoclobber or confirm == 'yes': + if verbosity >= 1: + print("Destroying old test database for alias '%s'..." % self.connection.alias) + try: + self._execute_test_db_destruction(cursor, parameters, verbosity) + except DatabaseError as e: + if 'ORA-29857' in str(e): + self._handle_objects_preventing_db_destruction(cursor, parameters, + verbosity, autoclobber) + else: + # Ran into a database error that isn't about leftover objects in the tablespace + sys.stderr.write("Got an error destroying the old test database: %s\n" % e) + sys.exit(2) + except Exception as e: sys.stderr.write("Got an error destroying the old test database: %s\n" % e) sys.exit(2) - except Exception as e: - sys.stderr.write("Got an error destroying the old test database: %s\n" % e) - sys.exit(2) - try: - self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) - except Exception as e: - sys.stderr.write("Got an error recreating the test database: %s\n" % e) - sys.exit(2) - else: - print("Tests cancelled.") - sys.exit(1) + try: + self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) + except Exception as e: + sys.stderr.write("Got an error recreating the test database: %s\n" % e) + sys.exit(2) + else: + print("Tests cancelled.") + sys.exit(1) - if self._test_user_create(): - if verbosity >= 1: - print("Creating test user...") - try: - self._create_test_user(cursor, parameters, verbosity, keepdb) - except Exception as e: - if 'ORA-01920' not in str(e): - # All errors except "user already exists" cancel tests - sys.stderr.write("Got an error creating the test user: %s\n" % e) - sys.exit(2) - if not autoclobber: - confirm = input( - "It appears the test user, %s, already exists. Type " - "'yes' to delete it, or 'no' to cancel: " % parameters['user']) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print("Destroying old test user...") - self._destroy_test_user(cursor, parameters, verbosity) - if verbosity >= 1: - print("Creating test user...") - self._create_test_user(cursor, parameters, verbosity, keepdb) - except Exception as e: - sys.stderr.write("Got an error recreating the test user: %s\n" % e) + if self._test_user_create(): + if verbosity >= 1: + print("Creating test user...") + try: + self._create_test_user(cursor, parameters, verbosity, keepdb) + except Exception as e: + if 'ORA-01920' not in str(e): + # All errors except "user already exists" cancel tests + sys.stderr.write("Got an error creating the test user: %s\n" % e) sys.exit(2) - else: - print("Tests cancelled.") - sys.exit(1) - - # Cursor must be closed before closing connection. - cursor.close() + if not autoclobber: + confirm = input( + "It appears the test user, %s, already exists. Type " + "'yes' to delete it, or 'no' to cancel: " % parameters['user']) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print("Destroying old test user...") + self._destroy_test_user(cursor, parameters, verbosity) + if verbosity >= 1: + print("Creating test user...") + self._create_test_user(cursor, parameters, verbosity, keepdb) + except Exception as e: + sys.stderr.write("Got an error recreating the test user: %s\n" % e) + sys.exit(2) + else: + print("Tests cancelled.") + sys.exit(1) self._maindb_connection.close() # done with main user -- test user and tablespaces created self._switch_to_test_user(parameters) return self.connection.settings_dict['NAME'] @@ -175,17 +172,15 @@ class DatabaseCreation(BaseDatabaseCreation): self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] self.connection.close() parameters = self._get_test_db_params() - cursor = self._maindb_connection.cursor() - if self._test_user_create(): - if verbosity >= 1: - print('Destroying test user...') - self._destroy_test_user(cursor, parameters, verbosity) - if self._test_database_create(): - if verbosity >= 1: - print('Destroying test database tables...') - self._execute_test_db_destruction(cursor, parameters, verbosity) - # Cursor must be closed before closing connection. - cursor.close() + with self._maindb_connection.cursor() as cursor: + if self._test_user_create(): + if verbosity >= 1: + print('Destroying test user...') + self._destroy_test_user(cursor, parameters, verbosity) + if self._test_database_create(): + if verbosity >= 1: + print('Destroying test database tables...') + self._execute_test_db_destruction(cursor, parameters, verbosity) self._maindb_connection.close() def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 40b205cf30d..13f9a575e2a 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -237,37 +237,37 @@ class DatabaseWrapper(BaseDatabaseWrapper): Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE") """ - cursor = self.cursor() - if table_names is None: - table_names = self.introspection.table_names(cursor) - for table_name in table_names: - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) - if not primary_key_column_name: - continue - key_columns = self.introspection.get_key_columns(cursor, table_name) - for column_name, referenced_table_name, referenced_column_name in key_columns: - cursor.execute( - """ - SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING - LEFT JOIN `%s` as REFERRED - ON (REFERRING.`%s` = REFERRED.`%s`) - WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL - """ - % ( - primary_key_column_name, column_name, table_name, - referenced_table_name, column_name, referenced_column_name, - column_name, referenced_column_name, - ) - ) - for bad_row in cursor.fetchall(): - raise utils.IntegrityError( - "The row in table '%s' with primary key '%s' has an " - "invalid foreign key: %s.%s contains a value '%s' that " - "does not have a corresponding value in %s.%s." % ( - table_name, bad_row[0], table_name, column_name, - bad_row[1], referenced_table_name, referenced_column_name, + with self.cursor() as cursor: + if table_names is None: + table_names = self.introspection.table_names(cursor) + for table_name in table_names: + primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + if not primary_key_column_name: + continue + key_columns = self.introspection.get_key_columns(cursor, table_name) + for column_name, referenced_table_name, referenced_column_name in key_columns: + cursor.execute( + """ + SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING + LEFT JOIN `%s` as REFERRED + ON (REFERRING.`%s` = REFERRED.`%s`) + WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL + """ + % ( + primary_key_column_name, column_name, table_name, + referenced_table_name, column_name, referenced_column_name, + column_name, referenced_column_name, ) ) + for bad_row in cursor.fetchall(): + raise utils.IntegrityError( + "The row in table '%s' with primary key '%s' has an " + "invalid foreign key: %s.%s contains a value '%s' that " + "does not have a corresponding value in %s.%s." % ( + table_name, bad_row[0], table_name, column_name, + bad_row[1], referenced_table_name, referenced_column_name, + ) + ) def is_usable(self): return True diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py index ea7bc70db35..ebaf75634b0 100644 --- a/django/db/migrations/executor.py +++ b/django/db/migrations/executor.py @@ -322,7 +322,8 @@ class MigrationExecutor: apps = after_state.apps found_create_model_migration = False found_add_field_migration = False - existing_table_names = self.connection.introspection.table_names(self.connection.cursor()) + with self.connection.cursor() as cursor: + existing_table_names = self.connection.introspection.table_names(cursor) # Make sure all create model and add field operations are done for operation in migration.operations: if isinstance(operation, migrations.CreateModel): diff --git a/django/test/testcases.py b/django/test/testcases.py index afab58f2dc4..4c32c87da55 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -852,9 +852,9 @@ class TransactionTestCase(SimpleTestCase): no_style(), conn.introspection.sequence_list()) if sql_list: with transaction.atomic(using=db_name): - cursor = conn.cursor() - for sql in sql_list: - cursor.execute(sql) + with conn.cursor() as cursor: + for sql in sql_list: + cursor.execute(sql) def _fixture_setup(self): for db_name in self._databases_names(include_mirrors=False): diff --git a/docs/topics/db/multi-db.txt b/docs/topics/db/multi-db.txt index 78f3fe23d92..394d25cfd9a 100644 --- a/docs/topics/db/multi-db.txt +++ b/docs/topics/db/multi-db.txt @@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its alias:: from django.db import connections - cursor = connections['my_db_alias'].cursor() + with connections['my_db_alias'].cursor() as cursor: + ... Limitations of multiple databases ================================= diff --git a/docs/topics/db/sql.txt b/docs/topics/db/sql.txt index dc2bcc50b6b..96d79a999ce 100644 --- a/docs/topics/db/sql.txt +++ b/docs/topics/db/sql.txt @@ -279,8 +279,8 @@ object that allows you to retrieve a specific connection using its alias:: from django.db import connections - cursor = connections['my_db_alias'].cursor() - # Your code here... + with connections['my_db_alias'].cursor() as cursor: + # Your code here... By default, the Python DB API will return results without their field names, which means you end up with a ``list`` of values, rather than a ``dict``. At a diff --git a/tests/backends/postgresql/test_introspection.py b/tests/backends/postgresql/test_introspection.py index cfa801a77f7..4dcadbd7333 100644 --- a/tests/backends/postgresql/test_introspection.py +++ b/tests/backends/postgresql/test_introspection.py @@ -9,15 +9,15 @@ from ..models import Person @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") class DatabaseSequenceTests(TestCase): def test_get_sequences(self): - cursor = connection.cursor() - seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) - self.assertEqual( - seqs, - [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] - ) - cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') - seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) - self.assertEqual( - seqs, - [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] - ) + with connection.cursor() as cursor: + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) + self.assertEqual( + seqs, + [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}] + ) + cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq') + seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table) + self.assertEqual( + seqs, + [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}] + ) diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py index 140fbbc444c..caa8b87fb94 100644 --- a/tests/backends/postgresql/tests.py +++ b/tests/backends/postgresql/tests.py @@ -44,10 +44,10 @@ class Tests(TestCase): # Ensure the database default time zone is different than # the time zone in new_connection.settings_dict. We can # get the default time zone by reset & show. - cursor = new_connection.cursor() - cursor.execute("RESET TIMEZONE") - cursor.execute("SHOW TIMEZONE") - db_default_tz = cursor.fetchone()[0] + with new_connection.cursor() as cursor: + cursor.execute("RESET TIMEZONE") + cursor.execute("SHOW TIMEZONE") + db_default_tz = cursor.fetchone()[0] new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC' new_connection.close() @@ -59,12 +59,12 @@ class Tests(TestCase): # time zone, run a query and rollback. with self.settings(TIME_ZONE=new_tz): new_connection.set_autocommit(False) - cursor = new_connection.cursor() new_connection.rollback() # Now let's see if the rollback rolled back the SET TIME ZONE. - cursor.execute("SHOW TIMEZONE") - tz = cursor.fetchone()[0] + with new_connection.cursor() as cursor: + cursor.execute("SHOW TIMEZONE") + tz = cursor.fetchone()[0] self.assertEqual(new_tz, tz) finally: diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py index 0c07f95e6f9..0fbb1392785 100644 --- a/tests/backends/sqlite/tests.py +++ b/tests/backends/sqlite/tests.py @@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase): # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query # can hit the SQLITE_MAX_COLUMN limit (#26063). - cursor = connection.cursor() - sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) - params = list(range(2001)) - # This should not raise an exception. - cursor.db.ops.last_executed_query(cursor.cursor, sql, params) + with connection.cursor() as cursor: + sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001) + params = list(range(2001)) + # This should not raise an exception. + cursor.db.ops.last_executed_query(cursor.cursor, sql, params) @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') @@ -97,9 +97,9 @@ class EscapingChecks(TestCase): """ def test_parameter_escaping(self): # '%s' escaping support for sqlite3 (#13648). - cursor = connection.cursor() - cursor.execute("select strftime('%s', date('now'))") - response = cursor.fetchall()[0][0] + with connection.cursor() as cursor: + cursor.execute("select strftime('%s', date('now'))") + response = cursor.fetchall()[0][0] # response should be an non-zero integer self.assertTrue(int(response)) diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 6d38625a986..32d76577a1d 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase): last_executed_query should not raise an exception even if no previous query has been run. """ - cursor = connection.cursor() - connection.ops.last_executed_query(cursor, '', ()) + with connection.cursor() as cursor: + connection.ops.last_executed_query(cursor, '', ()) def test_debug_sql(self): list(Reporter.objects.filter(first_name="test")) @@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase): def test_bad_parameter_count(self): "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" - cursor = connection.cursor() - query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( - connection.introspection.table_name_converter('backends_square'), - connection.ops.quote_name('root'), - connection.ops.quote_name('square') - )) - with self.assertRaises(Exception): - cursor.executemany(query, [(1, 2, 3)]) - with self.assertRaises(Exception): - cursor.executemany(query, [(1,)]) + with connection.cursor() as cursor: + query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % ( + connection.introspection.table_name_converter('backends_square'), + connection.ops.quote_name('root'), + connection.ops.quote_name('square') + )) + with self.assertRaises(Exception): + cursor.executemany(query, [(1, 2, 3)]) + with self.assertRaises(Exception): + cursor.executemany(query, [(1,)]) class LongNameTest(TransactionTestCase): @@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase): 'table': VLM._meta.db_table }, ] - cursor = connection.cursor() - for statement in connection.ops.sql_flush(no_style(), tables, sequences): - cursor.execute(statement) + sql_list = connection.ops.sql_flush(no_style(), tables, sequences) + with connection.cursor() as cursor: + for statement in sql_list: + cursor.execute(statement) class SequenceResetTest(TestCase): @@ -146,10 +147,10 @@ class SequenceResetTest(TestCase): Post.objects.create(id=10, name='1st post', text='hello world') # Reset the sequences for the database - cursor = connection.cursor() commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post]) - for sql in commands: - cursor.execute(sql) + with connection.cursor() as cursor: + for sql in commands: + cursor.execute(sql) # If we create a new object now, it should have a PK greater # than the PK we specified manually. @@ -192,14 +193,14 @@ class EscapingChecks(TestCase): bare_select_suffix = connection.features.bare_select_suffix def test_paramless_no_escaping(self): - cursor = connection.cursor() - cursor.execute("SELECT '%s'" + self.bare_select_suffix) - self.assertEqual(cursor.fetchall()[0][0], '%s') + with connection.cursor() as cursor: + cursor.execute("SELECT '%s'" + self.bare_select_suffix) + self.assertEqual(cursor.fetchall()[0][0], '%s') def test_parameter_escaping(self): - cursor = connection.cursor() - cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) - self.assertEqual(cursor.fetchall()[0], ('%', '%d')) + with connection.cursor() as cursor: + cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',)) + self.assertEqual(cursor.fetchall()[0], ('%', '%d')) @override_settings(DEBUG=True) @@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase): self.create_squares(args, 'format', True) def create_squares(self, args, paramstyle, multiple): - cursor = connection.cursor() opts = Square._meta tbl = connection.introspection.table_name_converter(opts.db_table) f1 = connection.ops.quote_name(opts.get_field('root').column) @@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase): query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2) else: raise ValueError("unsupported paramstyle in test") - if multiple: - cursor.executemany(query, args) - else: - cursor.execute(query, args) + with connection.cursor() as cursor: + if multiple: + cursor.executemany(query, args) + else: + cursor.execute(query, args) def test_cursor_executemany(self): # Test cursor.executemany #4896 @@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase): Person(first_name="Clark", last_name="Kent").save() opts2 = Person._meta f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') - cursor = connection.cursor() - cursor.execute( - 'SELECT %s, %s FROM %s ORDER BY %s' % ( - qn(f3.column), - qn(f4.column), - connection.introspection.table_name_converter(opts2.db_table), - qn(f3.column), + with connection.cursor() as cursor: + cursor.execute( + 'SELECT %s, %s FROM %s ORDER BY %s' % ( + qn(f3.column), + qn(f4.column), + connection.introspection.table_name_converter(opts2.db_table), + qn(f3.column), + ) ) - ) - self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) - self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) - self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')]) + self.assertEqual(cursor.fetchone(), ('Clark', 'Kent')) + self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')]) + self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')]) def test_unicode_password(self): old_password = connection.settings_dict['PASSWORD'] @@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase): def test_duplicate_table_error(self): """ Creating an existing table returns a DatabaseError """ - cursor = connection.cursor() query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table - with self.assertRaises(DatabaseError): - cursor.execute(query) + with connection.cursor() as cursor: + with self.assertRaises(DatabaseError): + cursor.execute(query) def test_cursor_contextmanager(self): """ diff --git a/tests/db_utils/tests.py b/tests/db_utils/tests.py index 2a45342df53..4e35e6bb8bb 100644 --- a/tests/db_utils/tests.py +++ b/tests/db_utils/tests.py @@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase): @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test') def test_reraising_backend_specific_database_exception(self): - cursor = connection.cursor() - msg = 'table "X" does not exist' - with self.assertRaisesMessage(ProgrammingError, msg) as cm: - cursor.execute('DROP TABLE "X"') + with connection.cursor() as cursor: + msg = 'table "X" does not exist' + with self.assertRaisesMessage(ProgrammingError, msg) as cm: + cursor.execute('DROP TABLE "X"') self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__)) self.assertIsNotNone(cm.exception.__cause__) self.assertIsNotNone(cm.exception.__cause__.pgcode)