Fixed #31233 -- Closed database connections and cursors after use.

This commit is contained in:
Jon Dufresne 2020-02-03 19:07:00 -08:00 committed by Mariusz Felisiak
parent f48f671223
commit 3259983f56
14 changed files with 86 additions and 55 deletions

View File

@ -269,7 +269,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
forward references. Always return True to indicate constraint checks forward references. Always return True to indicate constraint checks
need to be re-enabled. need to be re-enabled.
""" """
self.cursor().execute('SET foreign_key_checks=0') with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=0')
return True return True
def enable_constraint_checking(self): def enable_constraint_checking(self):
@ -280,7 +281,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# nested inside transaction.atomic. # nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback self.needs_rollback, needs_rollback = False, self.needs_rollback
try: try:
self.cursor().execute('SET foreign_key_checks=1') with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=1')
finally: finally:
self.needs_rollback = needs_rollback self.needs_rollback = needs_rollback

View File

@ -301,8 +301,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Check constraints by setting them to immediate. Return them to deferred Check constraints by setting them to immediate. Return them to deferred
afterward. afterward.
""" """
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') with self.cursor() as cursor:
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self): def is_usable(self):
try: try:

View File

@ -277,13 +277,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Check constraints by setting them to immediate. Return them to deferred Check constraints by setting them to immediate. Return them to deferred
afterward. afterward.
""" """
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') with self.cursor() as cursor:
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self): def is_usable(self):
try: try:
# Use a psycopg cursor directly, bypassing Django's utilities. # Use a psycopg cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1") with self.connection.cursor() as cursor:
cursor.execute('SELECT 1')
except Database.Error: except Database.Error:
return False return False
else: else:

View File

@ -296,7 +296,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return not bool(enabled) return not bool(enabled)
def enable_constraint_checking(self): def enable_constraint_checking(self):
self.cursor().execute('PRAGMA foreign_keys = ON') with self.cursor() as cursor:
cursor.execute('PRAGMA foreign_keys = ON')
def check_constraints(self, table_names=None): def check_constraints(self, table_names=None):
""" """
@ -309,7 +310,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if self.features.supports_pragma_foreign_key_check: if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor: with self.cursor() as cursor:
if table_names is None: if table_names is None:
violations = self.cursor().execute('PRAGMA foreign_key_check').fetchall() violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
else: else:
violations = chain.from_iterable( violations = chain.from_iterable(
cursor.execute('PRAGMA foreign_key_check(%s)' % table_name).fetchall() cursor.execute('PRAGMA foreign_key_check(%s)' % table_name).fetchall()

View File

@ -372,10 +372,8 @@ class MigrationExecutor:
else: else:
found_add_field_migration = True found_add_field_migration = True
continue continue
columns = self.connection.introspection.get_table_description( with self.connection.cursor() as cursor:
self.connection.cursor(), columns = self.connection.introspection.get_table_description(cursor, table)
table,
)
for column in columns: for column in columns:
field_column = field.column field_column = field.column
column_name = column.name column_name = column.name

View File

@ -52,7 +52,9 @@ class MigrationRecorder:
def has_table(self): def has_table(self):
"""Return True if the django_migrations table exists.""" """Return True if the django_migrations table exists."""
return self.Migration._meta.db_table in self.connection.introspection.table_names(self.connection.cursor()) with self.connection.cursor() as cursor:
tables = self.connection.introspection.table_names(cursor)
return self.Migration._meta.db_table in tables
def ensure_schema(self): def ensure_schema(self):
"""Ensure the table exists and has the correct schema.""" """Ensure the table exists and has the correct schema."""

View File

@ -747,7 +747,10 @@ class QuerySet:
query = self.query.clone() query = self.query.clone()
query.__class__ = sql.DeleteQuery query.__class__ = sql.DeleteQuery
cursor = query.get_compiler(using).execute_sql(CURSOR) cursor = query.get_compiler(using).execute_sql(CURSOR)
return cursor.rowcount if cursor else 0 if cursor:
with cursor:
return cursor.rowcount
return 0
_raw_delete.alters_data = True _raw_delete.alters_data = True
def update(self, **kwargs): def update(self, **kwargs):

View File

@ -21,7 +21,10 @@ class DeleteQuery(Query):
self.alias_map = {table: self.alias_map[table]} self.alias_map = {table: self.alias_map[table]}
self.where = where self.where = where
cursor = self.get_compiler(using).execute_sql(CURSOR) cursor = self.get_compiler(using).execute_sql(CURSOR)
return cursor.rowcount if cursor else 0 if cursor:
with cursor:
return cursor.rowcount
return 0
def delete_batch(self, pk_list, using): def delete_batch(self, pk_list, using):
""" """

View File

@ -114,8 +114,8 @@ class Tests(TestCase):
try: try:
# Open a database connection. # Open a database connection.
new_connection.cursor() with new_connection.cursor():
self.assertFalse(new_connection.get_autocommit()) self.assertFalse(new_connection.get_autocommit())
finally: finally:
new_connection.close() new_connection.close()
@ -149,9 +149,12 @@ class Tests(TestCase):
def test_connect_no_is_usable_checks(self): def test_connect_no_is_usable_checks(self):
new_connection = connection.copy() new_connection = connection.copy()
with mock.patch.object(new_connection, 'is_usable') as is_usable: try:
new_connection.connect() with mock.patch.object(new_connection, 'is_usable') as is_usable:
is_usable.assert_not_called() new_connection.connect()
is_usable.assert_not_called()
finally:
new_connection.close()
def _select(self, val): def _select(self, val):
with connection.cursor() as cursor: with connection.cursor() as cursor:

View File

@ -197,7 +197,8 @@ class LastExecutedQueryTest(TestCase):
def test_no_interpolation(self): def test_no_interpolation(self):
# This shouldn't raise an exception (#17158) # This shouldn't raise an exception (#17158)
query = "SELECT strftime('%Y', 'now');" query = "SELECT strftime('%Y', 'now');"
connection.cursor().execute(query) with connection.cursor() as cursor:
cursor.execute(query)
self.assertEqual(connection.queries[-1]['sql'], query) self.assertEqual(connection.queries[-1]['sql'], query)
def test_parameter_quoting(self): def test_parameter_quoting(self):
@ -205,7 +206,8 @@ class LastExecutedQueryTest(TestCase):
# worth testing that parameters are quoted (#14091). # worth testing that parameters are quoted (#14091).
query = "SELECT %s" query = "SELECT %s"
params = ["\"'\\"] params = ["\"'\\"]
connection.cursor().execute(query, params) with connection.cursor() as cursor:
cursor.execute(query, params)
# Note that the single quote is repeated # Note that the single quote is repeated
substituted = "SELECT '\"''\\'" substituted = "SELECT '\"''\\'"
self.assertEqual(connection.queries[-1]['sql'], substituted) self.assertEqual(connection.queries[-1]['sql'], substituted)

View File

@ -69,8 +69,8 @@ class LastExecutedQueryTest(TestCase):
"""last_executed_query() returns a string.""" """last_executed_query() returns a string."""
data = RawData.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1}) data = RawData.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1})
sql, params = data.query.sql_with_params() sql, params = data.query.sql_with_params()
cursor = data.query.get_compiler('default').execute_sql(CURSOR) with data.query.get_compiler('default').execute_sql(CURSOR) as cursor:
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
self.assertIsInstance(last_sql, str) self.assertIsInstance(last_sql, str)
def test_last_executed_query(self): def test_last_executed_query(self):
@ -81,11 +81,11 @@ class LastExecutedQueryTest(TestCase):
Article.objects.filter(pk__in=(1, 2), reporter__pk=3), Article.objects.filter(pk__in=(1, 2), reporter__pk=3),
): ):
sql, params = qs.query.sql_with_params() sql, params = qs.query.sql_with_params()
cursor = qs.query.get_compiler(DEFAULT_DB_ALIAS).execute_sql(CURSOR) with qs.query.get_compiler(DEFAULT_DB_ALIAS).execute_sql(CURSOR) as cursor:
self.assertEqual( self.assertEqual(
cursor.db.ops.last_executed_query(cursor, sql, params), cursor.db.ops.last_executed_query(cursor, sql, params),
str(qs.query), str(qs.query),
) )
@skipUnlessDBFeature('supports_paramstyle_pyformat') @skipUnlessDBFeature('supports_paramstyle_pyformat')
def test_last_executed_query_dict(self): def test_last_executed_query_dict(self):
@ -205,12 +205,14 @@ class ConnectionCreatedSignalTest(TransactionTestCase):
connection_created.connect(receiver) connection_created.connect(receiver)
connection.close() connection.close()
connection.cursor() with connection.cursor():
pass
self.assertIs(data["connection"].connection, connection.connection) self.assertIs(data["connection"].connection, connection.connection)
connection_created.disconnect(receiver) connection_created.disconnect(receiver)
data.clear() data.clear()
connection.cursor() with connection.cursor():
pass
self.assertEqual(data, {}) self.assertEqual(data, {})
@ -345,7 +347,8 @@ class BackendTestCase(TransactionTestCase):
old_password = connection.settings_dict['PASSWORD'] old_password = connection.settings_dict['PASSWORD']
connection.settings_dict['PASSWORD'] = "françois" connection.settings_dict['PASSWORD'] = "françois"
try: try:
connection.cursor() with connection.cursor():
pass
except DatabaseError: except DatabaseError:
# As password is probably wrong, a database exception is expected # As password is probably wrong, a database exception is expected
pass pass
@ -639,7 +642,8 @@ class ThreadTests(TransactionTestCase):
# Map connections by id because connections with identical aliases # Map connections by id because connections with identical aliases
# have the same hash. # have the same hash.
connections_dict = {} connections_dict = {}
connection.cursor() with connection.cursor():
pass
connections_dict[id(connection)] = connection connections_dict[id(connection)] = connection
def runner(): def runner():
@ -650,7 +654,8 @@ class ThreadTests(TransactionTestCase):
# Allow thread sharing so the connection can be closed by the # Allow thread sharing so the connection can be closed by the
# main thread. # main thread.
connection.inc_thread_sharing() connection.inc_thread_sharing()
connection.cursor() with connection.cursor():
pass
connections_dict[id(connection)] = connection connections_dict[id(connection)] = connection
try: try:
for x in range(2): for x in range(2):
@ -729,6 +734,7 @@ class ThreadTests(TransactionTestCase):
do_thread() do_thread()
# Forbidden! # Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError) self.assertIsInstance(exceptions[0], DatabaseError)
connections['default'].close()
# After calling inc_thread_sharing() on the connection. # After calling inc_thread_sharing() on the connection.
connections['default'].inc_thread_sharing() connections['default'].inc_thread_sharing()

View File

@ -564,7 +564,8 @@ class FixtureLoadingTests(DumpDataAssertMixin, TestCase):
# This won't affect other tests because the database connection # This won't affect other tests because the database connection
# is closed at the end of each test. # is closed at the end of each test.
if connection.vendor == 'mysql': if connection.vendor == 'mysql':
connection.cursor().execute("SET sql_mode = 'TRADITIONAL'") with connection.cursor() as cursor:
cursor.execute("SET sql_mode = 'TRADITIONAL'")
with self.assertRaises(IntegrityError) as cm: with self.assertRaises(IntegrityError) as cm:
management.call_command('loaddata', 'invalid.json', verbosity=0) management.call_command('loaddata', 'invalid.json', verbosity=0)
self.assertIn("Could not load fixtures.Article(pk=1):", cm.exception.args[0]) self.assertIn("Could not load fixtures.Article(pk=1):", cm.exception.args[0])

View File

@ -270,9 +270,10 @@ class SchemaIndexesMySQLTests(TransactionTestCase):
MySQL on InnoDB already creates indexes automatically for foreign keys. MySQL on InnoDB already creates indexes automatically for foreign keys.
(#14180). An index should be created if db_constraint=False (#26171). (#14180). An index should be created if db_constraint=False (#26171).
""" """
storage = connection.introspection.get_storage_engine( with connection.cursor() as cursor:
connection.cursor(), ArticleTranslation._meta.db_table storage = connection.introspection.get_storage_engine(
) cursor, ArticleTranslation._meta.db_table,
)
if storage != "InnoDB": if storage != "InnoDB":
self.skip("This test only applies to the InnoDB storage engine") self.skip("This test only applies to the InnoDB storage engine")
index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)] index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)]
@ -326,9 +327,10 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor)) str(index.create_sql(Article, schema_editor=editor))
) )
editor.add_index(index=index, model=Article) editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints( with connection.cursor() as cursor:
cursor=connection.cursor(), table_name=Article._meta.db_table, self.assertIn(index.name, connection.introspection.get_constraints(
)) cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article) editor.remove_index(index=index, model=Article)
def test_integer_restriction_partial(self): def test_integer_restriction_partial(self):
@ -343,9 +345,10 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor)) str(index.create_sql(Article, schema_editor=editor))
) )
editor.add_index(index=index, model=Article) editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints( with connection.cursor() as cursor:
cursor=connection.cursor(), table_name=Article._meta.db_table, self.assertIn(index.name, connection.introspection.get_constraints(
)) cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article) editor.remove_index(index=index, model=Article)
def test_boolean_restriction_partial(self): def test_boolean_restriction_partial(self):
@ -360,9 +363,10 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor)) str(index.create_sql(Article, schema_editor=editor))
) )
editor.add_index(index=index, model=Article) editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints( with connection.cursor() as cursor:
cursor=connection.cursor(), table_name=Article._meta.db_table, self.assertIn(index.name, connection.introspection.get_constraints(
)) cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article) editor.remove_index(index=index, model=Article)
@skipUnlessDBFeature('supports_functions_in_partial_indexes') @skipUnlessDBFeature('supports_functions_in_partial_indexes')
@ -390,9 +394,10 @@ class PartialIndexTests(TransactionTestCase):
# check ONLY the occurrence of headline in the SQL. # check ONLY the occurrence of headline in the SQL.
self.assertGreater(sql.rfind('headline'), where) self.assertGreater(sql.rfind('headline'), where)
editor.add_index(index=index, model=Article) editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints( with connection.cursor() as cursor:
cursor=connection.cursor(), table_name=Article._meta.db_table, self.assertIn(index.name, connection.introspection.get_constraints(
)) cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article) editor.remove_index(index=index, model=Article)
def test_is_null_condition(self): def test_is_null_condition(self):
@ -407,7 +412,8 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor)) str(index.create_sql(Article, schema_editor=editor))
) )
editor.add_index(index=index, model=Article) editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints( with connection.cursor() as cursor:
cursor=connection.cursor(), table_name=Article._meta.db_table, self.assertIn(index.name, connection.introspection.get_constraints(
)) cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article) editor.remove_index(index=index, model=Article)

View File

@ -226,7 +226,8 @@ class AssertNumQueriesUponConnectionTests(TransactionTestCase):
if is_opening_connection: if is_opening_connection:
# Avoid infinite recursion. Creating a cursor calls # Avoid infinite recursion. Creating a cursor calls
# ensure_connection() which is currently mocked by this method. # ensure_connection() which is currently mocked by this method.
connection.cursor().execute('SELECT 1' + connection.features.bare_select_suffix) with connection.cursor() as cursor:
cursor.execute('SELECT 1' + connection.features.bare_select_suffix)
ensure_connection = 'django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection' ensure_connection = 'django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection'
with mock.patch(ensure_connection, side_effect=make_configuration_query): with mock.patch(ensure_connection, side_effect=make_configuration_query):