Fixed #31233 -- Closed database connections and cursors after use.

This commit is contained in:
Jon Dufresne 2020-02-03 19:07:00 -08:00 committed by Mariusz Felisiak
parent f48f671223
commit 3259983f56
14 changed files with 86 additions and 55 deletions

View File

@ -269,7 +269,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
forward references. Always return True to indicate constraint checks
need to be re-enabled.
"""
self.cursor().execute('SET foreign_key_checks=0')
with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=0')
return True
def enable_constraint_checking(self):
@ -280,7 +281,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
self.cursor().execute('SET foreign_key_checks=1')
with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=1')
finally:
self.needs_rollback = needs_rollback

View File

@ -301,8 +301,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
with self.cursor() as cursor:
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:

View File

@ -277,13 +277,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
with self.cursor() as cursor:
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1")
with self.connection.cursor() as cursor:
cursor.execute('SELECT 1')
except Database.Error:
return False
else:

View File

@ -296,7 +296,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return not bool(enabled)
def enable_constraint_checking(self):
self.cursor().execute('PRAGMA foreign_keys = ON')
with self.cursor() as cursor:
cursor.execute('PRAGMA foreign_keys = ON')
def check_constraints(self, table_names=None):
"""
@ -309,7 +310,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor:
if table_names is None:
violations = self.cursor().execute('PRAGMA foreign_key_check').fetchall()
violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
else:
violations = chain.from_iterable(
cursor.execute('PRAGMA foreign_key_check(%s)' % table_name).fetchall()

View File

@ -372,10 +372,8 @@ class MigrationExecutor:
else:
found_add_field_migration = True
continue
columns = self.connection.introspection.get_table_description(
self.connection.cursor(),
table,
)
with self.connection.cursor() as cursor:
columns = self.connection.introspection.get_table_description(cursor, table)
for column in columns:
field_column = field.column
column_name = column.name

View File

@ -52,7 +52,9 @@ class MigrationRecorder:
def has_table(self):
"""Return True if the django_migrations table exists."""
return self.Migration._meta.db_table in self.connection.introspection.table_names(self.connection.cursor())
with self.connection.cursor() as cursor:
tables = self.connection.introspection.table_names(cursor)
return self.Migration._meta.db_table in tables
def ensure_schema(self):
"""Ensure the table exists and has the correct schema."""

View File

@ -747,7 +747,10 @@ class QuerySet:
query = self.query.clone()
query.__class__ = sql.DeleteQuery
cursor = query.get_compiler(using).execute_sql(CURSOR)
return cursor.rowcount if cursor else 0
if cursor:
with cursor:
return cursor.rowcount
return 0
_raw_delete.alters_data = True
def update(self, **kwargs):

View File

@ -21,7 +21,10 @@ class DeleteQuery(Query):
self.alias_map = {table: self.alias_map[table]}
self.where = where
cursor = self.get_compiler(using).execute_sql(CURSOR)
return cursor.rowcount if cursor else 0
if cursor:
with cursor:
return cursor.rowcount
return 0
def delete_batch(self, pk_list, using):
"""

View File

@ -114,8 +114,8 @@ class Tests(TestCase):
try:
# Open a database connection.
new_connection.cursor()
self.assertFalse(new_connection.get_autocommit())
with new_connection.cursor():
self.assertFalse(new_connection.get_autocommit())
finally:
new_connection.close()
@ -149,9 +149,12 @@ class Tests(TestCase):
def test_connect_no_is_usable_checks(self):
new_connection = connection.copy()
with mock.patch.object(new_connection, 'is_usable') as is_usable:
new_connection.connect()
is_usable.assert_not_called()
try:
with mock.patch.object(new_connection, 'is_usable') as is_usable:
new_connection.connect()
is_usable.assert_not_called()
finally:
new_connection.close()
def _select(self, val):
with connection.cursor() as cursor:

View File

@ -197,7 +197,8 @@ class LastExecutedQueryTest(TestCase):
def test_no_interpolation(self):
# This shouldn't raise an exception (#17158)
query = "SELECT strftime('%Y', 'now');"
connection.cursor().execute(query)
with connection.cursor() as cursor:
cursor.execute(query)
self.assertEqual(connection.queries[-1]['sql'], query)
def test_parameter_quoting(self):
@ -205,7 +206,8 @@ class LastExecutedQueryTest(TestCase):
# worth testing that parameters are quoted (#14091).
query = "SELECT %s"
params = ["\"'\\"]
connection.cursor().execute(query, params)
with connection.cursor() as cursor:
cursor.execute(query, params)
# Note that the single quote is repeated
substituted = "SELECT '\"''\\'"
self.assertEqual(connection.queries[-1]['sql'], substituted)

View File

@ -69,8 +69,8 @@ class LastExecutedQueryTest(TestCase):
"""last_executed_query() returns a string."""
data = RawData.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1})
sql, params = data.query.sql_with_params()
cursor = data.query.get_compiler('default').execute_sql(CURSOR)
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
with data.query.get_compiler('default').execute_sql(CURSOR) as cursor:
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
self.assertIsInstance(last_sql, str)
def test_last_executed_query(self):
@ -81,11 +81,11 @@ class LastExecutedQueryTest(TestCase):
Article.objects.filter(pk__in=(1, 2), reporter__pk=3),
):
sql, params = qs.query.sql_with_params()
cursor = qs.query.get_compiler(DEFAULT_DB_ALIAS).execute_sql(CURSOR)
self.assertEqual(
cursor.db.ops.last_executed_query(cursor, sql, params),
str(qs.query),
)
with qs.query.get_compiler(DEFAULT_DB_ALIAS).execute_sql(CURSOR) as cursor:
self.assertEqual(
cursor.db.ops.last_executed_query(cursor, sql, params),
str(qs.query),
)
@skipUnlessDBFeature('supports_paramstyle_pyformat')
def test_last_executed_query_dict(self):
@ -205,12 +205,14 @@ class ConnectionCreatedSignalTest(TransactionTestCase):
connection_created.connect(receiver)
connection.close()
connection.cursor()
with connection.cursor():
pass
self.assertIs(data["connection"].connection, connection.connection)
connection_created.disconnect(receiver)
data.clear()
connection.cursor()
with connection.cursor():
pass
self.assertEqual(data, {})
@ -345,7 +347,8 @@ class BackendTestCase(TransactionTestCase):
old_password = connection.settings_dict['PASSWORD']
connection.settings_dict['PASSWORD'] = "françois"
try:
connection.cursor()
with connection.cursor():
pass
except DatabaseError:
# As password is probably wrong, a database exception is expected
pass
@ -639,7 +642,8 @@ class ThreadTests(TransactionTestCase):
# Map connections by id because connections with identical aliases
# have the same hash.
connections_dict = {}
connection.cursor()
with connection.cursor():
pass
connections_dict[id(connection)] = connection
def runner():
@ -650,7 +654,8 @@ class ThreadTests(TransactionTestCase):
# Allow thread sharing so the connection can be closed by the
# main thread.
connection.inc_thread_sharing()
connection.cursor()
with connection.cursor():
pass
connections_dict[id(connection)] = connection
try:
for x in range(2):
@ -729,6 +734,7 @@ class ThreadTests(TransactionTestCase):
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
connections['default'].close()
# After calling inc_thread_sharing() on the connection.
connections['default'].inc_thread_sharing()

View File

@ -564,7 +564,8 @@ class FixtureLoadingTests(DumpDataAssertMixin, TestCase):
# This won't affect other tests because the database connection
# is closed at the end of each test.
if connection.vendor == 'mysql':
connection.cursor().execute("SET sql_mode = 'TRADITIONAL'")
with connection.cursor() as cursor:
cursor.execute("SET sql_mode = 'TRADITIONAL'")
with self.assertRaises(IntegrityError) as cm:
management.call_command('loaddata', 'invalid.json', verbosity=0)
self.assertIn("Could not load fixtures.Article(pk=1):", cm.exception.args[0])

View File

@ -270,9 +270,10 @@ class SchemaIndexesMySQLTests(TransactionTestCase):
MySQL on InnoDB already creates indexes automatically for foreign keys.
(#14180). An index should be created if db_constraint=False (#26171).
"""
storage = connection.introspection.get_storage_engine(
connection.cursor(), ArticleTranslation._meta.db_table
)
with connection.cursor() as cursor:
storage = connection.introspection.get_storage_engine(
cursor, ArticleTranslation._meta.db_table,
)
if storage != "InnoDB":
self.skip("This test only applies to the InnoDB storage engine")
index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)]
@ -326,9 +327,10 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor))
)
editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=connection.cursor(), table_name=Article._meta.db_table,
))
with connection.cursor() as cursor:
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article)
def test_integer_restriction_partial(self):
@ -343,9 +345,10 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor))
)
editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=connection.cursor(), table_name=Article._meta.db_table,
))
with connection.cursor() as cursor:
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article)
def test_boolean_restriction_partial(self):
@ -360,9 +363,10 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor))
)
editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=connection.cursor(), table_name=Article._meta.db_table,
))
with connection.cursor() as cursor:
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article)
@skipUnlessDBFeature('supports_functions_in_partial_indexes')
@ -390,9 +394,10 @@ class PartialIndexTests(TransactionTestCase):
# check ONLY the occurrence of headline in the SQL.
self.assertGreater(sql.rfind('headline'), where)
editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=connection.cursor(), table_name=Article._meta.db_table,
))
with connection.cursor() as cursor:
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article)
def test_is_null_condition(self):
@ -407,7 +412,8 @@ class PartialIndexTests(TransactionTestCase):
str(index.create_sql(Article, schema_editor=editor))
)
editor.add_index(index=index, model=Article)
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=connection.cursor(), table_name=Article._meta.db_table,
))
with connection.cursor() as cursor:
self.assertIn(index.name, connection.introspection.get_constraints(
cursor=cursor, table_name=Article._meta.db_table,
))
editor.remove_index(index=index, model=Article)

View File

@ -226,7 +226,8 @@ class AssertNumQueriesUponConnectionTests(TransactionTestCase):
if is_opening_connection:
# Avoid infinite recursion. Creating a cursor calls
# ensure_connection() which is currently mocked by this method.
connection.cursor().execute('SELECT 1' + connection.features.bare_select_suffix)
with connection.cursor() as cursor:
cursor.execute('SELECT 1' + connection.features.bare_select_suffix)
ensure_connection = 'django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection'
with mock.patch(ensure_connection, side_effect=make_configuration_query):