Ensure cursors are closed when no longer needed.

This commit touchs various parts of the code base and test framework. Any
found usage of opening a cursor for the sake of initializing a connection
has been replaced with 'ensure_connection()'.
This commit is contained in:
Michael Manfre 2014-01-09 10:05:15 -05:00
parent 0837eacc4e
commit 3ffeb93186
31 changed files with 657 additions and 615 deletions

View File

@ -11,7 +11,7 @@ class PostGISCreation(DatabaseCreation):
@cached_property @cached_property
def template_postgis(self): def template_postgis(self):
template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis') template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,)) cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
if cursor.fetchone(): if cursor.fetchone():
return template_postgis return template_postgis
@ -88,7 +88,7 @@ class PostGISCreation(DatabaseCreation):
# Connect to the test database in order to create the postgis extension # Connect to the test database in order to create the postgis extension
self.connection.close() self.connection.close()
self.connection.settings_dict["NAME"] = test_database_name self.connection.settings_dict["NAME"] = test_database_name
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis") cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
cursor.connection.commit() cursor.connection.commit()

View File

@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
call_command('createcachetable', database=self.connection.alias) call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has # Ensure a connection for the side effect of initializing the test database.
# the side effect of initializing the test database. self.connection.ensure_connection()
self.connection.cursor()
return test_database_name return test_database_name

View File

@ -33,7 +33,7 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
if sequence_sql: if sequence_sql:
if verbosity >= 2: if verbosity >= 2:
print("Resetting sequence") print("Resetting sequence")
cursor = connections[db].cursor() with connections[db].cursor() as cursor:
for command in sequence_sql: for command in sequence_sql:
cursor.execute(command) cursor.execute(command)

View File

@ -59,8 +59,8 @@ class DatabaseCache(BaseDatabaseCache):
self.validate_key(key) self.validate_key(key)
db = router.db_for_read(self.cache_model_class) db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key, value, expires FROM %s " cursor.execute("SELECT cache_key, value, expires FROM %s "
"WHERE cache_key = %%s" % table, [key]) "WHERE cache_key = %%s" % table, [key])
row = cursor.fetchone() row = cursor.fetchone()
@ -75,7 +75,7 @@ class DatabaseCache(BaseDatabaseCache):
expires = typecast_timestamp(str(expires)) expires = typecast_timestamp(str(expires))
if expires < now: if expires < now:
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
cursor = connections[db].cursor() with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s " cursor.execute("DELETE FROM %s "
"WHERE cache_key = %%s" % table, [key]) "WHERE cache_key = %%s" % table, [key])
return default return default
@ -96,8 +96,8 @@ class DatabaseCache(BaseDatabaseCache):
timeout = self.get_backend_timeout(timeout) timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM %s" % table) cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0] num = cursor.fetchone()[0]
now = timezone.now() now = timezone.now()
@ -152,8 +152,8 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key]) cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
def has_key(self, key, version=None): def has_key(self, key, version=None):
@ -162,13 +162,14 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_read(self.cache_model_class) db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
if settings.USE_TZ: if settings.USE_TZ:
now = datetime.utcnow() now = datetime.utcnow()
else: else:
now = datetime.now() now = datetime.now()
now = now.replace(microsecond=0) now = now.replace(microsecond=0)
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key FROM %s " cursor.execute("SELECT cache_key FROM %s "
"WHERE cache_key = %%s and expires > %%s" % table, "WHERE cache_key = %%s and expires > %%s" % table,
[key, connections[db].ops.value_to_db_datetime(now)]) [key, connections[db].ops.value_to_db_datetime(now)])
@ -197,7 +198,7 @@ class DatabaseCache(BaseDatabaseCache):
def clear(self): def clear(self):
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor() with connections[db].cursor() as cursor:
cursor.execute('DELETE FROM %s' % table) cursor.execute('DELETE FROM %s' % table)

View File

@ -72,7 +72,7 @@ class Command(BaseCommand):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else '')) full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');') full_statement.append(');')
with transaction.commit_on_success_unless_managed(): with transaction.commit_on_success_unless_managed():
curs = connection.cursor() with connection.cursor() as curs:
try: try:
curs.execute("\n".join(full_statement)) curs.execute("\n".join(full_statement))
except DatabaseError as e: except DatabaseError as e:

View File

@ -64,7 +64,7 @@ Are you sure you want to do this?
if confirm == 'yes': if confirm == 'yes':
try: try:
with transaction.commit_on_success_unless_managed(): with transaction.commit_on_success_unless_managed():
cursor = connection.cursor() with connection.cursor() as cursor:
for sql in sql_list: for sql in sql_list:
cursor.execute(sql) cursor.execute(sql)
except Exception as e: except Exception as e:

View File

@ -37,7 +37,7 @@ class Command(NoArgsCommand):
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '') table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s[1:] if s.startswith("u'") else s strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
cursor = connection.cursor() with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module." yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:" yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order" yield "# * Rearrange models' order"

View File

@ -100,10 +100,9 @@ class Command(BaseCommand):
if sequence_sql: if sequence_sql:
if self.verbosity >= 2: if self.verbosity >= 2:
self.stdout.write("Resetting sequences\n") self.stdout.write("Resetting sequences\n")
cursor = connection.cursor() with connection.cursor() as cursor:
for line in sequence_sql: for line in sequence_sql:
cursor.execute(line) cursor.execute(line)
cursor.close()
if self.verbosity >= 1: if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count: if self.fixture_object_count == self.loaded_object_count:

View File

@ -171,8 +171,9 @@ class Command(BaseCommand):
"Runs the old syncdb-style operation on a list of app_labels." "Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor() cursor = connection.cursor()
try:
# Get a list of already installed *models* so that references work right. # Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names() tables = connection.introspection.table_names(cursor)
seen_models = connection.introspection.installed_models(tables) seen_models = connection.introspection.installed_models(tables)
created_models = set() created_models = set()
pending_references = {} pending_references = {}
@ -226,10 +227,12 @@ class Command(BaseCommand):
# We force a commit here, as that was the previous behaviour. # We force a commit here, as that was the previous behaviour.
# If you can prove we don't need this, remove it. # If you can prove we don't need this, remove it.
transaction.set_dirty(using=connection.alias) transaction.set_dirty(using=connection.alias)
finally:
cursor.close()
# The connection may have been closed by a syncdb handler. # The connection may have been closed by a syncdb handler.
cursor = connection.cursor() cursor = connection.cursor()
try:
# Install custom SQL for the app (but only if this # Install custom SQL for the app (but only if this
# is a model we've just created) # is a model we've just created)
if self.verbosity >= 1: if self.verbosity >= 1:
@ -270,6 +273,8 @@ class Command(BaseCommand):
cursor.execute(sql) cursor.execute(sql)
except Exception as e: except Exception as e:
self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e)) self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
finally:
cursor.close()
# Load initial_data fixtures (unless that has been disabled) # Load initial_data fixtures (unless that has been disabled)
if self.load_initial_data: if self.load_initial_data:

View File

@ -67,6 +67,7 @@ def sql_delete(app_config, style, connection):
except Exception: except Exception:
cursor = None cursor = None
try:
# Figure out which tables already exist # Figure out which tables already exist
if cursor: if cursor:
table_names = connection.introspection.table_names(cursor) table_names = connection.introspection.table_names(cursor)
@ -93,7 +94,7 @@ def sql_delete(app_config, style, connection):
for model in app_models: for model in app_models:
if connection.introspection.table_name_converter(model._meta.db_table) in table_names: if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style)) output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
finally:
# Close database connection explicitly, in case this output is being piped # Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues. # directly into a database client, to avoid locking issues.
if cursor: if cursor:

View File

@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods ##### ##### Backend-specific savepoint management methods #####
def _savepoint(self, sid): def _savepoint(self, sid):
self.cursor().execute(self.ops.savepoint_create_sql(sid)) with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid): def _savepoint_rollback(self, sid):
self.cursor().execute(self.ops.savepoint_rollback_sql(sid)) with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid): def _savepoint_commit(self, sid):
self.cursor().execute(self.ops.savepoint_commit_sql(sid)) with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self): def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction # Savepoints cannot be created outside a transaction
@ -688,7 +691,7 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to # otherwise autocommit will cause the confimation to
# fail. # fail.
self.connection.enter_transaction_management() self.connection.enter_transaction_management()
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.commit() self.connection.commit()
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases. in sorting order between databases.
""" """
if cursor is None: if cursor is None:
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor)) return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor): def get_table_list(self, cursor):

View File

@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias) call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has # Ensure a connection for the side effect of initializing the test database.
# the side effect of initializing the test database. self.connection.ensure_connection()
self.connection.cursor()
return test_database_name return test_database_name
@ -406,7 +405,7 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name qn = self.connection.ops.quote_name
# Create the test database and connect to it. # Create the test database and connect to it.
cursor = self._nodb_connection.cursor() with self._nodb_connection.cursor() as cursor:
try: try:
cursor.execute( cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
@ -461,7 +460,7 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database) # ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being # to do so, because it's not allowed to delete a database while being
# connected to it. # connected to it.
cursor = self._nodb_connection.cursor() with self._nodb_connection.cursor() as cursor:
# Wait to avoid "database is being accessed by other users" errors. # Wait to avoid "database is being accessed by other users" errors.
time.sleep(1) time.sleep(1)
cursor.execute("DROP DATABASE %s" cursor.execute("DROP DATABASE %s"

View File

@ -180,7 +180,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property @cached_property
def _mysql_storage_engine(self): def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code" "Internal method used in Django tests. Don't rely on this from your code"
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
# This command is MySQL specific; the second column # This command is MySQL specific; the second column
# will tell you the default table type of the created # will tell you the default table type of the created
@ -207,7 +207,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False return False
# Test if the time zone definitions are installed. # Test if the time zone definitions are installed.
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None return cursor.fetchone() is not None
@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn return conn
def init_connection_state(self): def init_connection_state(self):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for # on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with # NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards. # SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0') cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
def create_cursor(self): def create_cursor(self):
cursor = self.connection.cursor() cursor = self.connection.cursor()

View File

@ -353,7 +353,7 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type): def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create # If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur. # a cursor to initialize it and recur.
self.connection.cursor() with self.connection.cursor():
return self.connection.ops.regex_lookup(lookup_type) return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self): def return_insert_id(self):

View File

@ -149,7 +149,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz: if conn_tz != tz:
cursor = self.connection.cursor() cursor = self.connection.cursor()
try:
cursor.execute(self.ops.set_time_zone_sql(), [tz]) cursor.execute(self.ops.set_time_zone_sql(), [tz])
finally:
cursor.close() cursor.close()
# Commit after setting the time zone (see #17062) # Commit after setting the time zone (see #17062)
if not self.get_autocommit(): if not self.get_autocommit():

View File

@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'): if hasattr(connection, 'server_version'):
return connection.server_version return connection.server_version
else: else:
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT version()") cursor.execute("SELECT version()")
return _parse_version(cursor.fetchone()[0]) return _parse_version(cursor.fetchone()[0])

View File

@ -86,13 +86,12 @@ class BaseDatabaseSchemaEditor(object):
""" """
Executes the given SQL statement, with optional parameters. Executes the given SQL statement, with optional parameters.
""" """
# Get the cursor
cursor = self.connection.cursor()
# Log the command we're running, then run it # Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params)) logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql: if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";") self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else: else:
with self.connection.cursor() as cursor:
cursor.execute(sql, params) cursor.execute(sql, params)
def quote_name(self, name): def quote_name(self, name):
@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions Returns all constraint names matching the columns and conditions
""" """
column_names = list(column_names) if column_names else None column_names = list(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = [] result = []
for name, infodict in constraints.items(): for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']: if column_names is None or column_names == infodict['columns']:

View File

@ -122,7 +122,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check rule out support for STDDEV. We need to manually check
whether the call works. whether the call works.
""" """
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)') cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try: try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST') cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')

View File

@ -1522,6 +1522,7 @@ class RawQuerySet(object):
query = iter(self.query) query = iter(self.query)
try:
# Find out which columns are model's fields, and which ones should be # Find out which columns are model's fields, and which ones should be
# annotated to the model. # annotated to the model.
for pos, column in enumerate(self.columns): for pos, column in enumerate(self.columns):
@ -1570,6 +1571,10 @@ class RawQuerySet(object):
instance._state.adding = False instance._state.adding = False
yield instance yield instance
finally:
# Done iterating the Query. If it has its own cursor, close it.
if hasattr(self.query, 'cursor') and self.query.cursor:
self.query.cursor.close()
def __repr__(self): def __repr__(self):
text = self.raw_query text = self.raw_query

View File

@ -1,4 +1,5 @@
import datetime import datetime
import sys
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
@ -777,7 +778,7 @@ class SQLCompiler(object):
cursor = self.connection.cursor() cursor = self.connection.cursor()
try: try:
cursor.execute(sql, params) cursor.execute(sql, params)
except: except Exception:
cursor.close() cursor.close()
raise raise
@ -908,7 +909,7 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False): def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1) assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id self.return_id = return_id
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
for sql, params in self.as_sql(): for sql, params in self.as_sql():
cursor.execute(sql, params) cursor.execute(sql, params)
if not (return_id and cursor): if not (return_id and cursor):

View File

@ -59,7 +59,7 @@ class OracleChecks(unittest.TestCase):
# stored procedure through our cursor wrapper. # stored procedure through our cursor wrapper.
from django.db.backends.oracle.base import convert_unicode from django.db.backends.oracle.base import convert_unicode
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'), cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!')]) [convert_unicode('_django_testing!')])
@ -70,7 +70,7 @@ class OracleChecks(unittest.TestCase):
# as query parameters. # as query parameters.
from django.db.backends.oracle.base import Database from django.db.backends.oracle.base import Database
cursor = connection.cursor() with connection.cursor() as cursor:
var = cursor.var(Database.STRING) var = cursor.var(Database.STRING)
cursor.execute("BEGIN %s := 'X'; END; ", [var]) cursor.execute("BEGIN %s := 'X'; END; ", [var])
self.assertEqual(var.getvalue(), 'X') self.assertEqual(var.getvalue(), 'X')
@ -80,21 +80,21 @@ class OracleChecks(unittest.TestCase):
def test_long_string(self): def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer # If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly # than 4000 chars and read it properly
c = connection.cursor() with connection.cursor() as cursor:
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)') cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join(six.text_type(x) for x in xrange(4000)) long_str = ''.join(six.text_type(x) for x in xrange(4000))
c.execute('INSERT INTO ltext VALUES (%s)', [long_str]) cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext') cursor.execute('SELECT text FROM ltext')
row = c.fetchone() row = cursor.fetchone()
self.assertEqual(long_str, row[0].read()) self.assertEqual(long_str, row[0].read())
c.execute('DROP TABLE ltext') cursor.execute('DROP TABLE ltext')
@unittest.skipUnless(connection.vendor == 'oracle', @unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle connection semantics") "No need to check Oracle connection semantics")
def test_client_encoding(self): def test_client_encoding(self):
# If the backend is Oracle, test that the client encoding is set # If the backend is Oracle, test that the client encoding is set
# correctly. This was broken under Cygwin prior to r14781. # correctly. This was broken under Cygwin prior to r14781.
connection.cursor() # Ensure the connection is initialized. self.connection.ensure_connection()
self.assertEqual(connection.connection.encoding, "UTF-8") self.assertEqual(connection.connection.encoding, "UTF-8")
self.assertEqual(connection.connection.nencoding, "UTF-8") self.assertEqual(connection.connection.nencoding, "UTF-8")
@ -103,12 +103,12 @@ class OracleChecks(unittest.TestCase):
def test_order_of_nls_parameters(self): def test_order_of_nls_parameters(self):
# an 'almost right' datetime should work with configured # an 'almost right' datetime should work with configured
# NLS parameters as per #18465. # NLS parameters as per #18465.
c = connection.cursor() with connection.cursor() as cursor:
query = "select 1 from dual where '1936-12-29 00:00' < sysdate" query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
# Test that the query succeeds without errors - pre #18465 this # Test that the query succeeds without errors - pre #18465 this
# wasn't the case. # wasn't the case.
c.execute(query) cursor.execute(query)
self.assertEqual(c.fetchone()[0], 1) self.assertEqual(cursor.fetchone()[0], 1)
class SQLiteTests(TestCase): class SQLiteTests(TestCase):
@ -328,6 +328,12 @@ class PostgresVersionTest(TestCase):
def fetchone(self): def fetchone(self):
return ["PostgreSQL 8.3"] return ["PostgreSQL 8.3"]
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class OlderConnectionMock(object): class OlderConnectionMock(object):
"Mock of psycopg2 (< 2.0.12) connection" "Mock of psycopg2 (< 2.0.12) connection"
def cursor(self): def cursor(self):

View File

@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
management.call_command('createcachetable', verbosity=0, interactive=False) management.call_command('createcachetable', verbosity=0, interactive=False)
def drop_table(self): def drop_table(self):
cursor = connection.cursor() with connection.cursor() as cursor:
table_name = connection.ops.quote_name('test cache table') table_name = connection.ops.quote_name('test cache table')
cursor.execute('DROP TABLE %s' % table_name) cursor.execute('DROP TABLE %s' % table_name)
cursor.close()
def test_zero_cull(self): def test_zero_cull(self):
self._perform_cull_test(caches['zero_cull'], 50, 18) self._perform_cull_test(caches['zero_cull'], 50, 18)

View File

@ -30,7 +30,7 @@ class Article(models.Model):
database query for the sake of demonstration. database query for the sake of demonstration.
""" """
from django.db import connection from django.db import connection
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute(""" cursor.execute("""
SELECT id, headline, pub_date SELECT id, headline, pub_date
FROM custom_methods_article FROM custom_methods_article

View File

@ -28,7 +28,7 @@ class InitialSQLTests(TestCase):
connection = connections[DEFAULT_DB_ALIAS] connection = connections[DEFAULT_DB_ALIAS]
custom_sql = custom_sql_for_model(Simple, no_style(), connection) custom_sql = custom_sql_for_model(Simple, no_style(), connection)
self.assertEqual(len(custom_sql), 9) self.assertEqual(len(custom_sql), 9)
cursor = connection.cursor() with connection.cursor() as cursor:
for sql in custom_sql: for sql in custom_sql:
cursor.execute(sql) cursor.execute(sql)
self.assertEqual(Simple.objects.count(), 9) self.assertEqual(Simple.objects.count(), 9)

View File

@ -23,7 +23,7 @@ class IntrospectionTests(TestCase):
"'%s' isn't in table_list()." % Article._meta.db_table) "'%s' isn't in table_list()." % Article._meta.db_table)
def test_django_table_names(self): def test_django_table_names(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names() tl = connection.introspection.django_table_names()
cursor.execute("DROP TABLE django_ixn_test_table;") cursor.execute("DROP TABLE django_ixn_test_table;")
@ -32,7 +32,7 @@ class IntrospectionTests(TestCase):
def test_django_table_names_retval_type(self): def test_django_table_names_retval_type(self):
# Ticket #15216 # Ticket #15216
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names(only_existing=True) tl = connection.introspection.django_table_names(only_existing=True)
@ -53,13 +53,13 @@ class IntrospectionTests(TestCase):
'Reporter sequence not found in sequence_list()') 'Reporter sequence not found in sequence_list()')
def test_get_table_description_names(self): def test_get_table_description_names(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual([r[0] for r in desc], self.assertEqual([r[0] for r in desc],
[f.column for f in Reporter._meta.fields]) [f.column for f in Reporter._meta.fields])
def test_get_table_description_types(self): def test_get_table_description_types(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
# The MySQL exception is due to the cursor.description returning the same constant for # The MySQL exception is due to the cursor.description returning the same constant for
# text and blob columns. TODO: use information_schema database to retrieve the proper # text and blob columns. TODO: use information_schema database to retrieve the proper
@ -75,7 +75,7 @@ class IntrospectionTests(TestCase):
# inspect the length of character columns). # inspect the length of character columns).
@expectedFailureOnOracle @expectedFailureOnOracle
def test_get_table_description_col_lengths(self): def test_get_table_description_col_lengths(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual( self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'], [r[3] for r in desc if datatype(r[1], r) == 'CharField'],
@ -87,7 +87,7 @@ class IntrospectionTests(TestCase):
# so its idea about null_ok in cursor.description is different from ours. # so its idea about null_ok in cursor.description is different from ours.
@skipIfDBFeature('interprets_empty_strings_as_nulls') @skipIfDBFeature('interprets_empty_strings_as_nulls')
def test_get_table_description_nullable(self): def test_get_table_description_nullable(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual( self.assertEqual(
[r[6] for r in desc], [r[6] for r in desc],
@ -97,14 +97,14 @@ class IntrospectionTests(TestCase):
# Regression test for #9991 - 'real' types in postgres # Regression test for #9991 - 'real' types in postgres
@skipUnlessDBFeature('has_real_datatype') @skipUnlessDBFeature('has_real_datatype')
def test_postgresql_real_type(self): def test_postgresql_real_type(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);") cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table') desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
cursor.execute('DROP TABLE django_ixn_real_test_table;') cursor.execute('DROP TABLE django_ixn_real_test_table;')
self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField') self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
def test_get_relations(self): def test_get_relations(self):
cursor = connection.cursor() with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table) relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# Older versions of MySQL don't have the chops to report on this stuff, # Older versions of MySQL don't have the chops to report on this stuff,
@ -117,7 +117,7 @@ class IntrospectionTests(TestCase):
@skipUnlessDBFeature('can_introspect_foreign_keys') @skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self): def test_get_key_columns(self):
cursor = connection.cursor() with connection.cursor() as cursor:
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table) key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual( self.assertEqual(
set(key_columns), set(key_columns),
@ -125,12 +125,12 @@ class IntrospectionTests(TestCase):
('response_to_id', Article._meta.db_table, 'id')])) ('response_to_id', Article._meta.db_table, 'id')]))
def test_get_primary_key_column(self): def test_get_primary_key_column(self):
cursor = connection.cursor() with connection.cursor() as cursor:
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, 'id') self.assertEqual(primary_key_column, 'id')
def test_get_indexes(self): def test_get_indexes(self):
cursor = connection.cursor() with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table) indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False}) self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
@ -139,7 +139,7 @@ class IntrospectionTests(TestCase):
Test that multicolumn indexes are not included in the introspection Test that multicolumn indexes are not included in the introspection
results. results.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table) indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
self.assertNotIn('first_name', indexes) self.assertNotIn('first_name', indexes)
self.assertIn('id', indexes) self.assertIn('id', indexes)

View File

@ -9,30 +9,37 @@ class MigrationTestBase(TransactionTestCase):
available_apps = ["migrations"] available_apps = ["migrations"]
def get_table_description(self, table):
with connection.cursor() as cursor:
return connection.introspection.get_table_description(cursor, table)
def assertTableExists(self, table): def assertTableExists(self, table):
self.assertIn(table, connection.introspection.get_table_list(connection.cursor())) with connection.cursor() as cursor:
self.assertIn(table, connection.introspection.get_table_list(cursor))
def assertTableNotExists(self, table): def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor())) with connection.cursor() as cursor:
self.assertNotIn(table, connection.introspection.get_table_list(cursor))
def assertColumnExists(self, table, column): def assertColumnExists(self, table, column):
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) self.assertIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNotExists(self, table, column): def assertColumnNotExists(self, table, column):
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNull(self, table, column): def assertColumnNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True) self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
def assertColumnNotNull(self, table, column): def assertColumnNotNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False) self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
def assertIndexExists(self, table, columns, value=True): def assertIndexExists(self, table, columns, value=True):
with connection.cursor() as cursor:
self.assertEqual( self.assertEqual(
value, value,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), table).values() for c in connection.introspection.get_constraints(cursor, table).values()
if c['columns'] == list(columns) if c['columns'] == list(columns)
), ),
) )

View File

@ -19,7 +19,7 @@ class OperationTests(MigrationTestBase):
Creates a test model state and database table. Creates a test model state and database table.
""" """
# Delete the tables if they already exist # Delete the tables if they already exist
cursor = connection.cursor() with connection.cursor() as cursor:
try: try:
cursor.execute("DROP TABLE %s_pony" % app_label) cursor.execute("DROP TABLE %s_pony" % app_label)
except: except:
@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
operation.state_forwards("test_alflpkfk", new_state) operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField) self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
# Test the database alteration
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] def assertIdTypeEqualsFkType(self):
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] with connection.cursor() as cursor:
id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type) self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# Test the database alteration
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state) operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] assertIdTypeEqualsFkType()
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
# And test reversal # And test reversal
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state) operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] assertIdTypeEqualsFkType()
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
def test_rename_field(self): def test_rename_field(self):
""" """
@ -400,7 +400,7 @@ class OperationTests(MigrationTestBase):
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0) self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1) self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows # Make sure we can insert duplicate rows
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony") cursor.execute("DELETE FROM test_alunto_pony")

View File

@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
# request_finished signal. # request_finished signal.
response = self.client.get('/') response = self.client.get('/')
# Make sure there is an open connection # Make sure there is an open connection
connection.cursor() self.connection.ensure_connection()
connection.enter_transaction_management() connection.enter_transaction_management()
signals.request_finished.send(sender=response._handler_class) signals.request_finished.send(sender=response._handler_class)
self.assertEqual(len(connection.transaction_state), 0) self.assertEqual(len(connection.transaction_state), 0)

View File

@ -37,7 +37,7 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self): def delete_tables(self):
"Deletes all model tables for our models for a clean test environment" "Deletes all model tables for our models for a clean test environment"
cursor = connection.cursor() with connection.cursor() as cursor:
connection.disable_constraint_checking() connection.disable_constraint_checking()
table_names = connection.introspection.table_names(cursor) table_names = connection.introspection.table_names(cursor)
for model in self.models: for model in self.models:
@ -61,7 +61,7 @@ class SchemaTests(TransactionTestCase):
connection.enable_constraint_checking() connection.enable_constraint_checking()
def column_classes(self, model): def column_classes(self, model):
cursor = connection.cursor() with connection.cursor() as cursor:
columns = dict( columns = dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d)) (d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description( for d in connection.introspection.get_table_description(
@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
raise DatabaseError("Table does not exist (empty pragma)") raise DatabaseError("Table does not exist (empty pragma)")
return columns return columns
def get_indexes(self, table):
"""
Get the indexes on the table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_indexes(cursor, table)
def get_constraints(self, table):
"""
Get the constraints on a table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
# Tests # Tests
def test_creation_deletion(self): def test_creation_deletion(self):
@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
strict=True, strict=True,
) )
# Make sure the new FK constraint is present # Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table) constraints = self.get_constraints(Book._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["author_id"] and details['foreign_key']: if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id')) self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(TagM2MTest) editor.create_model(TagM2MTest)
editor.create_model(UniqueTest) editor.create_model(UniqueTest)
# Ensure the M2M exists and points to TagM2MTest # Ensure the M2M exists and points to TagM2MTest
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table) constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys: if connection.features.supports_foreign_keys:
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']: if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
# Ensure old M2M is gone # Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through) self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
# Ensure the new M2M exists and points to UniqueTest # Ensure the new M2M exists and points to UniqueTest
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table) constraints = self.get_constraints(new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys: if connection.features.supports_foreign_keys:
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["uniquetest_id"] and details['foreign_key']: if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_model(Author) editor.create_model(Author)
# Ensure the constraint exists # Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']: if details['columns'] == ["height"] and details['check']:
break break
@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
new_field, new_field,
strict=True, strict=True,
) )
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']: if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found") self.fail("Check constraint for height found")
@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
Author._meta.get_field_by_name("height")[0], Author._meta.get_field_by_name("height")[0],
strict=True, strict=True,
) )
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']: if details['columns'] == ["height"] and details['check']:
break break
@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
False, False,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
True, True,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
False, False,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
True, True,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values() for c in self.get_constraints("schema_tagindexed").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the right index # Ensure the table is there and has the right index
self.assertIn( self.assertIn(
"title", "title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Alter to remove the index # Alter to remove the index
new_field = CharField(max_length=100, db_index=False) new_field = CharField(max_length=100, db_index=False)
@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has no index # Ensure the table is there and has no index
self.assertNotIn( self.assertNotIn(
"title", "title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Alter to re-add the index # Alter to re-add the index
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the index again # Ensure the table is there and has the index again
self.assertIn( self.assertIn(
"title", "title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Add a unique column, verify that creates an implicit index # Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
) )
self.assertIn( self.assertIn(
"slug", "slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Remove the unique, check the index goes with it # Remove the unique, check the index goes with it
new_field2 = CharField(max_length=20, unique=False) new_field2 = CharField(max_length=20, unique=False)
@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
) )
self.assertNotIn( self.assertNotIn(
"slug", "slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
def test_primary_key(self): def test_primary_key(self):
@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(Tag) editor.create_model(Tag)
# Ensure the table is there and has the right PK # Ensure the table is there and has the right PK
self.assertTrue( self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'], self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
) )
# Alter to change the PK # Alter to change the PK
new_field = SlugField(primary_key=True) new_field = SlugField(primary_key=True)
@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
# Ensure the PK changed # Ensure the PK changed
self.assertNotIn( self.assertNotIn(
'id', 'id',
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table), self.get_indexes(Tag._meta.db_table),
) )
self.assertTrue( self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'], self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
) )
def test_context_manager_exit(self): def test_context_manager_exit(self):
@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has an index on the column # Ensure the table is there and has an index on the column
self.assertIn( self.assertIn(
column_name, column_name,
connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table), self.get_indexes(BookWithLongName._meta.db_table),
) )
def test_creation_deletion_reserved_names(self): def test_creation_deletion_reserved_names(self):

View File

@ -202,7 +202,8 @@ class AtomicTests(TransactionTestCase):
# trigger a database error inside an inner atomic without savepoint # trigger a database error inside an inner atomic without savepoint
with self.assertRaises(DatabaseError): with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False): with transaction.atomic(savepoint=False):
connection.cursor().execute( with connection.cursor() as cursor:
cursor.execute(
"SELECT no_such_col FROM transactions_reporter") "SELECT no_such_col FROM transactions_reporter")
# prevent atomic from rolling back since we're recovering manually # prevent atomic from rolling back since we're recovering manually
self.assertTrue(transaction.get_rollback()) self.assertTrue(transaction.get_rollback())
@ -534,7 +535,7 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
available_apps = ['transactions'] available_apps = ['transactions']
def execute_bad_sql(self): def execute_bad_sql(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
@skipUnlessDBFeature('requires_rollback_on_dirty_transaction') @skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
""" """
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
with transaction.commit_on_success(): with transaction.commit_on_success():
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
transaction.rollback() transaction.rollback()

View File

@ -54,7 +54,7 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
@commit_on_success @commit_on_success
def raw_sql(): def raw_sql():
"Write a record using raw sql under a commit_on_success decorator" "Write a record using raw sql under a commit_on_success decorator"
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (18)") cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
raw_sql() raw_sql()
@ -143,7 +143,7 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
(reference). All this under commit_on_success, so the second insert should (reference). All this under commit_on_success, so the second insert should
be committed. be committed.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
transaction.rollback() transaction.rollback()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")