diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 527197971f7..be000248498 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -269,7 +269,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): forward references. Always return True to indicate constraint checks need to be re-enabled. """ - self.cursor().execute('SET foreign_key_checks=0') + with self.cursor() as cursor: + cursor.execute('SET foreign_key_checks=0') return True def enable_constraint_checking(self): @@ -280,7 +281,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): # nested inside transaction.atomic. self.needs_rollback, needs_rollback = False, self.needs_rollback try: - self.cursor().execute('SET foreign_key_checks=1') + with self.cursor() as cursor: + cursor.execute('SET foreign_key_checks=1') finally: self.needs_rollback = needs_rollback diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index e303ec1f3dd..652e3235217 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -301,8 +301,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): Check constraints by setting them to immediate. Return them to deferred afterward. """ - self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') - self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') + with self.cursor() as cursor: + cursor.execute('SET CONSTRAINTS ALL IMMEDIATE') + cursor.execute('SET CONSTRAINTS ALL DEFERRED') def is_usable(self): try: diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 0f45ca93e18..192316d7fbc 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -277,13 +277,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): Check constraints by setting them to immediate. Return them to deferred afterward. """ - self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') - self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') + with self.cursor() as cursor: + cursor.execute('SET CONSTRAINTS ALL IMMEDIATE') + cursor.execute('SET CONSTRAINTS ALL DEFERRED') def is_usable(self): try: # Use a psycopg cursor directly, bypassing Django's utilities. - self.connection.cursor().execute("SELECT 1") + with self.connection.cursor() as cursor: + cursor.execute('SELECT 1') except Database.Error: return False else: diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 642d3d6c5ae..7b3f90a2fd7 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -296,7 +296,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): return not bool(enabled) def enable_constraint_checking(self): - self.cursor().execute('PRAGMA foreign_keys = ON') + with self.cursor() as cursor: + cursor.execute('PRAGMA foreign_keys = ON') def check_constraints(self, table_names=None): """ @@ -309,7 +310,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): if self.features.supports_pragma_foreign_key_check: with self.cursor() as cursor: if table_names is None: - violations = self.cursor().execute('PRAGMA foreign_key_check').fetchall() + violations = cursor.execute('PRAGMA foreign_key_check').fetchall() else: violations = chain.from_iterable( cursor.execute('PRAGMA foreign_key_check(%s)' % table_name).fetchall() diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py index 1f65a7fe998..2765ac28bd8 100644 --- a/django/db/migrations/executor.py +++ b/django/db/migrations/executor.py @@ -372,10 +372,8 @@ class MigrationExecutor: else: found_add_field_migration = True continue - columns = self.connection.introspection.get_table_description( - self.connection.cursor(), - table, - ) + with self.connection.cursor() as cursor: + columns = self.connection.introspection.get_table_description(cursor, table) for column in columns: field_column = field.column column_name = column.name diff --git a/django/db/migrations/recorder.py b/django/db/migrations/recorder.py index b8e32a468cc..1a37c6b7d06 100644 --- a/django/db/migrations/recorder.py +++ b/django/db/migrations/recorder.py @@ -52,7 +52,9 @@ class MigrationRecorder: def has_table(self): """Return True if the django_migrations table exists.""" - return self.Migration._meta.db_table in self.connection.introspection.table_names(self.connection.cursor()) + with self.connection.cursor() as cursor: + tables = self.connection.introspection.table_names(cursor) + return self.Migration._meta.db_table in tables def ensure_schema(self): """Ensure the table exists and has the correct schema.""" diff --git a/django/db/models/query.py b/django/db/models/query.py index 74be5df0ac9..a7c16c4bd80 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -747,7 +747,10 @@ class QuerySet: query = self.query.clone() query.__class__ = sql.DeleteQuery cursor = query.get_compiler(using).execute_sql(CURSOR) - return cursor.rowcount if cursor else 0 + if cursor: + with cursor: + return cursor.rowcount + return 0 _raw_delete.alters_data = True def update(self, **kwargs): diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index af48239cdc7..72b6712864b 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -21,7 +21,10 @@ class DeleteQuery(Query): self.alias_map = {table: self.alias_map[table]} self.where = where cursor = self.get_compiler(using).execute_sql(CURSOR) - return cursor.rowcount if cursor else 0 + if cursor: + with cursor: + return cursor.rowcount + return 0 def delete_batch(self, pk_list, using): """ diff --git a/tests/backends/postgresql/tests.py b/tests/backends/postgresql/tests.py index ab8ad4bb937..1dcb14b964c 100644 --- a/tests/backends/postgresql/tests.py +++ b/tests/backends/postgresql/tests.py @@ -114,8 +114,8 @@ class Tests(TestCase): try: # Open a database connection. - new_connection.cursor() - self.assertFalse(new_connection.get_autocommit()) + with new_connection.cursor(): + self.assertFalse(new_connection.get_autocommit()) finally: new_connection.close() @@ -149,9 +149,12 @@ class Tests(TestCase): def test_connect_no_is_usable_checks(self): new_connection = connection.copy() - with mock.patch.object(new_connection, 'is_usable') as is_usable: - new_connection.connect() - is_usable.assert_not_called() + try: + with mock.patch.object(new_connection, 'is_usable') as is_usable: + new_connection.connect() + is_usable.assert_not_called() + finally: + new_connection.close() def _select(self, val): with connection.cursor() as cursor: diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py index fc02671ca50..c26377c8af8 100644 --- a/tests/backends/sqlite/tests.py +++ b/tests/backends/sqlite/tests.py @@ -197,7 +197,8 @@ class LastExecutedQueryTest(TestCase): def test_no_interpolation(self): # This shouldn't raise an exception (#17158) query = "SELECT strftime('%Y', 'now');" - connection.cursor().execute(query) + with connection.cursor() as cursor: + cursor.execute(query) self.assertEqual(connection.queries[-1]['sql'], query) def test_parameter_quoting(self): @@ -205,7 +206,8 @@ class LastExecutedQueryTest(TestCase): # worth testing that parameters are quoted (#14091). query = "SELECT %s" params = ["\"'\\"] - connection.cursor().execute(query, params) + with connection.cursor() as cursor: + cursor.execute(query, params) # Note that the single quote is repeated substituted = "SELECT '\"''\\'" self.assertEqual(connection.queries[-1]['sql'], substituted) diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 918fc32166d..ce3b863016d 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -69,8 +69,8 @@ class LastExecutedQueryTest(TestCase): """last_executed_query() returns a string.""" data = RawData.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1}) sql, params = data.query.sql_with_params() - cursor = data.query.get_compiler('default').execute_sql(CURSOR) - last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) + with data.query.get_compiler('default').execute_sql(CURSOR) as cursor: + last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) self.assertIsInstance(last_sql, str) def test_last_executed_query(self): @@ -81,11 +81,11 @@ class LastExecutedQueryTest(TestCase): Article.objects.filter(pk__in=(1, 2), reporter__pk=3), ): sql, params = qs.query.sql_with_params() - cursor = qs.query.get_compiler(DEFAULT_DB_ALIAS).execute_sql(CURSOR) - self.assertEqual( - cursor.db.ops.last_executed_query(cursor, sql, params), - str(qs.query), - ) + with qs.query.get_compiler(DEFAULT_DB_ALIAS).execute_sql(CURSOR) as cursor: + self.assertEqual( + cursor.db.ops.last_executed_query(cursor, sql, params), + str(qs.query), + ) @skipUnlessDBFeature('supports_paramstyle_pyformat') def test_last_executed_query_dict(self): @@ -205,12 +205,14 @@ class ConnectionCreatedSignalTest(TransactionTestCase): connection_created.connect(receiver) connection.close() - connection.cursor() + with connection.cursor(): + pass self.assertIs(data["connection"].connection, connection.connection) connection_created.disconnect(receiver) data.clear() - connection.cursor() + with connection.cursor(): + pass self.assertEqual(data, {}) @@ -345,7 +347,8 @@ class BackendTestCase(TransactionTestCase): old_password = connection.settings_dict['PASSWORD'] connection.settings_dict['PASSWORD'] = "françois" try: - connection.cursor() + with connection.cursor(): + pass except DatabaseError: # As password is probably wrong, a database exception is expected pass @@ -639,7 +642,8 @@ class ThreadTests(TransactionTestCase): # Map connections by id because connections with identical aliases # have the same hash. connections_dict = {} - connection.cursor() + with connection.cursor(): + pass connections_dict[id(connection)] = connection def runner(): @@ -650,7 +654,8 @@ class ThreadTests(TransactionTestCase): # Allow thread sharing so the connection can be closed by the # main thread. connection.inc_thread_sharing() - connection.cursor() + with connection.cursor(): + pass connections_dict[id(connection)] = connection try: for x in range(2): @@ -729,6 +734,7 @@ class ThreadTests(TransactionTestCase): do_thread() # Forbidden! self.assertIsInstance(exceptions[0], DatabaseError) + connections['default'].close() # After calling inc_thread_sharing() on the connection. connections['default'].inc_thread_sharing() diff --git a/tests/fixtures/tests.py b/tests/fixtures/tests.py index 02dd38e6350..16e90dea382 100644 --- a/tests/fixtures/tests.py +++ b/tests/fixtures/tests.py @@ -564,7 +564,8 @@ class FixtureLoadingTests(DumpDataAssertMixin, TestCase): # This won't affect other tests because the database connection # is closed at the end of each test. if connection.vendor == 'mysql': - connection.cursor().execute("SET sql_mode = 'TRADITIONAL'") + with connection.cursor() as cursor: + cursor.execute("SET sql_mode = 'TRADITIONAL'") with self.assertRaises(IntegrityError) as cm: management.call_command('loaddata', 'invalid.json', verbosity=0) self.assertIn("Could not load fixtures.Article(pk=1):", cm.exception.args[0]) diff --git a/tests/indexes/tests.py b/tests/indexes/tests.py index 5ef2835f4eb..274ee54a373 100644 --- a/tests/indexes/tests.py +++ b/tests/indexes/tests.py @@ -270,9 +270,10 @@ class SchemaIndexesMySQLTests(TransactionTestCase): MySQL on InnoDB already creates indexes automatically for foreign keys. (#14180). An index should be created if db_constraint=False (#26171). """ - storage = connection.introspection.get_storage_engine( - connection.cursor(), ArticleTranslation._meta.db_table - ) + with connection.cursor() as cursor: + storage = connection.introspection.get_storage_engine( + cursor, ArticleTranslation._meta.db_table, + ) if storage != "InnoDB": self.skip("This test only applies to the InnoDB storage engine") index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)] @@ -326,9 +327,10 @@ class PartialIndexTests(TransactionTestCase): str(index.create_sql(Article, schema_editor=editor)) ) editor.add_index(index=index, model=Article) - self.assertIn(index.name, connection.introspection.get_constraints( - cursor=connection.cursor(), table_name=Article._meta.db_table, - )) + with connection.cursor() as cursor: + self.assertIn(index.name, connection.introspection.get_constraints( + cursor=cursor, table_name=Article._meta.db_table, + )) editor.remove_index(index=index, model=Article) def test_integer_restriction_partial(self): @@ -343,9 +345,10 @@ class PartialIndexTests(TransactionTestCase): str(index.create_sql(Article, schema_editor=editor)) ) editor.add_index(index=index, model=Article) - self.assertIn(index.name, connection.introspection.get_constraints( - cursor=connection.cursor(), table_name=Article._meta.db_table, - )) + with connection.cursor() as cursor: + self.assertIn(index.name, connection.introspection.get_constraints( + cursor=cursor, table_name=Article._meta.db_table, + )) editor.remove_index(index=index, model=Article) def test_boolean_restriction_partial(self): @@ -360,9 +363,10 @@ class PartialIndexTests(TransactionTestCase): str(index.create_sql(Article, schema_editor=editor)) ) editor.add_index(index=index, model=Article) - self.assertIn(index.name, connection.introspection.get_constraints( - cursor=connection.cursor(), table_name=Article._meta.db_table, - )) + with connection.cursor() as cursor: + self.assertIn(index.name, connection.introspection.get_constraints( + cursor=cursor, table_name=Article._meta.db_table, + )) editor.remove_index(index=index, model=Article) @skipUnlessDBFeature('supports_functions_in_partial_indexes') @@ -390,9 +394,10 @@ class PartialIndexTests(TransactionTestCase): # check ONLY the occurrence of headline in the SQL. self.assertGreater(sql.rfind('headline'), where) editor.add_index(index=index, model=Article) - self.assertIn(index.name, connection.introspection.get_constraints( - cursor=connection.cursor(), table_name=Article._meta.db_table, - )) + with connection.cursor() as cursor: + self.assertIn(index.name, connection.introspection.get_constraints( + cursor=cursor, table_name=Article._meta.db_table, + )) editor.remove_index(index=index, model=Article) def test_is_null_condition(self): @@ -407,7 +412,8 @@ class PartialIndexTests(TransactionTestCase): str(index.create_sql(Article, schema_editor=editor)) ) editor.add_index(index=index, model=Article) - self.assertIn(index.name, connection.introspection.get_constraints( - cursor=connection.cursor(), table_name=Article._meta.db_table, - )) + with connection.cursor() as cursor: + self.assertIn(index.name, connection.introspection.get_constraints( + cursor=cursor, table_name=Article._meta.db_table, + )) editor.remove_index(index=index, model=Article) diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index 32007ebeb28..c61358f4c80 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -226,7 +226,8 @@ class AssertNumQueriesUponConnectionTests(TransactionTestCase): if is_opening_connection: # Avoid infinite recursion. Creating a cursor calls # ensure_connection() which is currently mocked by this method. - connection.cursor().execute('SELECT 1' + connection.features.bare_select_suffix) + with connection.cursor() as cursor: + cursor.execute('SELECT 1' + connection.features.bare_select_suffix) ensure_connection = 'django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection' with mock.patch(ensure_connection, side_effect=make_configuration_query):