Merge pull request #2154 from manfre/close-cursors

Fixed #21751 -- Explicitly closed cursors.
This commit is contained in:
Aymeric Augustin 2014-02-02 10:37:27 -08:00
commit 54bfa4caab
33 changed files with 725 additions and 641 deletions

View File

@ -11,10 +11,10 @@ class PostGISCreation(DatabaseCreation):
@cached_property
def template_postgis(self):
template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
cursor = self.connection.cursor()
cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
if cursor.fetchone():
return template_postgis
with self.connection.cursor() as cursor:
cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
if cursor.fetchone():
return template_postgis
return None
def sql_indexes_for_field(self, model, f, style):
@ -88,8 +88,8 @@ class PostGISCreation(DatabaseCreation):
# Connect to the test database in order to create the postgis extension
self.connection.close()
self.connection.settings_dict["NAME"] = test_database_name
cursor = self.connection.cursor()
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
cursor.connection.commit()
with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
cursor.connection.commit()
return test_database_name

View File

@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
self.connection.cursor()
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name

View File

@ -33,9 +33,9 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
if sequence_sql:
if verbosity >= 2:
print("Resetting sequence")
cursor = connections[db].cursor()
for command in sequence_sql:
cursor.execute(command)
with connections[db].cursor() as cursor:
for command in sequence_sql:
cursor.execute(command)
Site.objects.clear_cache()

View File

@ -59,11 +59,11 @@ class DatabaseCache(BaseDatabaseCache):
self.validate_key(key)
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute("SELECT cache_key, value, expires FROM %s "
"WHERE cache_key = %%s" % table, [key])
row = cursor.fetchone()
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key, value, expires FROM %s "
"WHERE cache_key = %%s" % table, [key])
row = cursor.fetchone()
if row is None:
return default
now = timezone.now()
@ -75,9 +75,9 @@ class DatabaseCache(BaseDatabaseCache):
expires = typecast_timestamp(str(expires))
if expires < now:
db = router.db_for_write(self.cache_model_class)
cursor = connections[db].cursor()
cursor.execute("DELETE FROM %s "
"WHERE cache_key = %%s" % table, [key])
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s "
"WHERE cache_key = %%s" % table, [key])
return default
value = connections[db].ops.process_clob(row[1])
return pickle.loads(base64.b64decode(force_bytes(value)))
@ -96,55 +96,55 @@ class DatabaseCache(BaseDatabaseCache):
timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0]
now = timezone.now()
now = now.replace(microsecond=0)
if timeout is None:
exp = datetime.max
elif settings.USE_TZ:
exp = datetime.utcfromtimestamp(timeout)
else:
exp = datetime.fromtimestamp(timeout)
exp = exp.replace(microsecond=0)
if num > self._max_entries:
self._cull(db, cursor, now)
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
b64encoded = base64.b64encode(pickled)
# The DB column is expecting a string, so make sure the value is a
# string, not bytes. Refs #19274.
if six.PY3:
b64encoded = b64encoded.decode('latin1')
try:
# Note: typecasting for datetimes is needed by some 3rd party
# database backends. All core backends work without typecasting,
# so be careful about changes here - test suite will NOT pick
# regressions.
with transaction.atomic(using=db):
cursor.execute("SELECT cache_key, expires FROM %s "
"WHERE cache_key = %%s" % table, [key])
result = cursor.fetchone()
if result:
current_expires = result[1]
if (connections[db].features.needs_datetime_string_cast and not
isinstance(current_expires, datetime)):
current_expires = typecast_timestamp(str(current_expires))
exp = connections[db].ops.value_to_db_datetime(exp)
if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
"WHERE cache_key = %%s" % table,
[b64encoded, exp, key])
else:
cursor.execute("INSERT INTO %s (cache_key, value, expires) "
"VALUES (%%s, %%s, %%s)" % table,
[key, b64encoded, exp])
except DatabaseError:
# To be threadsafe, updates/inserts are allowed to fail silently
return False
else:
return True
with connections[db].cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0]
now = timezone.now()
now = now.replace(microsecond=0)
if timeout is None:
exp = datetime.max
elif settings.USE_TZ:
exp = datetime.utcfromtimestamp(timeout)
else:
exp = datetime.fromtimestamp(timeout)
exp = exp.replace(microsecond=0)
if num > self._max_entries:
self._cull(db, cursor, now)
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
b64encoded = base64.b64encode(pickled)
# The DB column is expecting a string, so make sure the value is a
# string, not bytes. Refs #19274.
if six.PY3:
b64encoded = b64encoded.decode('latin1')
try:
# Note: typecasting for datetimes is needed by some 3rd party
# database backends. All core backends work without typecasting,
# so be careful about changes here - test suite will NOT pick
# regressions.
with transaction.atomic(using=db):
cursor.execute("SELECT cache_key, expires FROM %s "
"WHERE cache_key = %%s" % table, [key])
result = cursor.fetchone()
if result:
current_expires = result[1]
if (connections[db].features.needs_datetime_string_cast and not
isinstance(current_expires, datetime)):
current_expires = typecast_timestamp(str(current_expires))
exp = connections[db].ops.value_to_db_datetime(exp)
if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
"WHERE cache_key = %%s" % table,
[b64encoded, exp, key])
else:
cursor.execute("INSERT INTO %s (cache_key, value, expires) "
"VALUES (%%s, %%s, %%s)" % table,
[key, b64encoded, exp])
except DatabaseError:
# To be threadsafe, updates/inserts are allowed to fail silently
return False
else:
return True
def delete(self, key, version=None):
key = self.make_key(key, version=version)
@ -152,9 +152,9 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
def has_key(self, key, version=None):
key = self.make_key(key, version=version)
@ -162,17 +162,18 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
if settings.USE_TZ:
now = datetime.utcnow()
else:
now = datetime.now()
now = now.replace(microsecond=0)
cursor.execute("SELECT cache_key FROM %s "
"WHERE cache_key = %%s and expires > %%s" % table,
[key, connections[db].ops.value_to_db_datetime(now)])
return cursor.fetchone() is not None
with connections[db].cursor() as cursor:
cursor.execute("SELECT cache_key FROM %s "
"WHERE cache_key = %%s and expires > %%s" % table,
[key, connections[db].ops.value_to_db_datetime(now)])
return cursor.fetchone() is not None
def _cull(self, db, cursor, now):
if self._cull_frequency == 0:
@ -197,8 +198,8 @@ class DatabaseCache(BaseDatabaseCache):
def clear(self):
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute('DELETE FROM %s' % table)
with connections[db].cursor() as cursor:
cursor.execute('DELETE FROM %s' % table)
# For backwards compatibility

View File

@ -72,14 +72,14 @@ class Command(BaseCommand):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');')
with transaction.commit_on_success_unless_managed():
curs = connection.cursor()
try:
curs.execute("\n".join(full_statement))
except DatabaseError as e:
raise CommandError(
"Cache table '%s' could not be created.\nThe error was: %s." %
(tablename, force_text(e)))
for statement in index_output:
curs.execute(statement)
with connection.cursor() as curs:
try:
curs.execute("\n".join(full_statement))
except DatabaseError as e:
raise CommandError(
"Cache table '%s' could not be created.\nThe error was: %s." %
(tablename, force_text(e)))
for statement in index_output:
curs.execute(statement)
if self.verbosity > 1:
self.stdout.write("Cache table '%s' created." % tablename)

View File

@ -64,9 +64,9 @@ Are you sure you want to do this?
if confirm == 'yes':
try:
with transaction.commit_on_success_unless_managed():
cursor = connection.cursor()
for sql in sql_list:
cursor.execute(sql)
with connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
except Exception as e:
new_msg = (
"Database %s couldn't be flushed. Possible reasons:\n"

View File

@ -37,108 +37,108 @@ class Command(NoArgsCommand):
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
cursor = connection.cursor()
yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order"
yield "# * Make sure each model has one field with primary_key=True"
yield "# * Remove `managed = False` lines for those models you wish to give write DB access"
yield "# Feel free to rename the models, but don't rename db_table values or field names."
yield "#"
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
yield "# into your database."
yield "from __future__ import unicode_literals"
yield ''
yield 'from %s import models' % self.db_module
known_models = []
for table_name in connection.introspection.table_names(cursor):
if table_name_filter is not None and callable(table_name_filter):
if not table_name_filter(table_name):
continue
with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order"
yield "# * Make sure each model has one field with primary_key=True"
yield "# * Remove `managed = False` lines for those models you wish to give write DB access"
yield "# Feel free to rename the models, but don't rename db_table values or field names."
yield "#"
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
yield "# into your database."
yield "from __future__ import unicode_literals"
yield ''
yield ''
yield 'class %s(models.Model):' % table2model(table_name)
known_models.append(table2model(table_name))
try:
relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError:
relations = {}
try:
indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError:
indexes = {}
used_column_names = [] # Holds column names used in the table so far
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
column_name = row[0]
is_relation = i in relations
att_name, params, notes = self.normalize_col_name(
column_name, used_column_names, is_relation)
extra_params.update(params)
comment_notes.extend(notes)
used_column_names.append(att_name)
# Add primary_key and unique, if necessary.
if column_name in indexes:
if indexes[column_name]['primary_key']:
extra_params['primary_key'] = True
elif indexes[column_name]['unique']:
extra_params['unique'] = True
if is_relation:
rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
if rel_to in known_models:
field_type = 'ForeignKey(%s' % rel_to
else:
field_type = "ForeignKey('%s'" % rel_to
else:
# Calling `get_field_type` to get the field type string and any
# additional paramters and notes.
field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
extra_params.update(field_params)
comment_notes.extend(field_notes)
field_type += '('
# Don't output 'id = meta.AutoField(primary_key=True)', because
# that's assumed if it doesn't exist.
if att_name == 'id' and extra_params == {'primary_key': True}:
if field_type == 'AutoField(':
yield 'from %s import models' % self.db_module
known_models = []
for table_name in connection.introspection.table_names(cursor):
if table_name_filter is not None and callable(table_name_filter):
if not table_name_filter(table_name):
continue
elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
comment_notes.append('AutoField?')
yield ''
yield ''
yield 'class %s(models.Model):' % table2model(table_name)
known_models.append(table2model(table_name))
try:
relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError:
relations = {}
try:
indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError:
indexes = {}
used_column_names = [] # Holds column names used in the table so far
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
column_name = row[0]
is_relation = i in relations
# Add 'null' and 'blank', if the 'null_ok' flag was present in the
# table description.
if row[6]: # If it's NULL...
if field_type == 'BooleanField(':
field_type = 'NullBooleanField('
att_name, params, notes = self.normalize_col_name(
column_name, used_column_names, is_relation)
extra_params.update(params)
comment_notes.extend(notes)
used_column_names.append(att_name)
# Add primary_key and unique, if necessary.
if column_name in indexes:
if indexes[column_name]['primary_key']:
extra_params['primary_key'] = True
elif indexes[column_name]['unique']:
extra_params['unique'] = True
if is_relation:
rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
if rel_to in known_models:
field_type = 'ForeignKey(%s' % rel_to
else:
field_type = "ForeignKey('%s'" % rel_to
else:
extra_params['blank'] = True
if not field_type in ('TextField(', 'CharField('):
extra_params['null'] = True
# Calling `get_field_type` to get the field type string and any
# additional paramters and notes.
field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
extra_params.update(field_params)
comment_notes.extend(field_notes)
field_desc = '%s = %s%s' % (
att_name,
# Custom fields will have a dotted path
'' if '.' in field_type else 'models.',
field_type,
)
if extra_params:
if not field_desc.endswith('('):
field_desc += ', '
field_desc += ', '.join([
'%s=%s' % (k, strip_prefix(repr(v)))
for k, v in extra_params.items()])
field_desc += ')'
if comment_notes:
field_desc += ' # ' + ' '.join(comment_notes)
yield ' %s' % field_desc
for meta_line in self.get_meta(table_name):
yield meta_line
field_type += '('
# Don't output 'id = meta.AutoField(primary_key=True)', because
# that's assumed if it doesn't exist.
if att_name == 'id' and extra_params == {'primary_key': True}:
if field_type == 'AutoField(':
continue
elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
comment_notes.append('AutoField?')
# Add 'null' and 'blank', if the 'null_ok' flag was present in the
# table description.
if row[6]: # If it's NULL...
if field_type == 'BooleanField(':
field_type = 'NullBooleanField('
else:
extra_params['blank'] = True
if not field_type in ('TextField(', 'CharField('):
extra_params['null'] = True
field_desc = '%s = %s%s' % (
att_name,
# Custom fields will have a dotted path
'' if '.' in field_type else 'models.',
field_type,
)
if extra_params:
if not field_desc.endswith('('):
field_desc += ', '
field_desc += ', '.join([
'%s=%s' % (k, strip_prefix(repr(v)))
for k, v in extra_params.items()])
field_desc += ')'
if comment_notes:
field_desc += ' # ' + ' '.join(comment_notes)
yield ' %s' % field_desc
for meta_line in self.get_meta(table_name):
yield meta_line
def normalize_col_name(self, col_name, used_column_names, is_relation):
"""

View File

@ -100,10 +100,9 @@ class Command(BaseCommand):
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write("Resetting sequences\n")
cursor = connection.cursor()
for line in sequence_sql:
cursor.execute(line)
cursor.close()
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count:

View File

@ -171,105 +171,110 @@ class Command(BaseCommand):
"Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor()
# Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names()
seen_models = connection.introspection.installed_models(tables)
created_models = set()
pending_references = {}
try:
# Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names(cursor)
seen_models = connection.introspection.installed_models(tables)
created_models = set()
pending_references = {}
# Build the manifest of apps and models that are to be synchronized
all_models = [
(app_config.label,
router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config.label in app_labels
]
# Build the manifest of apps and models that are to be synchronized
all_models = [
(app_config.label,
router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config.label in app_labels
]
def model_installed(model):
opts = model._meta
converter = connection.introspection.table_name_converter
# Note that if a model is unmanaged we short-circuit and never try to install it
return not ((converter(opts.db_table) in tables) or
(opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
def model_installed(model):
opts = model._meta
converter = connection.introspection.table_name_converter
# Note that if a model is unmanaged we short-circuit and never try to install it
return not ((converter(opts.db_table) in tables) or
(opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
manifest = OrderedDict(
(app_name, list(filter(model_installed, model_list)))
for app_name, model_list in all_models
)
manifest = OrderedDict(
(app_name, list(filter(model_installed, model_list)))
for app_name, model_list in all_models
)
create_models = set(itertools.chain(*manifest.values()))
emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
create_models = set(itertools.chain(*manifest.values()))
emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
# Create the tables for each model
if self.verbosity >= 1:
self.stdout.write(" Creating tables...\n")
with transaction.atomic(using=connection.alias, savepoint=False):
for app_name, model_list in manifest.items():
for model in model_list:
# Create the model's database table, if it doesn't already exist.
if self.verbosity >= 3:
self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name))
sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
seen_models.add(model)
created_models.add(model)
for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models:
sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
if self.verbosity >= 1 and sql:
self.stdout.write(" Creating table %s\n" % model._meta.db_table)
for statement in sql:
cursor.execute(statement)
tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# Create the tables for each model
if self.verbosity >= 1:
self.stdout.write(" Creating tables...\n")
with transaction.atomic(using=connection.alias, savepoint=False):
for app_name, model_list in manifest.items():
for model in model_list:
# Create the model's database table, if it doesn't already exist.
if self.verbosity >= 3:
self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name))
sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
seen_models.add(model)
created_models.add(model)
for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models:
sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
if self.verbosity >= 1 and sql:
self.stdout.write(" Creating table %s\n" % model._meta.db_table)
for statement in sql:
cursor.execute(statement)
tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# We force a commit here, as that was the previous behaviour.
# If you can prove we don't need this, remove it.
transaction.set_dirty(using=connection.alias)
# We force a commit here, as that was the previous behaviour.
# If you can prove we don't need this, remove it.
transaction.set_dirty(using=connection.alias)
finally:
cursor.close()
# The connection may have been closed by a syncdb handler.
cursor = connection.cursor()
try:
# Install custom SQL for the app (but only if this
# is a model we've just created)
if self.verbosity >= 1:
self.stdout.write(" Installing custom SQL...\n")
for app_name, model_list in manifest.items():
for model in model_list:
if model in created_models:
custom_sql = custom_sql_for_model(model, no_style(), connection)
if custom_sql:
if self.verbosity >= 2:
self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
try:
with transaction.commit_on_success_unless_managed(using=connection.alias):
for sql in custom_sql:
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
if self.show_traceback:
traceback.print_exc()
else:
if self.verbosity >= 3:
self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
# Install custom SQL for the app (but only if this
# is a model we've just created)
if self.verbosity >= 1:
self.stdout.write(" Installing custom SQL...\n")
for app_name, model_list in manifest.items():
for model in model_list:
if model in created_models:
custom_sql = custom_sql_for_model(model, no_style(), connection)
if custom_sql:
if self.verbosity >= 2:
self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
try:
with transaction.commit_on_success_unless_managed(using=connection.alias):
for sql in custom_sql:
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
if self.show_traceback:
traceback.print_exc()
else:
if self.verbosity >= 3:
self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
if self.verbosity >= 1:
self.stdout.write(" Installing indexes...\n")
if self.verbosity >= 1:
self.stdout.write(" Installing indexes...\n")
# Install SQL indices for all newly created models
for app_name, model_list in manifest.items():
for model in model_list:
if model in created_models:
index_sql = connection.creation.sql_indexes_for_model(model, no_style())
if index_sql:
if self.verbosity >= 2:
self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
try:
with transaction.commit_on_success_unless_managed(using=connection.alias):
for sql in index_sql:
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
# Install SQL indices for all newly created models
for app_name, model_list in manifest.items():
for model in model_list:
if model in created_models:
index_sql = connection.creation.sql_indexes_for_model(model, no_style())
if index_sql:
if self.verbosity >= 2:
self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
try:
with transaction.commit_on_success_unless_managed(using=connection.alias):
for sql in index_sql:
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
finally:
cursor.close()
# Load initial_data fixtures (unless that has been disabled)
if self.load_initial_data:

View File

@ -67,38 +67,39 @@ def sql_delete(app_config, style, connection):
except Exception:
cursor = None
# Figure out which tables already exist
if cursor:
table_names = connection.introspection.table_names(cursor)
else:
table_names = []
try:
# Figure out which tables already exist
if cursor:
table_names = connection.introspection.table_names(cursor)
else:
table_names = []
output = []
output = []
# Output DROP TABLE statements for standard application tables.
to_delete = set()
# Output DROP TABLE statements for standard application tables.
to_delete = set()
references_to_delete = {}
app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
for model in app_models:
if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped
opts = model._meta
for f in opts.local_fields:
if f.rel and f.rel.to not in to_delete:
references_to_delete.setdefault(f.rel.to, []).append((model, f))
references_to_delete = {}
app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
for model in app_models:
if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped
opts = model._meta
for f in opts.local_fields:
if f.rel and f.rel.to not in to_delete:
references_to_delete.setdefault(f.rel.to, []).append((model, f))
to_delete.add(model)
to_delete.add(model)
for model in app_models:
if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
# Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues.
if cursor:
cursor.close()
connection.close()
for model in app_models:
if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
finally:
# Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues.
if cursor:
cursor.close()
connection.close()
return output[::-1] # Reverse it, to deal with table dependencies.

View File

@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
self.cursor().execute(self.ops.savepoint_create_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
self.cursor().execute(self.ops.savepoint_rollback_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
self.cursor().execute(self.ops.savepoint_commit_sql(sid))
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to
# fail.
self.connection.enter_transaction_management()
cursor = self.connection.cursor()
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.commit()
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
self.connection.rollback()
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST')
self.connection.commit()
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.commit()
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
self.connection.rollback()
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST')
self.connection.commit()
finally:
self.connection.leave_transaction_management()
return count == 0
@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases.
"""
if cursor is None:
cursor = self.connection.cursor()
with self.connection.cursor() as cursor:
return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor):

View File

@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
self.connection.cursor()
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
return test_database_name
@ -406,34 +405,34 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name
# Create the test database and connect to it.
cursor = self._nodb_connection.cursor()
try:
cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception as e:
sys.stderr.write(
"Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print("Destroying old test database '%s'..."
% self.connection.alias)
cursor.execute(
"DROP DATABASE %s" % qn(test_database_name))
cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name),
suffix))
except Exception as e:
sys.stderr.write(
"Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print("Tests cancelled.")
sys.exit(1)
with self._nodb_connection.cursor() as cursor:
try:
cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception as e:
sys.stderr.write(
"Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print("Destroying old test database '%s'..."
% self.connection.alias)
cursor.execute(
"DROP DATABASE %s" % qn(test_database_name))
cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name),
suffix))
except Exception as e:
sys.stderr.write(
"Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print("Tests cancelled.")
sys.exit(1)
return test_database_name
@ -461,11 +460,11 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = self._nodb_connection.cursor()
# Wait to avoid "database is being accessed by other users" errors.
time.sleep(1)
cursor.execute("DROP DATABASE %s"
% self.connection.ops.quote_name(test_database_name))
with self._nodb_connection.cursor() as cursor:
# Wait to avoid "database is being accessed by other users" errors.
time.sleep(1)
cursor.execute("DROP DATABASE %s"
% self.connection.ops.quote_name(test_database_name))
def set_autocommit(self):
"""

View File

@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
cursor = self.connection.cursor()
cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
# This command is MySQL specific; the second column
# will tell you the default table type of the created
# table. Since all Django's test tables will have the same
# table type, that's enough to evaluate the feature.
cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
result = cursor.fetchone()
cursor.execute('DROP TABLE INTROSPECT_TEST')
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
# This command is MySQL specific; the second column
# will tell you the default table type of the created
# table. Since all Django's test tables will have the same
# table type, that's enough to evaluate the feature.
cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
result = cursor.fetchone()
cursor.execute('DROP TABLE INTROSPECT_TEST')
return result[1]
@cached_property
@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
# Test if the time zone definitions are installed.
cursor = self.connection.cursor()
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None
class DatabaseOperations(BaseDatabaseOperations):
@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn
def init_connection_state(self):
cursor = self.connection.cursor()
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
with self.connection.cursor() as cursor:
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
def create_cursor(self):
cursor = self.connection.cursor()

View File

@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur.
self.connection.cursor()
return self.connection.ops.regex_lookup(lookup_type)
with self.connection.cursor():
return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):
return "RETURNING %s INTO %%s", (InsertIdVar(),)

View File

@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz:
cursor = self.connection.cursor()
cursor.execute(self.ops.set_time_zone_sql(), [tz])
cursor.close()
try:
cursor.execute(self.ops.set_time_zone_sql(), [tz])
finally:
cursor.close()
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():
self.connection.commit()

View File

@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'):
return connection.server_version
else:
cursor = connection.cursor()
cursor.execute("SELECT version()")
return _parse_version(cursor.fetchone()[0])
with connection.cursor() as cursor:
cursor.execute("SELECT version()")
return _parse_version(cursor.fetchone()[0])

View File

@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object):
"""
Executes the given SQL statement, with optional parameters.
"""
# Get the cursor
cursor = self.connection.cursor()
# Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else:
cursor.execute(sql, params)
with self.connection.cursor() as cursor:
cursor.execute(sql, params)
def quote_name(self, name):
return self.connection.ops.quote_name(name)
@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions
"""
column_names = list(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:

View File

@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check
whether the call works.
"""
cursor = self.connection.cursor()
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
has_support = True
except utils.DatabaseError:
has_support = False
cursor.execute('DROP TABLE STDDEV_TEST')
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
has_support = True
except utils.DatabaseError:
has_support = False
cursor.execute('DROP TABLE STDDEV_TEST')
return has_support
@cached_property

View File

@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty
from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
from django.db.models.sql.constants import CURSOR
from django.db.models import sql
from django.utils.functional import partition
from django.utils import six
@ -574,7 +575,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
with transaction.commit_on_success_unless_managed(using=self.db):
rows = query.get_compiler(self.db).execute_sql(None)
rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None
return rows
update.alters_data = True
@ -591,7 +592,7 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values)
self._result_cache = None
return query.get_compiler(self.db).execute_sql(None)
return query.get_compiler(self.db).execute_sql(CURSOR)
_update.alters_data = True
_update.queryset_only = False
@ -1521,54 +1522,59 @@ class RawQuerySet(object):
query = iter(self.query)
# Find out which columns are model's fields, and which ones should be
# annotated to the model.
for pos, column in enumerate(self.columns):
if column in self.model_fields:
model_init_field_names[self.model_fields[column].attname] = pos
else:
annotation_fields.append((column, pos))
try:
# Find out which columns are model's fields, and which ones should be
# annotated to the model.
for pos, column in enumerate(self.columns):
if column in self.model_fields:
model_init_field_names[self.model_fields[column].attname] = pos
else:
annotation_fields.append((column, pos))
# Find out which model's fields are not present in the query.
skip = set()
for field in self.model._meta.fields:
if field.attname not in model_init_field_names:
skip.add(field.attname)
if skip:
if self.model._meta.pk.attname in skip:
raise InvalidQuery('Raw query must include the primary key')
model_cls = deferred_class_factory(self.model, skip)
else:
model_cls = self.model
# All model's fields are present in the query. So, it is possible
# to use *args based model instantation. For each field of the model,
# record the query column position matching that field.
model_init_field_pos = []
# Find out which model's fields are not present in the query.
skip = set()
for field in self.model._meta.fields:
model_init_field_pos.append(model_init_field_names[field.attname])
if need_resolv_columns:
fields = [self.model_fields.get(c, None) for c in self.columns]
# Begin looping through the query values.
for values in query:
if need_resolv_columns:
values = compiler.resolve_columns(values, fields)
# Associate fields to values
if field.attname not in model_init_field_names:
skip.add(field.attname)
if skip:
model_init_kwargs = {}
for attname, pos in six.iteritems(model_init_field_names):
model_init_kwargs[attname] = values[pos]
instance = model_cls(**model_init_kwargs)
if self.model._meta.pk.attname in skip:
raise InvalidQuery('Raw query must include the primary key')
model_cls = deferred_class_factory(self.model, skip)
else:
model_init_args = [values[pos] for pos in model_init_field_pos]
instance = model_cls(*model_init_args)
if annotation_fields:
for column, pos in annotation_fields:
setattr(instance, column, values[pos])
model_cls = self.model
# All model's fields are present in the query. So, it is possible
# to use *args based model instantation. For each field of the model,
# record the query column position matching that field.
model_init_field_pos = []
for field in self.model._meta.fields:
model_init_field_pos.append(model_init_field_names[field.attname])
if need_resolv_columns:
fields = [self.model_fields.get(c, None) for c in self.columns]
# Begin looping through the query values.
for values in query:
if need_resolv_columns:
values = compiler.resolve_columns(values, fields)
# Associate fields to values
if skip:
model_init_kwargs = {}
for attname, pos in six.iteritems(model_init_field_names):
model_init_kwargs[attname] = values[pos]
instance = model_cls(**model_init_kwargs)
else:
model_init_args = [values[pos] for pos in model_init_field_pos]
instance = model_cls(*model_init_args)
if annotation_fields:
for column, pos in annotation_fields:
setattr(instance, column, values[pos])
instance._state.db = db
instance._state.adding = False
instance._state.db = db
instance._state.adding = False
yield instance
yield instance
finally:
# Done iterating the Query. If it has its own cursor, close it.
if hasattr(self.query, 'cursor') and self.query.cursor:
self.query.cursor.close()
def __repr__(self):
text = self.raw_query

View File

@ -1,12 +1,13 @@
import datetime
import sys
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@ -762,6 +763,8 @@ class SQLCompiler(object):
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
"""
if not result_type:
result_type = NO_RESULTS
try:
sql, params = self.as_sql()
if not sql:
@ -773,27 +776,44 @@ class SQLCompiler(object):
return
cursor = self.connection.cursor()
cursor.execute(sql, params)
try:
cursor.execute(sql, params)
except Exception:
cursor.close()
raise
if not result_type:
if result_type == CURSOR:
# Caller didn't specify a result_type, so just give them back the
# cursor to process (and close).
return cursor
if result_type == SINGLE:
if self.ordering_aliases:
return cursor.fetchone()[:-len(self.ordering_aliases)]
return cursor.fetchone()
try:
if self.ordering_aliases:
return cursor.fetchone()[:-len(self.ordering_aliases)]
return cursor.fetchone()
finally:
# done with the cursor
cursor.close()
if result_type == NO_RESULTS:
cursor.close()
return
# The MULTI case.
if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value)
else:
result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
self.connection.features.empty_fetchmany_value)
result = cursor_iter(cursor,
self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads:
# If we are using non-chunked reads, we return the same data
# structure as normally, but ensure it is all read into memory
# before going any further.
return list(result)
try:
# If we are using non-chunked reads, we return the same data
# structure as normally, but ensure it is all read into memory
# before going any further.
return list(result)
finally:
# done with the cursor
cursor.close()
return result
def as_subquery_condition(self, alias, columns, qn):
@ -889,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id
cursor = self.connection.cursor()
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor):
return
if self.connection.features.can_return_id_from_insert:
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor,
self.query.get_meta().db_table, self.query.get_meta().pk.column)
with self.connection.cursor() as cursor:
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor):
return
if self.connection.features.can_return_id_from_insert:
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor,
self.query.get_meta().db_table, self.query.get_meta().pk.column)
class SQLDeleteCompiler(SQLCompiler):
@ -970,12 +990,15 @@ class SQLUpdateCompiler(SQLCompiler):
related queries are not available.
"""
cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
rows = cursor.rowcount if cursor else 0
is_empty = cursor is None
del cursor
try:
rows = cursor.rowcount if cursor else 0
is_empty = cursor is None
finally:
if cursor:
cursor.close()
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty:
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows
@ -1111,6 +1134,19 @@ class SQLDateTimeCompiler(SQLCompiler):
yield datetime
def cursor_iter(cursor, sentinel):
"""
Yields blocks of rows from a cursor and ensures the cursor is closed when
done.
"""
try:
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield rows
finally:
cursor.close()
def order_modified_iter(cursor, trim, sentinel):
"""
Yields blocks of rows from a cursor. We use this iterator in the special
@ -1118,6 +1154,9 @@ def order_modified_iter(cursor, trim, sentinel):
requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid.
"""
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield [r[:-trim] for r in rows]
try:
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield [r[:-trim] for r in rows]
finally:
cursor.close()

View File

@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field')
# How many results to expect from a cursor.execute call
MULTI = 'multi'
SINGLE = 'single'
CURSOR = 'cursor'
NO_RESULTS = 'no results'
ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
ORDER_DIR = {

View File

@ -8,7 +8,7 @@ from django.db import connections
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.utils import six
@ -30,7 +30,7 @@ class DeleteQuery(Query):
def do_query(self, table, where, using):
self.tables = [table]
self.where = where
self.get_compiler(using).execute_sql(None)
self.get_compiler(using).execute_sql(NO_RESULTS)
def delete_batch(self, pk_list, using, field=None):
"""
@ -82,7 +82,7 @@ class DeleteQuery(Query):
values = innerq
self.where = self.where_class()
self.add_q(Q(pk__in=values))
self.get_compiler(using).execute_sql(None)
self.get_compiler(using).execute_sql(NO_RESULTS)
class UpdateQuery(Query):
@ -116,7 +116,7 @@ class UpdateQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class()
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
self.get_compiler(using).execute_sql(None)
self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
"""

View File

@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper
from django.db.models import Sum, Avg, Variance, StdDev
from django.db.models.fields import (AutoField, DateField, DateTimeField,
DecimalField, IntegerField, TimeField)
from django.db.models.sql.constants import CURSOR
from django.db.utils import ConnectionHandler
from django.test import (TestCase, TransactionTestCase, override_settings,
skipUnlessDBFeature, skipIfDBFeature)
@ -58,9 +59,9 @@ class OracleChecks(unittest.TestCase):
# stored procedure through our cursor wrapper.
from django.db.backends.oracle.base import convert_unicode
cursor = connection.cursor()
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!')])
with connection.cursor() as cursor:
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!')])
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics")
@ -69,31 +70,31 @@ class OracleChecks(unittest.TestCase):
# as query parameters.
from django.db.backends.oracle.base import Database
cursor = connection.cursor()
var = cursor.var(Database.STRING)
cursor.execute("BEGIN %s := 'X'; END; ", [var])
self.assertEqual(var.getvalue(), 'X')
with connection.cursor() as cursor:
var = cursor.var(Database.STRING)
cursor.execute("BEGIN %s := 'X'; END; ", [var])
self.assertEqual(var.getvalue(), 'X')
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics")
def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly
c = connection.cursor()
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join(six.text_type(x) for x in xrange(4000))
c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext')
row = c.fetchone()
self.assertEqual(long_str, row[0].read())
c.execute('DROP TABLE ltext')
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join(six.text_type(x) for x in xrange(4000))
cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
cursor.execute('SELECT text FROM ltext')
row = cursor.fetchone()
self.assertEqual(long_str, row[0].read())
cursor.execute('DROP TABLE ltext')
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle connection semantics")
def test_client_encoding(self):
# If the backend is Oracle, test that the client encoding is set
# correctly. This was broken under Cygwin prior to r14781.
connection.cursor() # Ensure the connection is initialized.
self.connection.ensure_connection()
self.assertEqual(connection.connection.encoding, "UTF-8")
self.assertEqual(connection.connection.nencoding, "UTF-8")
@ -102,12 +103,12 @@ class OracleChecks(unittest.TestCase):
def test_order_of_nls_parameters(self):
# an 'almost right' datetime should work with configured
# NLS parameters as per #18465.
c = connection.cursor()
query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
# Test that the query succeeds without errors - pre #18465 this
# wasn't the case.
c.execute(query)
self.assertEqual(c.fetchone()[0], 1)
with connection.cursor() as cursor:
query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
# Test that the query succeeds without errors - pre #18465 this
# wasn't the case.
cursor.execute(query)
self.assertEqual(cursor.fetchone()[0], 1)
class SQLiteTests(TestCase):
@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase):
"""
persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1})
sql, params = persons.query.sql_with_params()
cursor = persons.query.get_compiler('default').execute_sql(None)
cursor = persons.query.get_compiler('default').execute_sql(CURSOR)
last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
self.assertIsInstance(last_sql, six.text_type)
@ -327,6 +328,12 @@ class PostgresVersionTest(TestCase):
def fetchone(self):
return ["PostgreSQL 8.3"]
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class OlderConnectionMock(object):
"Mock of psycopg2 (< 2.0.12) connection"
def cursor(self):

View File

@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
management.call_command('createcachetable', verbosity=0, interactive=False)
def drop_table(self):
cursor = connection.cursor()
table_name = connection.ops.quote_name('test cache table')
cursor.execute('DROP TABLE %s' % table_name)
cursor.close()
with connection.cursor() as cursor:
table_name = connection.ops.quote_name('test cache table')
cursor.execute('DROP TABLE %s' % table_name)
def test_zero_cull(self):
self._perform_cull_test(caches['zero_cull'], 50, 18)

View File

@ -30,11 +30,11 @@ class Article(models.Model):
database query for the sake of demonstration.
"""
from django.db import connection
cursor = connection.cursor()
cursor.execute("""
SELECT id, headline, pub_date
FROM custom_methods_article
WHERE pub_date = %s
AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
self.id])
return [self.__class__(*row) for row in cursor.fetchall()]
with connection.cursor() as cursor:
cursor.execute("""
SELECT id, headline, pub_date
FROM custom_methods_article
WHERE pub_date = %s
AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
self.id])
return [self.__class__(*row) for row in cursor.fetchall()]

View File

@ -28,9 +28,9 @@ class InitialSQLTests(TestCase):
connection = connections[DEFAULT_DB_ALIAS]
custom_sql = custom_sql_for_model(Simple, no_style(), connection)
self.assertEqual(len(custom_sql), 9)
cursor = connection.cursor()
for sql in custom_sql:
cursor.execute(sql)
with connection.cursor() as cursor:
for sql in custom_sql:
cursor.execute(sql)
self.assertEqual(Simple.objects.count(), 9)
self.assertEqual(
Simple.objects.get(name__contains='placeholders').name,

View File

@ -23,17 +23,17 @@ class IntrospectionTests(TestCase):
"'%s' isn't in table_list()." % Article._meta.db_table)
def test_django_table_names(self):
cursor = connection.cursor()
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names()
cursor.execute("DROP TABLE django_ixn_test_table;")
self.assertTrue('django_ixn_testcase_table' not in tl,
"django_table_names() returned a non-Django table")
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names()
cursor.execute("DROP TABLE django_ixn_test_table;")
self.assertTrue('django_ixn_testcase_table' not in tl,
"django_table_names() returned a non-Django table")
def test_django_table_names_retval_type(self):
# Ticket #15216
cursor = connection.cursor()
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names(only_existing=True)
self.assertIs(type(tl), list)
@ -53,14 +53,14 @@ class IntrospectionTests(TestCase):
'Reporter sequence not found in sequence_list()')
def test_get_table_description_names(self):
cursor = connection.cursor()
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual([r[0] for r in desc],
[f.column for f in Reporter._meta.fields])
def test_get_table_description_types(self):
cursor = connection.cursor()
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
# The MySQL exception is due to the cursor.description returning the same constant for
# text and blob columns. TODO: use information_schema database to retrieve the proper
# field type on MySQL
@ -75,8 +75,8 @@ class IntrospectionTests(TestCase):
# inspect the length of character columns).
@expectedFailureOnOracle
def test_get_table_description_col_lengths(self):
cursor = connection.cursor()
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'],
[30, 30, 75]
@ -87,8 +87,8 @@ class IntrospectionTests(TestCase):
# so its idea about null_ok in cursor.description is different from ours.
@skipIfDBFeature('interprets_empty_strings_as_nulls')
def test_get_table_description_nullable(self):
cursor = connection.cursor()
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual(
[r[6] for r in desc],
[False, False, False, False, True, True]
@ -97,15 +97,15 @@ class IntrospectionTests(TestCase):
# Regression test for #9991 - 'real' types in postgres
@skipUnlessDBFeature('has_real_datatype')
def test_postgresql_real_type(self):
cursor = connection.cursor()
cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
cursor.execute('DROP TABLE django_ixn_real_test_table;')
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
cursor.execute('DROP TABLE django_ixn_real_test_table;')
self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
def test_get_relations(self):
cursor = connection.cursor()
relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# Older versions of MySQL don't have the chops to report on this stuff,
# so just skip it if no relations come back. If they do, though, we
@ -117,21 +117,21 @@ class IntrospectionTests(TestCase):
@skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self):
cursor = connection.cursor()
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
with connection.cursor() as cursor:
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual(
set(key_columns),
set([('reporter_id', Reporter._meta.db_table, 'id'),
('response_to_id', Article._meta.db_table, 'id')]))
def test_get_primary_key_column(self):
cursor = connection.cursor()
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
with connection.cursor() as cursor:
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, 'id')
def test_get_indexes(self):
cursor = connection.cursor()
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
def test_get_indexes_multicol(self):
@ -139,8 +139,8 @@ class IntrospectionTests(TestCase):
Test that multicolumn indexes are not included in the introspection
results.
"""
cursor = connection.cursor()
indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
self.assertNotIn('first_name', indexes)
self.assertIn('id', indexes)

View File

@ -9,33 +9,40 @@ class MigrationTestBase(TransactionTestCase):
available_apps = ["migrations"]
def get_table_description(self, table):
with connection.cursor() as cursor:
return connection.introspection.get_table_description(cursor, table)
def assertTableExists(self, table):
self.assertIn(table, connection.introspection.get_table_list(connection.cursor()))
with connection.cursor() as cursor:
self.assertIn(table, connection.introspection.get_table_list(cursor))
def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
with connection.cursor() as cursor:
self.assertNotIn(table, connection.introspection.get_table_list(cursor))
def assertColumnExists(self, table, column):
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
self.assertIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNotExists(self, table, column):
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True)
self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
def assertColumnNotNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False)
self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
def assertIndexExists(self, table, columns, value=True):
self.assertEqual(
value,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), table).values()
if c['columns'] == list(columns)
),
)
with connection.cursor() as cursor:
self.assertEqual(
value,
any(
c["index"]
for c in connection.introspection.get_constraints(cursor, table).values()
if c['columns'] == list(columns)
),
)
def assertIndexNotExists(self, table, columns):
return self.assertIndexExists(table, columns, False)

View File

@ -19,15 +19,15 @@ class OperationTests(MigrationTestBase):
Creates a test model state and database table.
"""
# Delete the tables if they already exist
cursor = connection.cursor()
try:
cursor.execute("DROP TABLE %s_pony" % app_label)
except:
pass
try:
cursor.execute("DROP TABLE %s_stable" % app_label)
except:
pass
with connection.cursor() as cursor:
try:
cursor.execute("DROP TABLE %s_pony" % app_label)
except:
pass
try:
cursor.execute("DROP TABLE %s_stable" % app_label)
except:
pass
# Make the "current" state
operations = [migrations.CreateModel(
"Pony",
@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
def assertIdTypeEqualsFkType(self):
with connection.cursor() as cursor:
id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# Test the database alteration
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
def test_rename_field(self):
"""
@ -400,24 +400,24 @@ class OperationTests(MigrationTestBase):
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows
cursor = connection.cursor()
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# Test the database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_alunto", editor, project_state, new_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
with self.assertRaises(IntegrityError):
with atomic():
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alunto", editor, new_state, project_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
with connection.cursor() as cursor:
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# Test the database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_alunto", editor, project_state, new_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
with self.assertRaises(IntegrityError):
with atomic():
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alunto", editor, new_state, project_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# Test flat unique_together
operation = migrations.AlterUniqueTogether("Pony", ("pink", "weight"))
operation.state_forwards("test_alunto", new_state)

View File

@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
# request_finished signal.
response = self.client.get('/')
# Make sure there is an open connection
connection.cursor()
self.connection.ensure_connection()
connection.enter_transaction_management()
signals.request_finished.send(sender=response._handler_class)
self.assertEqual(len(connection.transaction_state), 0)

View File

@ -37,38 +37,38 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self):
"Deletes all model tables for our models for a clean test environment"
cursor = connection.cursor()
connection.disable_constraint_checking()
table_names = connection.introspection.table_names(cursor)
for model in self.models:
# Remove any M2M tables first
for field in model._meta.local_many_to_many:
with connection.cursor() as cursor:
connection.disable_constraint_checking()
table_names = connection.introspection.table_names(cursor)
for model in self.models:
# Remove any M2M tables first
for field in model._meta.local_many_to_many:
with atomic():
tbl = field.rel.through._meta.db_table
if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl),
})
table_names.remove(tbl)
# Then remove the main tables
with atomic():
tbl = field.rel.through._meta.db_table
tbl = model._meta.db_table
if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl),
})
table_names.remove(tbl)
# Then remove the main tables
with atomic():
tbl = model._meta.db_table
if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl),
})
table_names.remove(tbl)
connection.enable_constraint_checking()
def column_classes(self, model):
cursor = connection.cursor()
columns = dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description(
cursor,
model._meta.db_table,
with connection.cursor() as cursor:
columns = dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description(
cursor,
model._meta.db_table,
)
)
)
# SQLite has a different format for field_type
for name, (type, desc) in columns.items():
if isinstance(type, tuple):
@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
raise DatabaseError("Table does not exist (empty pragma)")
return columns
def get_indexes(self, table):
"""
Get the indexes on the table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_indexes(cursor, table)
def get_constraints(self, table):
"""
Get the constraints on a table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
# Tests
def test_creation_deletion(self):
@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
strict=True,
)
# Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
constraints = self.get_constraints(Book._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(TagM2MTest)
editor.create_model(UniqueTest)
# Ensure the M2M exists and points to TagM2MTest
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
# Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
# Ensure the new M2M exists and points to UniqueTest
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
constraints = self.get_constraints(new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys:
for name, details in constraints.items():
if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor:
editor.create_model(Author)
# Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
new_field,
strict=True,
)
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found")
@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
Author._meta.get_field_by_name("height")[0],
strict=True,
)
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']:
break
@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
False,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"]
),
)
@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
True,
any(
c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values()
for c in self.get_constraints("schema_tagindexed").values()
if c['columns'] == ["slug", "title"]
),
)
@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the right index
self.assertIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Alter to remove the index
new_field = CharField(max_length=100, db_index=False)
@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has no index
self.assertNotIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Alter to re-add the index
with connection.schema_editor() as editor:
@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the index again
self.assertIn(
"title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor:
@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertIn(
"slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
# Remove the unique, check the index goes with it
new_field2 = CharField(max_length=20, unique=False)
@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
)
self.assertNotIn(
"slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
self.get_indexes(Book._meta.db_table),
)
def test_primary_key(self):
@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(Tag)
# Ensure the table is there and has the right PK
self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'],
self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
)
# Alter to change the PK
new_field = SlugField(primary_key=True)
@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
# Ensure the PK changed
self.assertNotIn(
'id',
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table),
self.get_indexes(Tag._meta.db_table),
)
self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'],
self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
)
def test_context_manager_exit(self):
@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has an index on the column
self.assertIn(
column_name,
connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table),
self.get_indexes(BookWithLongName._meta.db_table),
)
def test_creation_deletion_reserved_names(self):

View File

@ -202,8 +202,9 @@ class AtomicTests(TransactionTestCase):
# trigger a database error inside an inner atomic without savepoint
with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False):
connection.cursor().execute(
"SELECT no_such_col FROM transactions_reporter")
with connection.cursor() as cursor:
cursor.execute(
"SELECT no_such_col FROM transactions_reporter")
# prevent atomic from rolling back since we're recovering manually
self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False)
@ -534,8 +535,8 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
available_apps = ['transactions']
def execute_bad_sql(self):
cursor = connection.cursor()
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
@skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
def test_bad_sql(self):
@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
"""
with self.assertRaises(IntegrityError):
with transaction.commit_on_success():
cursor = connection.cursor()
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
transaction.rollback()

View File

@ -54,8 +54,8 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
@commit_on_success
def raw_sql():
"Write a record using raw sql under a commit_on_success decorator"
cursor = connection.cursor()
cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
raw_sql()
# Rollback so that if the decorator didn't commit, the record is unwritten
@ -143,10 +143,10 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
(reference). All this under commit_on_success, so the second insert should
be committed.
"""
cursor = connection.cursor()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
transaction.rollback()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
transaction.rollback()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
reuse_cursor_ref()
# Rollback so that if the decorator didn't commit, the record is unwritten