Ensure cursors are closed when no longer needed.

This commit touchs various parts of the code base and test framework. Any
found usage of opening a cursor for the sake of initializing a connection
has been replaced with 'ensure_connection()'.
This commit is contained in:
Michael Manfre 2014-01-09 10:05:15 -05:00
parent 0837eacc4e
commit 3ffeb93186
31 changed files with 657 additions and 615 deletions

View File

@ -11,7 +11,7 @@ class PostGISCreation(DatabaseCreation):
@cached_property
def template_postgis(self):
template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
if cursor.fetchone():
return template_postgis
@ -88,7 +88,7 @@ class PostGISCreation(DatabaseCreation):
# Connect to the test database in order to create the postgis extension
self.connection.close()
self.connection.settings_dict["NAME"] = test_database_name
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
cursor.connection.commit()

View File

@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
self.connection.cursor()
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name

View File

@ -33,7 +33,7 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
if sequence_sql:
if verbosity >= 2:
print("Resetting sequence")
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
for command in sequence_sql:
cursor.execute(command)

View File

@ -59,8 +59,8 @@ class DatabaseCache(BaseDatabaseCache):
self.validate_key(key)
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key, value, expires FROM %s "
"WHERE cache_key = %%s" % table, [key])
row = cursor.fetchone()
@ -75,7 +75,7 @@ class DatabaseCache(BaseDatabaseCache):
expires = typecast_timestamp(str(expires))
if expires < now:
db = router.db_for_write(self.cache_model_class)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s "
"WHERE cache_key = %%s" % table, [key])
return default
@ -96,8 +96,8 @@ class DatabaseCache(BaseDatabaseCache):
timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0]
now = timezone.now()
@ -152,8 +152,8 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
def has_key(self, key, version=None):
@ -162,13 +162,14 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
if settings.USE_TZ:
now = datetime.utcnow()
else:
now = datetime.now()
now = now.replace(microsecond=0)
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key FROM %s "
"WHERE cache_key = %%s and expires > %%s" % table,
[key, connections[db].ops.value_to_db_datetime(now)])
@ -197,7 +198,7 @@ class DatabaseCache(BaseDatabaseCache):
def clear(self):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute('DELETE FROM %s' % table)

View File

@ -72,7 +72,7 @@ class Command(BaseCommand):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');')
with transaction.commit_on_success_unless_managed():
curs = connection.cursor()
with connection.cursor() as curs:
try:
curs.execute("\n".join(full_statement))
except DatabaseError as e:

View File

@ -64,7 +64,7 @@ Are you sure you want to do this?
if confirm == 'yes':
try:
with transaction.commit_on_success_unless_managed():
cursor = connection.cursor()
with connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
except Exception as e:

View File

@ -37,7 +37,7 @@ class Command(NoArgsCommand):
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
cursor = connection.cursor()
with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order"

View File

@ -100,10 +100,9 @@ class Command(BaseCommand):
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write("Resetting sequences\n")
cursor = connection.cursor()
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
cursor.close()
if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count:

View File

@ -171,8 +171,9 @@ class Command(BaseCommand):
"Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor()
try:
# Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names()
tables = connection.introspection.table_names(cursor)
seen_models = connection.introspection.installed_models(tables)
created_models = set()
pending_references = {}
@ -226,10 +227,12 @@ class Command(BaseCommand):
# We force a commit here, as that was the previous behaviour.
# If you can prove we don't need this, remove it.
transaction.set_dirty(using=connection.alias)
finally:
cursor.close()
# The connection may have been closed by a syncdb handler.
cursor = connection.cursor()
try:
# Install custom SQL for the app (but only if this
# is a model we've just created)
if self.verbosity >= 1:
@ -270,6 +273,8 @@ class Command(BaseCommand):
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
finally:
cursor.close()
# Load initial_data fixtures (unless that has been disabled)
if self.load_initial_data:

View File

@ -67,6 +67,7 @@ def sql_delete(app_config, style, connection):
except Exception:
cursor = None
try:
# Figure out which tables already exist
if cursor:
table_names = connection.introspection.table_names(cursor)
@ -93,7 +94,7 @@ def sql_delete(app_config, style, connection):
for model in app_models:
if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
finally:
# Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues.
if cursor:

View File

@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
self.cursor().execute(self.ops.savepoint_create_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
self.cursor().execute(self.ops.savepoint_rollback_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
self.cursor().execute(self.ops.savepoint_commit_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
@ -688,7 +691,7 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to
# fail.
self.connection.enter_transaction_management()
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.commit()
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases.
"""
if cursor is None:
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor):

View File

@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
self.connection.cursor()
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name
@ -406,7 +405,7 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name
# Create the test database and connect to it.
cursor = self._nodb_connection.cursor()
with self._nodb_connection.cursor() as cursor:
try:
cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
@ -461,7 +460,7 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = self._nodb_connection.cursor()
with self._nodb_connection.cursor() as cursor:
# Wait to avoid "database is being accessed by other users" errors.
time.sleep(1)
cursor.execute("DROP DATABASE %s"

View File

@ -180,7 +180,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
# This command is MySQL specific; the second column
# will tell you the default table type of the created
@ -207,7 +207,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
# Test if the time zone definitions are installed.
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None
@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn
def init_connection_state(self):
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
def create_cursor(self):
cursor = self.connection.cursor()

View File

@ -353,7 +353,7 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur.
self.connection.cursor()
with self.connection.cursor():
return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):

View File

@ -149,7 +149,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz:
cursor = self.connection.cursor()
try:
cursor.execute(self.ops.set_time_zone_sql(), [tz])
finally:
cursor.close()
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():

View File

@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'):
return connection.server_version
else:
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("SELECT version()")
return _parse_version(cursor.fetchone()[0])

View File

@ -86,13 +86,12 @@ class BaseDatabaseSchemaEditor(object):
"""
Executes the given SQL statement, with optional parameters.
"""
# Get the cursor
cursor = self.connection.cursor()
# Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else:
with self.connection.cursor() as cursor:
cursor.execute(sql, params)
def quote_name(self, name):
@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions
"""
column_names = list(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:

View File

@ -122,7 +122,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check
whether the call works.
"""
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')

View File

@ -1522,6 +1522,7 @@ class RawQuerySet(object):
query = iter(self.query)
try:
# Find out which columns are model's fields, and which ones should be
# annotated to the model.
for pos, column in enumerate(self.columns):
@ -1570,6 +1571,10 @@ class RawQuerySet(object):
instance._state.adding = False
yield instance
finally:
# Done iterating the Query. If it has its own cursor, close it.
if hasattr(self.query, 'cursor') and self.query.cursor:
self.query.cursor.close()
def __repr__(self):
text = self.raw_query

View File

@ -1,4 +1,5 @@
import datetime
import sys
from django.conf import settings
from django.core.exceptions import FieldError
@ -777,7 +778,7 @@ class SQLCompiler(object):
cursor = self.connection.cursor()
try:
cursor.execute(sql, params)
except:
except Exception:
cursor.close()
raise
@ -908,7 +909,7 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor):

View File

@ -59,7 +59,7 @@ class OracleChecks(unittest.TestCase):
# stored procedure through our cursor wrapper.
from django.db.backends.oracle.base import convert_unicode
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!')])
@ -70,7 +70,7 @@ class OracleChecks(unittest.TestCase):
# as query parameters.
from django.db.backends.oracle.base import Database
cursor = connection.cursor()
with connection.cursor() as cursor:
var = cursor.var(Database.STRING)
cursor.execute("BEGIN %s := 'X'; END; ", [var])
self.assertEqual(var.getvalue(), 'X')
@ -80,21 +80,21 @@ class OracleChecks(unittest.TestCase):
def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly
c = connection.cursor()
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join(six.text_type(x) for x in xrange(4000))
c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext')
row = c.fetchone()
cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
cursor.execute('SELECT text FROM ltext')
row = cursor.fetchone()
self.assertEqual(long_str, row[0].read())
c.execute('DROP TABLE ltext')
cursor.execute('DROP TABLE ltext')
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle connection semantics")
def test_client_encoding(self):
# If the backend is Oracle, test that the client encoding is set
# correctly. This was broken under Cygwin prior to r14781.
connection.cursor() # Ensure the connection is initialized.
self.connection.ensure_connection()
self.assertEqual(connection.connection.encoding, "UTF-8")
self.assertEqual(connection.connection.nencoding, "UTF-8")
@ -103,12 +103,12 @@ class OracleChecks(unittest.TestCase):
def test_order_of_nls_parameters(self):
# an 'almost right' datetime should work with configured
# NLS parameters as per #18465.
c = connection.cursor()
with connection.cursor() as cursor:
query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
# Test that the query succeeds without errors - pre #18465 this
# wasn't the case.
c.execute(query)
self.assertEqual(c.fetchone()[0], 1)
cursor.execute(query)
self.assertEqual(cursor.fetchone()[0], 1)
class SQLiteTests(TestCase):
@ -328,6 +328,12 @@ class PostgresVersionTest(TestCase):
def fetchone(self):
return ["PostgreSQL 8.3"]
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class OlderConnectionMock(object):
"Mock of psycopg2 (< 2.0.12) connection"
def cursor(self):

View File

@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
management.call_command('createcachetable', verbosity=0, interactive=False)
def drop_table(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
table_name = connection.ops.quote_name('test cache table')
cursor.execute('DROP TABLE %s' % table_name)
cursor.close()
def test_zero_cull(self):
self._perform_cull_test(caches['zero_cull'], 50, 18)

View File

@ -30,7 +30,7 @@ class Article(models.Model):
database query for the sake of demonstration.
"""
from django.db import connection
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("""
SELECT id, headline, pub_date
FROM custom_methods_article

View File

@ -28,7 +28,7 @@ class InitialSQLTests(TestCase):
connection = connections[DEFAULT_DB_ALIAS]
custom_sql = custom_sql_for_model(Simple, no_style(), connection)
self.assertEqual(len(custom_sql), 9)
cursor = connection.cursor()
with connection.cursor() as cursor:
for sql in custom_sql:
cursor.execute(sql)
self.assertEqual(Simple.objects.count(), 9)

View File

@ -23,7 +23,7 @@ class IntrospectionTests(TestCase):
"'%s' isn't in table_list()." % Article._meta.db_table)
def test_django_table_names(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names()
cursor.execute("DROP TABLE django_ixn_test_table;")
@ -32,7 +32,7 @@ class IntrospectionTests(TestCase):
def test_django_table_names_retval_type(self):
# Ticket #15216
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names(only_existing=True)
@ -53,13 +53,13 @@ class IntrospectionTests(TestCase):
'Reporter sequence not found in sequence_list()')
def test_get_table_description_names(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual([r[0] for r in desc],
[f.column for f in Reporter._meta.fields])
def test_get_table_description_types(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
# The MySQL exception is due to the cursor.description returning the same constant for
# text and blob columns. TODO: use information_schema database to retrieve the proper
@ -75,7 +75,7 @@ class IntrospectionTests(TestCase):
# inspect the length of character columns).
@expectedFailureOnOracle
def test_get_table_description_col_lengths(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'],
@ -87,7 +87,7 @@ class IntrospectionTests(TestCase):
# so its idea about null_ok in cursor.description is different from ours.
@skipIfDBFeature('interprets_empty_strings_as_nulls')
def test_get_table_description_nullable(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[6] for r in desc],
@ -97,14 +97,14 @@ class IntrospectionTests(TestCase):
# Regression test for #9991 - 'real' types in postgres
@skipUnlessDBFeature('has_real_datatype')
def test_postgresql_real_type(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
cursor.execute('DROP TABLE django_ixn_real_test_table;')
self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
def test_get_relations(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# Older versions of MySQL don't have the chops to report on this stuff,
@ -117,7 +117,7 @@ class IntrospectionTests(TestCase):
@skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual(
set(key_columns),
@ -125,12 +125,12 @@ class IntrospectionTests(TestCase):
('response_to_id', Article._meta.db_table, 'id')]))
def test_get_primary_key_column(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, 'id')
def test_get_indexes(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
@ -139,7 +139,7 @@ class IntrospectionTests(TestCase):
Test that multicolumn indexes are not included in the introspection
results.
"""
cursor = connection.cursor()
with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
self.assertNotIn('first_name', indexes)
self.assertIn('id', indexes)

View File

@ -9,30 +9,37 @@ class MigrationTestBase(TransactionTestCase):
available_apps = ["migrations"]
def get_table_description(self, table):
with connection.cursor() as cursor:
return connection.introspection.get_table_description(cursor, table)
def assertTableExists(self, table):
self.assertIn(table, connection.introspection.get_table_list(connection.cursor()))
with connection.cursor() as cursor:
self.assertIn(table, connection.introspection.get_table_list(cursor))
def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
with connection.cursor() as cursor:
self.assertNotIn(table, connection.introspection.get_table_list(cursor))
def assertColumnExists(self, table, column):
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
self.assertIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNotExists(self, table, column):
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True)
self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
def assertColumnNotNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False)
self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
def assertIndexExists(self, table, columns, value=True):
with connection.cursor() as cursor:
self.assertEqual(
value,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), table).values()
for c in connection.introspection.get_constraints(cursor, table).values()
if c['columns'] == list(columns)
),
)

View File

@ -19,7 +19,7 @@ class OperationTests(MigrationTestBase):
Creates a test model state and database table.
"""
# Delete the tables if they already exist
cursor = connection.cursor()
with connection.cursor() as cursor:
try:
cursor.execute("DROP TABLE %s_pony" % app_label)
except:
@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
# Test the database alteration
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
def assertIdTypeEqualsFkType(self):
with connection.cursor() as cursor:
id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# Test the database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
def test_rename_field(self):
"""
@ -400,7 +400,7 @@ class OperationTests(MigrationTestBase):
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")

View File

@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
# request_finished signal.
response = self.client.get('/')
# Make sure there is an open connection
connection.cursor()
self.connection.ensure_connection()
connection.enter_transaction_management()
signals.request_finished.send(sender=response._handler_class)
self.assertEqual(len(connection.transaction_state), 0)

View File

@ -37,7 +37,7 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self):
"Deletes all model tables for our models for a clean test environment"
cursor = connection.cursor()
with connection.cursor() as cursor:
connection.disable_constraint_checking()
table_names = connection.introspection.table_names(cursor)
for model in self.models:
@ -61,7 +61,7 @@ class SchemaTests(TransactionTestCase):
connection.enable_constraint_checking()
def column_classes(self, model):
cursor = connection.cursor()
with connection.cursor() as cursor:
columns = dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description(
@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
raise DatabaseError("Table does not exist (empty pragma)")
return columns
def get_indexes(self, table):
"""
Get the indexes on the table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_indexes(cursor, table)
def get_constraints(self, table):
"""
Get the constraints on a table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
# Tests
def test_creation_deletion(self):
@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
strict=True,
)
# Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
constraints = self.get_constraints(Book._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(TagM2MTest)
editor.create_model(UniqueTest)
# Ensure the M2M exists and points to TagM2MTest
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
# Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
# Ensure the new M2M exists and points to UniqueTest
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
constraints = self.get_constraints(new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor:
editor.create_model(Author)
# Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
new_field,
strict=True,
)
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found")
@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
Author._meta.get_field_by_name("height")[0],
strict=True,
)
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values()
for c in self.get_constraints("schema_tagindexed").values()
if c['columns'] == ["slug", "title"]
),
)
@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the right index
self.assertIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Alter to remove the index
new_field = CharField(max_length=100, db_index=False)
@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has no index
self.assertNotIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Alter to re-add the index
with connection.schema_editor() as editor:
@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the index again
self.assertIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor:
@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertIn(
"slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Remove the unique, check the index goes with it
new_field2 = CharField(max_length=20, unique=False)
@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertNotIn(
"slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
def test_primary_key(self):
@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(Tag)
# Ensure the table is there and has the right PK
self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'],
self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
)
# Alter to change the PK
new_field = SlugField(primary_key=True)
@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
# Ensure the PK changed
self.assertNotIn(
'id',
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table),
self.get_indexes(Tag._meta.db_table),
)
self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'],
self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
)
def test_context_manager_exit(self):
@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has an index on the column
self.assertIn(
column_name,
connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table),
self.get_indexes(BookWithLongName._meta.db_table),
)
def test_creation_deletion_reserved_names(self):

View File

@ -202,7 +202,8 @@ class AtomicTests(TransactionTestCase):
# trigger a database error inside an inner atomic without savepoint
with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False):
connection.cursor().execute(
with connection.cursor() as cursor:
cursor.execute(
"SELECT no_such_col FROM transactions_reporter")
# prevent atomic from rolling back since we're recovering manually
self.assertTrue(transaction.get_rollback())
@ -534,7 +535,7 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
available_apps = ['transactions']
def execute_bad_sql(self):
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
@skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
"""
with self.assertRaises(IntegrityError):
with transaction.commit_on_success():
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
transaction.rollback()

View File

@ -54,7 +54,7 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
@commit_on_success
def raw_sql():
"Write a record using raw sql under a commit_on_success decorator"
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
raw_sql()
@ -143,7 +143,7 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
(reference). All this under commit_on_success, so the second insert should
be committed.
"""
cursor = connection.cursor()
with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
transaction.rollback()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")