From 0837eacc4e1fa7916e48135e8ba43f54a7a64997 Mon Sep 17 00:00:00 2001 From: Michael Manfre Date: Wed, 8 Jan 2014 23:31:34 -0500 Subject: [PATCH 1/2] Made SQLCompiler.execute_sql(result_type) more explicit. Updated SQLUpdateCompiler.execute_sql to match the behavior described in the docstring; the 'first non-empty query' will now include all queries, not just the main and first related update. Added CURSOR and NO_RESULTS result_type constants to make the usages more self documenting and allow execute_sql to explicitly close the cursor when it is no longer needed. --- django/db/models/query.py | 5 +- django/db/models/sql/compiler.py | 78 ++++++++++++++++++++++-------- django/db/models/sql/constants.py | 2 + django/db/models/sql/subqueries.py | 8 +-- tests/backends/tests.py | 3 +- 5 files changed, 69 insertions(+), 27 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 48d295ccca4..353dd957943 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty from django.db.models.query_utils import (Q, select_related_descend, deferred_class_factory, InvalidQuery) from django.db.models.deletion import Collector +from django.db.models.sql.constants import CURSOR from django.db.models import sql from django.utils.functional import partition from django.utils import six @@ -574,7 +575,7 @@ class QuerySet(object): query = self.query.clone(sql.UpdateQuery) query.add_update_values(kwargs) with transaction.commit_on_success_unless_managed(using=self.db): - rows = query.get_compiler(self.db).execute_sql(None) + rows = query.get_compiler(self.db).execute_sql(CURSOR) self._result_cache = None return rows update.alters_data = True @@ -591,7 +592,7 @@ class QuerySet(object): query = self.query.clone(sql.UpdateQuery) query.add_update_fields(values) self._result_cache = None - return query.get_compiler(self.db).execute_sql(None) + return query.get_compiler(self.db).execute_sql(CURSOR) _update.alters_data = True _update.queryset_only = False diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 123427cf8b4..536a66d1399 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -5,8 +5,8 @@ from django.core.exceptions import FieldError from django.db.backends.utils import truncate_name from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import select_related_descend, QueryWrapper -from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, - GET_ITERATOR_CHUNK_SIZE, SelectInfo) +from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS, + ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import get_order_dir, Query @@ -762,6 +762,8 @@ class SQLCompiler(object): is needed, as the filters describe an empty set. In that case, None is returned, to avoid any unnecessary database interaction. """ + if not result_type: + result_type = NO_RESULTS try: sql, params = self.as_sql() if not sql: @@ -773,27 +775,44 @@ class SQLCompiler(object): return cursor = self.connection.cursor() - cursor.execute(sql, params) + try: + cursor.execute(sql, params) + except: + cursor.close() + raise - if not result_type: + if result_type == CURSOR: + # Caller didn't specify a result_type, so just give them back the + # cursor to process (and close). return cursor if result_type == SINGLE: - if self.ordering_aliases: - return cursor.fetchone()[:-len(self.ordering_aliases)] - return cursor.fetchone() + try: + if self.ordering_aliases: + return cursor.fetchone()[:-len(self.ordering_aliases)] + return cursor.fetchone() + finally: + # done with the cursor + cursor.close() + if result_type == NO_RESULTS: + cursor.close() + return # The MULTI case. if self.ordering_aliases: result = order_modified_iter(cursor, len(self.ordering_aliases), self.connection.features.empty_fetchmany_value) else: - result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - self.connection.features.empty_fetchmany_value) + result = cursor_iter(cursor, + self.connection.features.empty_fetchmany_value) if not self.connection.features.can_use_chunked_reads: - # If we are using non-chunked reads, we return the same data - # structure as normally, but ensure it is all read into memory - # before going any further. - return list(result) + try: + # If we are using non-chunked reads, we return the same data + # structure as normally, but ensure it is all read into memory + # before going any further. + return list(result) + finally: + # done with the cursor + cursor.close() return result def as_subquery_condition(self, alias, columns, qn): @@ -970,12 +989,15 @@ class SQLUpdateCompiler(SQLCompiler): related queries are not available. """ cursor = super(SQLUpdateCompiler, self).execute_sql(result_type) - rows = cursor.rowcount if cursor else 0 - is_empty = cursor is None - del cursor + try: + rows = cursor.rowcount if cursor else 0 + is_empty = cursor is None + finally: + if cursor: + cursor.close() for query in self.query.get_related_updates(): aux_rows = query.get_compiler(self.using).execute_sql(result_type) - if is_empty: + if is_empty and aux_rows: rows = aux_rows is_empty = False return rows @@ -1111,6 +1133,19 @@ class SQLDateTimeCompiler(SQLCompiler): yield datetime +def cursor_iter(cursor, sentinel): + """ + Yields blocks of rows from a cursor and ensures the cursor is closed when + done. + """ + try: + for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + sentinel): + yield rows + finally: + cursor.close() + + def order_modified_iter(cursor, trim, sentinel): """ Yields blocks of rows from a cursor. We use this iterator in the special @@ -1118,6 +1153,9 @@ def order_modified_iter(cursor, trim, sentinel): requirements. We must trim those extra columns before anything else can use the results, since they're only needed to make the SQL valid. """ - for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - sentinel): - yield [r[:-trim] for r in rows] + try: + for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + sentinel): + yield [r[:-trim] for r in rows] + finally: + cursor.close() diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 904f7b2c8b4..36aab23bae0 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field') # How many results to expect from a cursor.execute call MULTI = 'multi' SINGLE = 'single' +CURSOR = 'cursor' +NO_RESULTS = 'no results' ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$') ORDER_DIR = { diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 86b1efd3f8d..cfda1f552c8 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -8,7 +8,7 @@ from django.db import connections from django.db.models.query_utils import Q from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist -from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo +from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.utils import six @@ -30,7 +30,7 @@ class DeleteQuery(Query): def do_query(self, table, where, using): self.tables = [table] self.where = where - self.get_compiler(using).execute_sql(None) + self.get_compiler(using).execute_sql(NO_RESULTS) def delete_batch(self, pk_list, using, field=None): """ @@ -82,7 +82,7 @@ class DeleteQuery(Query): values = innerq self.where = self.where_class() self.add_q(Q(pk__in=values)) - self.get_compiler(using).execute_sql(None) + self.get_compiler(using).execute_sql(NO_RESULTS) class UpdateQuery(Query): @@ -116,7 +116,7 @@ class UpdateQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.where = self.where_class() self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])) - self.get_compiler(using).execute_sql(None) + self.get_compiler(using).execute_sql(NO_RESULTS) def add_update_values(self, values): """ diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 0ff3ad0bba5..4a3fc31b7aa 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper from django.db.models import Sum, Avg, Variance, StdDev from django.db.models.fields import (AutoField, DateField, DateTimeField, DecimalField, IntegerField, TimeField) +from django.db.models.sql.constants import CURSOR from django.db.utils import ConnectionHandler from django.test import (TestCase, TransactionTestCase, override_settings, skipUnlessDBFeature, skipIfDBFeature) @@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase): """ persons = models.Reporter.objects.filter(raw_data=b'\x00\x46 \xFE').extra(select={'föö': 1}) sql, params = persons.query.sql_with_params() - cursor = persons.query.get_compiler('default').execute_sql(None) + cursor = persons.query.get_compiler('default').execute_sql(CURSOR) last_sql = cursor.db.ops.last_executed_query(cursor, sql, params) self.assertIsInstance(last_sql, six.text_type) From 3ffeb931869cc68a8e0916219702ee282afc6e9d Mon Sep 17 00:00:00 2001 From: Michael Manfre Date: Thu, 9 Jan 2014 10:05:15 -0500 Subject: [PATCH 2/2] Ensure cursors are closed when no longer needed. This commit touchs various parts of the code base and test framework. Any found usage of opening a cursor for the sake of initializing a connection has been replaced with 'ensure_connection()'. --- .../gis/db/backends/postgis/creation.py | 14 +- .../gis/db/backends/spatialite/creation.py | 5 +- django/contrib/sites/management.py | 6 +- django/core/cache/backends/db.py | 129 ++++++------ .../management/commands/createcachetable.py | 18 +- django/core/management/commands/flush.py | 6 +- django/core/management/commands/inspectdb.py | 194 +++++++++--------- django/core/management/commands/loaddata.py | 7 +- django/core/management/commands/migrate.py | 181 ++++++++-------- django/core/management/sql.py | 55 ++--- django/db/backends/__init__.py | 30 +-- django/db/backends/creation.py | 71 ++++--- django/db/backends/mysql/base.py | 37 ++-- django/db/backends/oracle/base.py | 4 +- .../db/backends/postgresql_psycopg2/base.py | 6 +- .../backends/postgresql_psycopg2/version.py | 6 +- django/db/backends/schema.py | 8 +- django/db/backends/sqlite3/base.py | 16 +- django/db/models/query.py | 89 ++++---- django/db/models/sql/compiler.py | 21 +- tests/backends/tests.py | 50 +++-- tests/cache/tests.py | 7 +- tests/custom_methods/models.py | 16 +- tests/initial_sql_regress/tests.py | 6 +- tests/introspection/tests.py | 60 +++--- tests/migrations/test_base.py | 35 ++-- tests/migrations/test_operations.py | 72 +++---- tests/requests/tests.py | 2 +- tests/schema/tests.py | 96 +++++---- tests/transactions/tests.py | 13 +- tests/transactions_regress/tests.py | 12 +- 31 files changed, 657 insertions(+), 615 deletions(-) diff --git a/django/contrib/gis/db/backends/postgis/creation.py b/django/contrib/gis/db/backends/postgis/creation.py index 51ac197b8e9..82be18cb652 100644 --- a/django/contrib/gis/db/backends/postgis/creation.py +++ b/django/contrib/gis/db/backends/postgis/creation.py @@ -11,10 +11,10 @@ class PostGISCreation(DatabaseCreation): @cached_property def template_postgis(self): template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis') - cursor = self.connection.cursor() - cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,)) - if cursor.fetchone(): - return template_postgis + with self.connection.cursor() as cursor: + cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,)) + if cursor.fetchone(): + return template_postgis return None def sql_indexes_for_field(self, model, f, style): @@ -88,8 +88,8 @@ class PostGISCreation(DatabaseCreation): # Connect to the test database in order to create the postgis extension self.connection.close() self.connection.settings_dict["NAME"] = test_database_name - cursor = self.connection.cursor() - cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis") - cursor.connection.commit() + with self.connection.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis") + cursor.connection.commit() return test_database_name diff --git a/django/contrib/gis/db/backends/spatialite/creation.py b/django/contrib/gis/db/backends/spatialite/creation.py index 521985259e8..06f105d5630 100644 --- a/django/contrib/gis/db/backends/spatialite/creation.py +++ b/django/contrib/gis/db/backends/spatialite/creation.py @@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation): call_command('createcachetable', database=self.connection.alias) - # Get a cursor (even though we don't need one yet). This has - # the side effect of initializing the test database. - self.connection.cursor() + # Ensure a connection for the side effect of initializing the test database. + self.connection.ensure_connection() return test_database_name diff --git a/django/contrib/sites/management.py b/django/contrib/sites/management.py index e7624e75cf5..8353a6f496c 100644 --- a/django/contrib/sites/management.py +++ b/django/contrib/sites/management.py @@ -33,9 +33,9 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB if sequence_sql: if verbosity >= 2: print("Resetting sequence") - cursor = connections[db].cursor() - for command in sequence_sql: - cursor.execute(command) + with connections[db].cursor() as cursor: + for command in sequence_sql: + cursor.execute(command) Site.objects.clear_cache() diff --git a/django/core/cache/backends/db.py b/django/core/cache/backends/db.py index a21777aaba1..959095026bd 100644 --- a/django/core/cache/backends/db.py +++ b/django/core/cache/backends/db.py @@ -59,11 +59,11 @@ class DatabaseCache(BaseDatabaseCache): self.validate_key(key) db = router.db_for_read(self.cache_model_class) table = connections[db].ops.quote_name(self._table) - cursor = connections[db].cursor() - cursor.execute("SELECT cache_key, value, expires FROM %s " - "WHERE cache_key = %%s" % table, [key]) - row = cursor.fetchone() + with connections[db].cursor() as cursor: + cursor.execute("SELECT cache_key, value, expires FROM %s " + "WHERE cache_key = %%s" % table, [key]) + row = cursor.fetchone() if row is None: return default now = timezone.now() @@ -75,9 +75,9 @@ class DatabaseCache(BaseDatabaseCache): expires = typecast_timestamp(str(expires)) if expires < now: db = router.db_for_write(self.cache_model_class) - cursor = connections[db].cursor() - cursor.execute("DELETE FROM %s " - "WHERE cache_key = %%s" % table, [key]) + with connections[db].cursor() as cursor: + cursor.execute("DELETE FROM %s " + "WHERE cache_key = %%s" % table, [key]) return default value = connections[db].ops.process_clob(row[1]) return pickle.loads(base64.b64decode(force_bytes(value))) @@ -96,55 +96,55 @@ class DatabaseCache(BaseDatabaseCache): timeout = self.get_backend_timeout(timeout) db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) - cursor = connections[db].cursor() - cursor.execute("SELECT COUNT(*) FROM %s" % table) - num = cursor.fetchone()[0] - now = timezone.now() - now = now.replace(microsecond=0) - if timeout is None: - exp = datetime.max - elif settings.USE_TZ: - exp = datetime.utcfromtimestamp(timeout) - else: - exp = datetime.fromtimestamp(timeout) - exp = exp.replace(microsecond=0) - if num > self._max_entries: - self._cull(db, cursor, now) - pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) - b64encoded = base64.b64encode(pickled) - # The DB column is expecting a string, so make sure the value is a - # string, not bytes. Refs #19274. - if six.PY3: - b64encoded = b64encoded.decode('latin1') - try: - # Note: typecasting for datetimes is needed by some 3rd party - # database backends. All core backends work without typecasting, - # so be careful about changes here - test suite will NOT pick - # regressions. - with transaction.atomic(using=db): - cursor.execute("SELECT cache_key, expires FROM %s " - "WHERE cache_key = %%s" % table, [key]) - result = cursor.fetchone() - if result: - current_expires = result[1] - if (connections[db].features.needs_datetime_string_cast and not - isinstance(current_expires, datetime)): - current_expires = typecast_timestamp(str(current_expires)) - exp = connections[db].ops.value_to_db_datetime(exp) - if result and (mode == 'set' or (mode == 'add' and current_expires < now)): - cursor.execute("UPDATE %s SET value = %%s, expires = %%s " - "WHERE cache_key = %%s" % table, - [b64encoded, exp, key]) - else: - cursor.execute("INSERT INTO %s (cache_key, value, expires) " - "VALUES (%%s, %%s, %%s)" % table, - [key, b64encoded, exp]) - except DatabaseError: - # To be threadsafe, updates/inserts are allowed to fail silently - return False - else: - return True + with connections[db].cursor() as cursor: + cursor.execute("SELECT COUNT(*) FROM %s" % table) + num = cursor.fetchone()[0] + now = timezone.now() + now = now.replace(microsecond=0) + if timeout is None: + exp = datetime.max + elif settings.USE_TZ: + exp = datetime.utcfromtimestamp(timeout) + else: + exp = datetime.fromtimestamp(timeout) + exp = exp.replace(microsecond=0) + if num > self._max_entries: + self._cull(db, cursor, now) + pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) + b64encoded = base64.b64encode(pickled) + # The DB column is expecting a string, so make sure the value is a + # string, not bytes. Refs #19274. + if six.PY3: + b64encoded = b64encoded.decode('latin1') + try: + # Note: typecasting for datetimes is needed by some 3rd party + # database backends. All core backends work without typecasting, + # so be careful about changes here - test suite will NOT pick + # regressions. + with transaction.atomic(using=db): + cursor.execute("SELECT cache_key, expires FROM %s " + "WHERE cache_key = %%s" % table, [key]) + result = cursor.fetchone() + if result: + current_expires = result[1] + if (connections[db].features.needs_datetime_string_cast and not + isinstance(current_expires, datetime)): + current_expires = typecast_timestamp(str(current_expires)) + exp = connections[db].ops.value_to_db_datetime(exp) + if result and (mode == 'set' or (mode == 'add' and current_expires < now)): + cursor.execute("UPDATE %s SET value = %%s, expires = %%s " + "WHERE cache_key = %%s" % table, + [b64encoded, exp, key]) + else: + cursor.execute("INSERT INTO %s (cache_key, value, expires) " + "VALUES (%%s, %%s, %%s)" % table, + [key, b64encoded, exp]) + except DatabaseError: + # To be threadsafe, updates/inserts are allowed to fail silently + return False + else: + return True def delete(self, key, version=None): key = self.make_key(key, version=version) @@ -152,9 +152,9 @@ class DatabaseCache(BaseDatabaseCache): db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) - cursor = connections[db].cursor() - cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key]) + with connections[db].cursor() as cursor: + cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key]) def has_key(self, key, version=None): key = self.make_key(key, version=version) @@ -162,17 +162,18 @@ class DatabaseCache(BaseDatabaseCache): db = router.db_for_read(self.cache_model_class) table = connections[db].ops.quote_name(self._table) - cursor = connections[db].cursor() if settings.USE_TZ: now = datetime.utcnow() else: now = datetime.now() now = now.replace(microsecond=0) - cursor.execute("SELECT cache_key FROM %s " - "WHERE cache_key = %%s and expires > %%s" % table, - [key, connections[db].ops.value_to_db_datetime(now)]) - return cursor.fetchone() is not None + + with connections[db].cursor() as cursor: + cursor.execute("SELECT cache_key FROM %s " + "WHERE cache_key = %%s and expires > %%s" % table, + [key, connections[db].ops.value_to_db_datetime(now)]) + return cursor.fetchone() is not None def _cull(self, db, cursor, now): if self._cull_frequency == 0: @@ -197,8 +198,8 @@ class DatabaseCache(BaseDatabaseCache): def clear(self): db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) - cursor = connections[db].cursor() - cursor.execute('DELETE FROM %s' % table) + with connections[db].cursor() as cursor: + cursor.execute('DELETE FROM %s' % table) # For backwards compatibility diff --git a/django/core/management/commands/createcachetable.py b/django/core/management/commands/createcachetable.py index 10506525fce..909a5d08c8e 100644 --- a/django/core/management/commands/createcachetable.py +++ b/django/core/management/commands/createcachetable.py @@ -72,14 +72,14 @@ class Command(BaseCommand): full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else '')) full_statement.append(');') with transaction.commit_on_success_unless_managed(): - curs = connection.cursor() - try: - curs.execute("\n".join(full_statement)) - except DatabaseError as e: - raise CommandError( - "Cache table '%s' could not be created.\nThe error was: %s." % - (tablename, force_text(e))) - for statement in index_output: - curs.execute(statement) + with connection.cursor() as curs: + try: + curs.execute("\n".join(full_statement)) + except DatabaseError as e: + raise CommandError( + "Cache table '%s' could not be created.\nThe error was: %s." % + (tablename, force_text(e))) + for statement in index_output: + curs.execute(statement) if self.verbosity > 1: self.stdout.write("Cache table '%s' created." % tablename) diff --git a/django/core/management/commands/flush.py b/django/core/management/commands/flush.py index 4a3f7c2d8bc..d99deb951ef 100644 --- a/django/core/management/commands/flush.py +++ b/django/core/management/commands/flush.py @@ -64,9 +64,9 @@ Are you sure you want to do this? if confirm == 'yes': try: with transaction.commit_on_success_unless_managed(): - cursor = connection.cursor() - for sql in sql_list: - cursor.execute(sql) + with connection.cursor() as cursor: + for sql in sql_list: + cursor.execute(sql) except Exception as e: new_msg = ( "Database %s couldn't be flushed. Possible reasons:\n" diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py index 54fdad20017..4a51892e5ab 100644 --- a/django/core/management/commands/inspectdb.py +++ b/django/core/management/commands/inspectdb.py @@ -37,108 +37,108 @@ class Command(NoArgsCommand): table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '') strip_prefix = lambda s: s[1:] if s.startswith("u'") else s - cursor = connection.cursor() - yield "# This is an auto-generated Django model module." - yield "# You'll have to do the following manually to clean this up:" - yield "# * Rearrange models' order" - yield "# * Make sure each model has one field with primary_key=True" - yield "# * Remove `managed = False` lines for those models you wish to give write DB access" - yield "# Feel free to rename the models, but don't rename db_table values or field names." - yield "#" - yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'" - yield "# into your database." - yield "from __future__ import unicode_literals" - yield '' - yield 'from %s import models' % self.db_module - known_models = [] - for table_name in connection.introspection.table_names(cursor): - if table_name_filter is not None and callable(table_name_filter): - if not table_name_filter(table_name): - continue + with connection.cursor() as cursor: + yield "# This is an auto-generated Django model module." + yield "# You'll have to do the following manually to clean this up:" + yield "# * Rearrange models' order" + yield "# * Make sure each model has one field with primary_key=True" + yield "# * Remove `managed = False` lines for those models you wish to give write DB access" + yield "# Feel free to rename the models, but don't rename db_table values or field names." + yield "#" + yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'" + yield "# into your database." + yield "from __future__ import unicode_literals" yield '' - yield '' - yield 'class %s(models.Model):' % table2model(table_name) - known_models.append(table2model(table_name)) - try: - relations = connection.introspection.get_relations(cursor, table_name) - except NotImplementedError: - relations = {} - try: - indexes = connection.introspection.get_indexes(cursor, table_name) - except NotImplementedError: - indexes = {} - used_column_names = [] # Holds column names used in the table so far - for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): - comment_notes = [] # Holds Field notes, to be displayed in a Python comment. - extra_params = OrderedDict() # Holds Field parameters such as 'db_column'. - column_name = row[0] - is_relation = i in relations - - att_name, params, notes = self.normalize_col_name( - column_name, used_column_names, is_relation) - extra_params.update(params) - comment_notes.extend(notes) - - used_column_names.append(att_name) - - # Add primary_key and unique, if necessary. - if column_name in indexes: - if indexes[column_name]['primary_key']: - extra_params['primary_key'] = True - elif indexes[column_name]['unique']: - extra_params['unique'] = True - - if is_relation: - rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1]) - if rel_to in known_models: - field_type = 'ForeignKey(%s' % rel_to - else: - field_type = "ForeignKey('%s'" % rel_to - else: - # Calling `get_field_type` to get the field type string and any - # additional paramters and notes. - field_type, field_params, field_notes = self.get_field_type(connection, table_name, row) - extra_params.update(field_params) - comment_notes.extend(field_notes) - - field_type += '(' - - # Don't output 'id = meta.AutoField(primary_key=True)', because - # that's assumed if it doesn't exist. - if att_name == 'id' and extra_params == {'primary_key': True}: - if field_type == 'AutoField(': + yield 'from %s import models' % self.db_module + known_models = [] + for table_name in connection.introspection.table_names(cursor): + if table_name_filter is not None and callable(table_name_filter): + if not table_name_filter(table_name): continue - elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield: - comment_notes.append('AutoField?') + yield '' + yield '' + yield 'class %s(models.Model):' % table2model(table_name) + known_models.append(table2model(table_name)) + try: + relations = connection.introspection.get_relations(cursor, table_name) + except NotImplementedError: + relations = {} + try: + indexes = connection.introspection.get_indexes(cursor, table_name) + except NotImplementedError: + indexes = {} + used_column_names = [] # Holds column names used in the table so far + for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): + comment_notes = [] # Holds Field notes, to be displayed in a Python comment. + extra_params = OrderedDict() # Holds Field parameters such as 'db_column'. + column_name = row[0] + is_relation = i in relations - # Add 'null' and 'blank', if the 'null_ok' flag was present in the - # table description. - if row[6]: # If it's NULL... - if field_type == 'BooleanField(': - field_type = 'NullBooleanField(' + att_name, params, notes = self.normalize_col_name( + column_name, used_column_names, is_relation) + extra_params.update(params) + comment_notes.extend(notes) + + used_column_names.append(att_name) + + # Add primary_key and unique, if necessary. + if column_name in indexes: + if indexes[column_name]['primary_key']: + extra_params['primary_key'] = True + elif indexes[column_name]['unique']: + extra_params['unique'] = True + + if is_relation: + rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1]) + if rel_to in known_models: + field_type = 'ForeignKey(%s' % rel_to + else: + field_type = "ForeignKey('%s'" % rel_to else: - extra_params['blank'] = True - if not field_type in ('TextField(', 'CharField('): - extra_params['null'] = True + # Calling `get_field_type` to get the field type string and any + # additional paramters and notes. + field_type, field_params, field_notes = self.get_field_type(connection, table_name, row) + extra_params.update(field_params) + comment_notes.extend(field_notes) - field_desc = '%s = %s%s' % ( - att_name, - # Custom fields will have a dotted path - '' if '.' in field_type else 'models.', - field_type, - ) - if extra_params: - if not field_desc.endswith('('): - field_desc += ', ' - field_desc += ', '.join([ - '%s=%s' % (k, strip_prefix(repr(v))) - for k, v in extra_params.items()]) - field_desc += ')' - if comment_notes: - field_desc += ' # ' + ' '.join(comment_notes) - yield ' %s' % field_desc - for meta_line in self.get_meta(table_name): - yield meta_line + field_type += '(' + + # Don't output 'id = meta.AutoField(primary_key=True)', because + # that's assumed if it doesn't exist. + if att_name == 'id' and extra_params == {'primary_key': True}: + if field_type == 'AutoField(': + continue + elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield: + comment_notes.append('AutoField?') + + # Add 'null' and 'blank', if the 'null_ok' flag was present in the + # table description. + if row[6]: # If it's NULL... + if field_type == 'BooleanField(': + field_type = 'NullBooleanField(' + else: + extra_params['blank'] = True + if not field_type in ('TextField(', 'CharField('): + extra_params['null'] = True + + field_desc = '%s = %s%s' % ( + att_name, + # Custom fields will have a dotted path + '' if '.' in field_type else 'models.', + field_type, + ) + if extra_params: + if not field_desc.endswith('('): + field_desc += ', ' + field_desc += ', '.join([ + '%s=%s' % (k, strip_prefix(repr(v))) + for k, v in extra_params.items()]) + field_desc += ')' + if comment_notes: + field_desc += ' # ' + ' '.join(comment_notes) + yield ' %s' % field_desc + for meta_line in self.get_meta(table_name): + yield meta_line def normalize_col_name(self, col_name, used_column_names, is_relation): """ diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index 65bc96ba99b..31a3af6225f 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -100,10 +100,9 @@ class Command(BaseCommand): if sequence_sql: if self.verbosity >= 2: self.stdout.write("Resetting sequences\n") - cursor = connection.cursor() - for line in sequence_sql: - cursor.execute(line) - cursor.close() + with connection.cursor() as cursor: + for line in sequence_sql: + cursor.execute(line) if self.verbosity >= 1: if self.fixture_object_count == self.loaded_object_count: diff --git a/django/core/management/commands/migrate.py b/django/core/management/commands/migrate.py index cb863509c89..60899ef09ff 100644 --- a/django/core/management/commands/migrate.py +++ b/django/core/management/commands/migrate.py @@ -171,105 +171,110 @@ class Command(BaseCommand): "Runs the old syncdb-style operation on a list of app_labels." cursor = connection.cursor() - # Get a list of already installed *models* so that references work right. - tables = connection.introspection.table_names() - seen_models = connection.introspection.installed_models(tables) - created_models = set() - pending_references = {} + try: + # Get a list of already installed *models* so that references work right. + tables = connection.introspection.table_names(cursor) + seen_models = connection.introspection.installed_models(tables) + created_models = set() + pending_references = {} - # Build the manifest of apps and models that are to be synchronized - all_models = [ - (app_config.label, - router.get_migratable_models(app_config, connection.alias, include_auto_created=True)) - for app_config in apps.get_app_configs() - if app_config.models_module is not None and app_config.label in app_labels - ] + # Build the manifest of apps and models that are to be synchronized + all_models = [ + (app_config.label, + router.get_migratable_models(app_config, connection.alias, include_auto_created=True)) + for app_config in apps.get_app_configs() + if app_config.models_module is not None and app_config.label in app_labels + ] - def model_installed(model): - opts = model._meta - converter = connection.introspection.table_name_converter - # Note that if a model is unmanaged we short-circuit and never try to install it - return not ((converter(opts.db_table) in tables) or - (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)) + def model_installed(model): + opts = model._meta + converter = connection.introspection.table_name_converter + # Note that if a model is unmanaged we short-circuit and never try to install it + return not ((converter(opts.db_table) in tables) or + (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)) - manifest = OrderedDict( - (app_name, list(filter(model_installed, model_list))) - for app_name, model_list in all_models - ) + manifest = OrderedDict( + (app_name, list(filter(model_installed, model_list))) + for app_name, model_list in all_models + ) - create_models = set(itertools.chain(*manifest.values())) - emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias) + create_models = set(itertools.chain(*manifest.values())) + emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias) - # Create the tables for each model - if self.verbosity >= 1: - self.stdout.write(" Creating tables...\n") - with transaction.atomic(using=connection.alias, savepoint=False): - for app_name, model_list in manifest.items(): - for model in model_list: - # Create the model's database table, if it doesn't already exist. - if self.verbosity >= 3: - self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name)) - sql, references = connection.creation.sql_create_model(model, no_style(), seen_models) - seen_models.add(model) - created_models.add(model) - for refto, refs in references.items(): - pending_references.setdefault(refto, []).extend(refs) - if refto in seen_models: - sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references)) - sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references)) - if self.verbosity >= 1 and sql: - self.stdout.write(" Creating table %s\n" % model._meta.db_table) - for statement in sql: - cursor.execute(statement) - tables.append(connection.introspection.table_name_converter(model._meta.db_table)) + # Create the tables for each model + if self.verbosity >= 1: + self.stdout.write(" Creating tables...\n") + with transaction.atomic(using=connection.alias, savepoint=False): + for app_name, model_list in manifest.items(): + for model in model_list: + # Create the model's database table, if it doesn't already exist. + if self.verbosity >= 3: + self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name)) + sql, references = connection.creation.sql_create_model(model, no_style(), seen_models) + seen_models.add(model) + created_models.add(model) + for refto, refs in references.items(): + pending_references.setdefault(refto, []).extend(refs) + if refto in seen_models: + sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references)) + sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references)) + if self.verbosity >= 1 and sql: + self.stdout.write(" Creating table %s\n" % model._meta.db_table) + for statement in sql: + cursor.execute(statement) + tables.append(connection.introspection.table_name_converter(model._meta.db_table)) - # We force a commit here, as that was the previous behaviour. - # If you can prove we don't need this, remove it. - transaction.set_dirty(using=connection.alias) + # We force a commit here, as that was the previous behaviour. + # If you can prove we don't need this, remove it. + transaction.set_dirty(using=connection.alias) + finally: + cursor.close() # The connection may have been closed by a syncdb handler. cursor = connection.cursor() + try: + # Install custom SQL for the app (but only if this + # is a model we've just created) + if self.verbosity >= 1: + self.stdout.write(" Installing custom SQL...\n") + for app_name, model_list in manifest.items(): + for model in model_list: + if model in created_models: + custom_sql = custom_sql_for_model(model, no_style(), connection) + if custom_sql: + if self.verbosity >= 2: + self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name)) + try: + with transaction.commit_on_success_unless_managed(using=connection.alias): + for sql in custom_sql: + cursor.execute(sql) + except Exception as e: + self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e)) + if self.show_traceback: + traceback.print_exc() + else: + if self.verbosity >= 3: + self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name)) - # Install custom SQL for the app (but only if this - # is a model we've just created) - if self.verbosity >= 1: - self.stdout.write(" Installing custom SQL...\n") - for app_name, model_list in manifest.items(): - for model in model_list: - if model in created_models: - custom_sql = custom_sql_for_model(model, no_style(), connection) - if custom_sql: - if self.verbosity >= 2: - self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name)) - try: - with transaction.commit_on_success_unless_managed(using=connection.alias): - for sql in custom_sql: - cursor.execute(sql) - except Exception as e: - self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e)) - if self.show_traceback: - traceback.print_exc() - else: - if self.verbosity >= 3: - self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name)) + if self.verbosity >= 1: + self.stdout.write(" Installing indexes...\n") - if self.verbosity >= 1: - self.stdout.write(" Installing indexes...\n") - - # Install SQL indices for all newly created models - for app_name, model_list in manifest.items(): - for model in model_list: - if model in created_models: - index_sql = connection.creation.sql_indexes_for_model(model, no_style()) - if index_sql: - if self.verbosity >= 2: - self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name)) - try: - with transaction.commit_on_success_unless_managed(using=connection.alias): - for sql in index_sql: - cursor.execute(sql) - except Exception as e: - self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e)) + # Install SQL indices for all newly created models + for app_name, model_list in manifest.items(): + for model in model_list: + if model in created_models: + index_sql = connection.creation.sql_indexes_for_model(model, no_style()) + if index_sql: + if self.verbosity >= 2: + self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name)) + try: + with transaction.commit_on_success_unless_managed(using=connection.alias): + for sql in index_sql: + cursor.execute(sql) + except Exception as e: + self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e)) + finally: + cursor.close() # Load initial_data fixtures (unless that has been disabled) if self.load_initial_data: diff --git a/django/core/management/sql.py b/django/core/management/sql.py index ad91ca36c63..ccab11d2bd7 100644 --- a/django/core/management/sql.py +++ b/django/core/management/sql.py @@ -67,38 +67,39 @@ def sql_delete(app_config, style, connection): except Exception: cursor = None - # Figure out which tables already exist - if cursor: - table_names = connection.introspection.table_names(cursor) - else: - table_names = [] + try: + # Figure out which tables already exist + if cursor: + table_names = connection.introspection.table_names(cursor) + else: + table_names = [] - output = [] + output = [] - # Output DROP TABLE statements for standard application tables. - to_delete = set() + # Output DROP TABLE statements for standard application tables. + to_delete = set() - references_to_delete = {} - app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True) - for model in app_models: - if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names: - # The table exists, so it needs to be dropped - opts = model._meta - for f in opts.local_fields: - if f.rel and f.rel.to not in to_delete: - references_to_delete.setdefault(f.rel.to, []).append((model, f)) + references_to_delete = {} + app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True) + for model in app_models: + if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names: + # The table exists, so it needs to be dropped + opts = model._meta + for f in opts.local_fields: + if f.rel and f.rel.to not in to_delete: + references_to_delete.setdefault(f.rel.to, []).append((model, f)) - to_delete.add(model) + to_delete.add(model) - for model in app_models: - if connection.introspection.table_name_converter(model._meta.db_table) in table_names: - output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style)) - - # Close database connection explicitly, in case this output is being piped - # directly into a database client, to avoid locking issues. - if cursor: - cursor.close() - connection.close() + for model in app_models: + if connection.introspection.table_name_converter(model._meta.db_table) in table_names: + output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style)) + finally: + # Close database connection explicitly, in case this output is being piped + # directly into a database client, to avoid locking issues. + if cursor: + cursor.close() + connection.close() return output[::-1] # Reverse it, to deal with table dependencies. diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index b96407056a1..50b8745f475 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object): ##### Backend-specific savepoint management methods ##### def _savepoint(self, sid): - self.cursor().execute(self.ops.savepoint_create_sql(sid)) + with self.cursor() as cursor: + cursor.execute(self.ops.savepoint_create_sql(sid)) def _savepoint_rollback(self, sid): - self.cursor().execute(self.ops.savepoint_rollback_sql(sid)) + with self.cursor() as cursor: + cursor.execute(self.ops.savepoint_rollback_sql(sid)) def _savepoint_commit(self, sid): - self.cursor().execute(self.ops.savepoint_commit_sql(sid)) + with self.cursor() as cursor: + cursor.execute(self.ops.savepoint_commit_sql(sid)) def _savepoint_allowed(self): # Savepoints cannot be created outside a transaction @@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object): # otherwise autocommit will cause the confimation to # fail. self.connection.enter_transaction_management() - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') - self.connection.commit() - cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') - self.connection.rollback() - cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') - count, = cursor.fetchone() - cursor.execute('DROP TABLE ROLLBACK_TEST') - self.connection.commit() + with self.connection.cursor() as cursor: + cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') + self.connection.commit() + cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') + self.connection.rollback() + cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') + count, = cursor.fetchone() + cursor.execute('DROP TABLE ROLLBACK_TEST') + self.connection.commit() finally: self.connection.leave_transaction_management() return count == 0 @@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object): in sorting order between databases. """ if cursor is None: - cursor = self.connection.cursor() + with self.connection.cursor() as cursor: + return sorted(self.get_table_list(cursor)) return sorted(self.get_table_list(cursor)) def get_table_list(self, cursor): diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index ff62f30e712..3ee1e8448e1 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -378,9 +378,8 @@ class BaseDatabaseCreation(object): call_command('createcachetable', database=self.connection.alias) - # Get a cursor (even though we don't need one yet). This has - # the side effect of initializing the test database. - self.connection.cursor() + # Ensure a connection for the side effect of initializing the test database. + self.connection.ensure_connection() return test_database_name @@ -406,34 +405,34 @@ class BaseDatabaseCreation(object): qn = self.connection.ops.quote_name # Create the test database and connect to it. - cursor = self._nodb_connection.cursor() - try: - cursor.execute( - "CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) - except Exception as e: - sys.stderr.write( - "Got an error creating the test database: %s\n" % e) - if not autoclobber: - confirm = input( - "Type 'yes' if you would like to try deleting the test " - "database '%s', or 'no' to cancel: " % test_database_name) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print("Destroying old test database '%s'..." - % self.connection.alias) - cursor.execute( - "DROP DATABASE %s" % qn(test_database_name)) - cursor.execute( - "CREATE DATABASE %s %s" % (qn(test_database_name), - suffix)) - except Exception as e: - sys.stderr.write( - "Got an error recreating the test database: %s\n" % e) - sys.exit(2) - else: - print("Tests cancelled.") - sys.exit(1) + with self._nodb_connection.cursor() as cursor: + try: + cursor.execute( + "CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) + except Exception as e: + sys.stderr.write( + "Got an error creating the test database: %s\n" % e) + if not autoclobber: + confirm = input( + "Type 'yes' if you would like to try deleting the test " + "database '%s', or 'no' to cancel: " % test_database_name) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print("Destroying old test database '%s'..." + % self.connection.alias) + cursor.execute( + "DROP DATABASE %s" % qn(test_database_name)) + cursor.execute( + "CREATE DATABASE %s %s" % (qn(test_database_name), + suffix)) + except Exception as e: + sys.stderr.write( + "Got an error recreating the test database: %s\n" % e) + sys.exit(2) + else: + print("Tests cancelled.") + sys.exit(1) return test_database_name @@ -461,11 +460,11 @@ class BaseDatabaseCreation(object): # ourselves. Connect to the previous database (not the test database) # to do so, because it's not allowed to delete a database while being # connected to it. - cursor = self._nodb_connection.cursor() - # Wait to avoid "database is being accessed by other users" errors. - time.sleep(1) - cursor.execute("DROP DATABASE %s" - % self.connection.ops.quote_name(test_database_name)) + with self._nodb_connection.cursor() as cursor: + # Wait to avoid "database is being accessed by other users" errors. + time.sleep(1) + cursor.execute("DROP DATABASE %s" + % self.connection.ops.quote_name(test_database_name)) def set_autocommit(self): """ diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index e7932dd800f..9d3935dc549 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def _mysql_storage_engine(self): "Internal method used in Django tests. Don't rely on this from your code" - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') - # This command is MySQL specific; the second column - # will tell you the default table type of the created - # table. Since all Django's test tables will have the same - # table type, that's enough to evaluate the feature. - cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") - result = cursor.fetchone() - cursor.execute('DROP TABLE INTROSPECT_TEST') + with self.connection.cursor() as cursor: + cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') + # This command is MySQL specific; the second column + # will tell you the default table type of the created + # table. Since all Django's test tables will have the same + # table type, that's enough to evaluate the feature. + cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") + result = cursor.fetchone() + cursor.execute('DROP TABLE INTROSPECT_TEST') return result[1] @cached_property @@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): return False # Test if the time zone definitions are installed. - cursor = self.connection.cursor() - cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") - return cursor.fetchone() is not None + with self.connection.cursor() as cursor: + cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") + return cursor.fetchone() is not None class DatabaseOperations(BaseDatabaseOperations): @@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): return conn def init_connection_state(self): - cursor = self.connection.cursor() - # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column - # on a recently-inserted row will return when the field is tested for - # NULL. Disabling this value brings this aspect of MySQL in line with - # SQL standards. - cursor.execute('SET SQL_AUTO_IS_NULL = 0') - cursor.close() + with self.connection.cursor() as cursor: + # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column + # on a recently-inserted row will return when the field is tested for + # NULL. Disabling this value brings this aspect of MySQL in line with + # SQL standards. + cursor.execute('SET SQL_AUTO_IS_NULL = 0') def create_cursor(self): cursor = self.connection.cursor() diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index cdb101d20ca..2495986a02f 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL) def regex_lookup(self, lookup_type): # If regex_lookup is called before it's been initialized, then create # a cursor to initialize it and recur. - self.connection.cursor() - return self.connection.ops.regex_lookup(lookup_type) + with self.connection.cursor(): + return self.connection.ops.regex_lookup(lookup_type) def return_insert_id(self): return "RETURNING %s INTO %%s", (InsertIdVar(),) diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 33f885d50c7..e89a4e604af 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): if conn_tz != tz: cursor = self.connection.cursor() - cursor.execute(self.ops.set_time_zone_sql(), [tz]) - cursor.close() + try: + cursor.execute(self.ops.set_time_zone_sql(), [tz]) + finally: + cursor.close() # Commit after setting the time zone (see #17062) if not self.get_autocommit(): self.connection.commit() diff --git a/django/db/backends/postgresql_psycopg2/version.py b/django/db/backends/postgresql_psycopg2/version.py index dae94f2dacd..64fd7c82985 100644 --- a/django/db/backends/postgresql_psycopg2/version.py +++ b/django/db/backends/postgresql_psycopg2/version.py @@ -39,6 +39,6 @@ def get_version(connection): if hasattr(connection, 'server_version'): return connection.server_version else: - cursor = connection.cursor() - cursor.execute("SELECT version()") - return _parse_version(cursor.fetchone()[0]) + with connection.cursor() as cursor: + cursor.execute("SELECT version()") + return _parse_version(cursor.fetchone()[0]) diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index e6905677386..88cc8944372 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object): """ Executes the given SQL statement, with optional parameters. """ - # Get the cursor - cursor = self.connection.cursor() # Log the command we're running, then run it logger.debug("%s; (params %r)" % (sql, params)) if self.collect_sql: self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";") else: - cursor.execute(sql, params) + with self.connection.cursor() as cursor: + cursor.execute(sql, params) def quote_name(self, name): return self.connection.ops.quote_name(name) @@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object): Returns all constraint names matching the columns and conditions """ column_names = list(column_names) if column_names else None - constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) + with self.connection.cursor() as cursor: + constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table) result = [] for name, infodict in constraints.items(): if column_names is None or column_names == infodict['columns']: diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 3c8f170b7d9..2adfbacaa95 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): rule out support for STDDEV. We need to manually check whether the call works. """ - cursor = self.connection.cursor() - cursor.execute('CREATE TABLE STDDEV_TEST (X INT)') - try: - cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST') - has_support = True - except utils.DatabaseError: - has_support = False - cursor.execute('DROP TABLE STDDEV_TEST') + with self.connection.cursor() as cursor: + cursor.execute('CREATE TABLE STDDEV_TEST (X INT)') + try: + cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST') + has_support = True + except utils.DatabaseError: + has_support = False + cursor.execute('DROP TABLE STDDEV_TEST') return has_support @cached_property diff --git a/django/db/models/query.py b/django/db/models/query.py index 353dd957943..6051b9f859a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1522,54 +1522,59 @@ class RawQuerySet(object): query = iter(self.query) - # Find out which columns are model's fields, and which ones should be - # annotated to the model. - for pos, column in enumerate(self.columns): - if column in self.model_fields: - model_init_field_names[self.model_fields[column].attname] = pos - else: - annotation_fields.append((column, pos)) + try: + # Find out which columns are model's fields, and which ones should be + # annotated to the model. + for pos, column in enumerate(self.columns): + if column in self.model_fields: + model_init_field_names[self.model_fields[column].attname] = pos + else: + annotation_fields.append((column, pos)) - # Find out which model's fields are not present in the query. - skip = set() - for field in self.model._meta.fields: - if field.attname not in model_init_field_names: - skip.add(field.attname) - if skip: - if self.model._meta.pk.attname in skip: - raise InvalidQuery('Raw query must include the primary key') - model_cls = deferred_class_factory(self.model, skip) - else: - model_cls = self.model - # All model's fields are present in the query. So, it is possible - # to use *args based model instantation. For each field of the model, - # record the query column position matching that field. - model_init_field_pos = [] + # Find out which model's fields are not present in the query. + skip = set() for field in self.model._meta.fields: - model_init_field_pos.append(model_init_field_names[field.attname]) - if need_resolv_columns: - fields = [self.model_fields.get(c, None) for c in self.columns] - # Begin looping through the query values. - for values in query: - if need_resolv_columns: - values = compiler.resolve_columns(values, fields) - # Associate fields to values + if field.attname not in model_init_field_names: + skip.add(field.attname) if skip: - model_init_kwargs = {} - for attname, pos in six.iteritems(model_init_field_names): - model_init_kwargs[attname] = values[pos] - instance = model_cls(**model_init_kwargs) + if self.model._meta.pk.attname in skip: + raise InvalidQuery('Raw query must include the primary key') + model_cls = deferred_class_factory(self.model, skip) else: - model_init_args = [values[pos] for pos in model_init_field_pos] - instance = model_cls(*model_init_args) - if annotation_fields: - for column, pos in annotation_fields: - setattr(instance, column, values[pos]) + model_cls = self.model + # All model's fields are present in the query. So, it is possible + # to use *args based model instantation. For each field of the model, + # record the query column position matching that field. + model_init_field_pos = [] + for field in self.model._meta.fields: + model_init_field_pos.append(model_init_field_names[field.attname]) + if need_resolv_columns: + fields = [self.model_fields.get(c, None) for c in self.columns] + # Begin looping through the query values. + for values in query: + if need_resolv_columns: + values = compiler.resolve_columns(values, fields) + # Associate fields to values + if skip: + model_init_kwargs = {} + for attname, pos in six.iteritems(model_init_field_names): + model_init_kwargs[attname] = values[pos] + instance = model_cls(**model_init_kwargs) + else: + model_init_args = [values[pos] for pos in model_init_field_pos] + instance = model_cls(*model_init_args) + if annotation_fields: + for column, pos in annotation_fields: + setattr(instance, column, values[pos]) - instance._state.db = db - instance._state.adding = False + instance._state.db = db + instance._state.adding = False - yield instance + yield instance + finally: + # Done iterating the Query. If it has its own cursor, close it. + if hasattr(self.query, 'cursor') and self.query.cursor: + self.query.cursor.close() def __repr__(self): text = self.raw_query diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 536a66d1399..d9161d820c2 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1,4 +1,5 @@ import datetime +import sys from django.conf import settings from django.core.exceptions import FieldError @@ -777,7 +778,7 @@ class SQLCompiler(object): cursor = self.connection.cursor() try: cursor.execute(sql, params) - except: + except Exception: cursor.close() raise @@ -908,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler): def execute_sql(self, return_id=False): assert not (return_id and len(self.query.objs) != 1) self.return_id = return_id - cursor = self.connection.cursor() - for sql, params in self.as_sql(): - cursor.execute(sql, params) - if not (return_id and cursor): - return - if self.connection.features.can_return_id_from_insert: - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id(cursor, - self.query.get_meta().db_table, self.query.get_meta().pk.column) + with self.connection.cursor() as cursor: + for sql, params in self.as_sql(): + cursor.execute(sql, params) + if not (return_id and cursor): + return + if self.connection.features.can_return_id_from_insert: + return self.connection.ops.fetch_returned_insert_id(cursor) + return self.connection.ops.last_insert_id(cursor, + self.query.get_meta().db_table, self.query.get_meta().pk.column) class SQLDeleteCompiler(SQLCompiler): diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 4a3fc31b7aa..f3c38893f46 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -59,9 +59,9 @@ class OracleChecks(unittest.TestCase): # stored procedure through our cursor wrapper. from django.db.backends.oracle.base import convert_unicode - cursor = connection.cursor() - cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'), - [convert_unicode('_django_testing!')]) + with connection.cursor() as cursor: + cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'), + [convert_unicode('_django_testing!')]) @unittest.skipUnless(connection.vendor == 'oracle', "No need to check Oracle cursor semantics") @@ -70,31 +70,31 @@ class OracleChecks(unittest.TestCase): # as query parameters. from django.db.backends.oracle.base import Database - cursor = connection.cursor() - var = cursor.var(Database.STRING) - cursor.execute("BEGIN %s := 'X'; END; ", [var]) - self.assertEqual(var.getvalue(), 'X') + with connection.cursor() as cursor: + var = cursor.var(Database.STRING) + cursor.execute("BEGIN %s := 'X'; END; ", [var]) + self.assertEqual(var.getvalue(), 'X') @unittest.skipUnless(connection.vendor == 'oracle', "No need to check Oracle cursor semantics") def test_long_string(self): # If the backend is Oracle, test that we can save a text longer # than 4000 chars and read it properly - c = connection.cursor() - c.execute('CREATE TABLE ltext ("TEXT" NCLOB)') - long_str = ''.join(six.text_type(x) for x in xrange(4000)) - c.execute('INSERT INTO ltext VALUES (%s)', [long_str]) - c.execute('SELECT text FROM ltext') - row = c.fetchone() - self.assertEqual(long_str, row[0].read()) - c.execute('DROP TABLE ltext') + with connection.cursor() as cursor: + cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)') + long_str = ''.join(six.text_type(x) for x in xrange(4000)) + cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str]) + cursor.execute('SELECT text FROM ltext') + row = cursor.fetchone() + self.assertEqual(long_str, row[0].read()) + cursor.execute('DROP TABLE ltext') @unittest.skipUnless(connection.vendor == 'oracle', "No need to check Oracle connection semantics") def test_client_encoding(self): # If the backend is Oracle, test that the client encoding is set # correctly. This was broken under Cygwin prior to r14781. - connection.cursor() # Ensure the connection is initialized. + self.connection.ensure_connection() self.assertEqual(connection.connection.encoding, "UTF-8") self.assertEqual(connection.connection.nencoding, "UTF-8") @@ -103,12 +103,12 @@ class OracleChecks(unittest.TestCase): def test_order_of_nls_parameters(self): # an 'almost right' datetime should work with configured # NLS parameters as per #18465. - c = connection.cursor() - query = "select 1 from dual where '1936-12-29 00:00' < sysdate" - # Test that the query succeeds without errors - pre #18465 this - # wasn't the case. - c.execute(query) - self.assertEqual(c.fetchone()[0], 1) + with connection.cursor() as cursor: + query = "select 1 from dual where '1936-12-29 00:00' < sysdate" + # Test that the query succeeds without errors - pre #18465 this + # wasn't the case. + cursor.execute(query) + self.assertEqual(cursor.fetchone()[0], 1) class SQLiteTests(TestCase): @@ -328,6 +328,12 @@ class PostgresVersionTest(TestCase): def fetchone(self): return ["PostgreSQL 8.3"] + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + pass + class OlderConnectionMock(object): "Mock of psycopg2 (< 2.0.12) connection" def cursor(self): diff --git a/tests/cache/tests.py b/tests/cache/tests.py index 94790ed740e..bc0f705375e 100644 --- a/tests/cache/tests.py +++ b/tests/cache/tests.py @@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase): management.call_command('createcachetable', verbosity=0, interactive=False) def drop_table(self): - cursor = connection.cursor() - table_name = connection.ops.quote_name('test cache table') - cursor.execute('DROP TABLE %s' % table_name) - cursor.close() + with connection.cursor() as cursor: + table_name = connection.ops.quote_name('test cache table') + cursor.execute('DROP TABLE %s' % table_name) def test_zero_cull(self): self._perform_cull_test(caches['zero_cull'], 50, 18) diff --git a/tests/custom_methods/models.py b/tests/custom_methods/models.py index cef3fd722be..78e00a99b8a 100644 --- a/tests/custom_methods/models.py +++ b/tests/custom_methods/models.py @@ -30,11 +30,11 @@ class Article(models.Model): database query for the sake of demonstration. """ from django.db import connection - cursor = connection.cursor() - cursor.execute(""" - SELECT id, headline, pub_date - FROM custom_methods_article - WHERE pub_date = %s - AND id != %s""", [connection.ops.value_to_db_date(self.pub_date), - self.id]) - return [self.__class__(*row) for row in cursor.fetchall()] + with connection.cursor() as cursor: + cursor.execute(""" + SELECT id, headline, pub_date + FROM custom_methods_article + WHERE pub_date = %s + AND id != %s""", [connection.ops.value_to_db_date(self.pub_date), + self.id]) + return [self.__class__(*row) for row in cursor.fetchall()] diff --git a/tests/initial_sql_regress/tests.py b/tests/initial_sql_regress/tests.py index e725f4b102a..428d993667e 100644 --- a/tests/initial_sql_regress/tests.py +++ b/tests/initial_sql_regress/tests.py @@ -28,9 +28,9 @@ class InitialSQLTests(TestCase): connection = connections[DEFAULT_DB_ALIAS] custom_sql = custom_sql_for_model(Simple, no_style(), connection) self.assertEqual(len(custom_sql), 9) - cursor = connection.cursor() - for sql in custom_sql: - cursor.execute(sql) + with connection.cursor() as cursor: + for sql in custom_sql: + cursor.execute(sql) self.assertEqual(Simple.objects.count(), 9) self.assertEqual( Simple.objects.get(name__contains='placeholders').name, diff --git a/tests/introspection/tests.py b/tests/introspection/tests.py index 8ec3d39903e..0c339bc8ead 100644 --- a/tests/introspection/tests.py +++ b/tests/introspection/tests.py @@ -23,17 +23,17 @@ class IntrospectionTests(TestCase): "'%s' isn't in table_list()." % Article._meta.db_table) def test_django_table_names(self): - cursor = connection.cursor() - cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') - tl = connection.introspection.django_table_names() - cursor.execute("DROP TABLE django_ixn_test_table;") - self.assertTrue('django_ixn_testcase_table' not in tl, - "django_table_names() returned a non-Django table") + with connection.cursor() as cursor: + cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') + tl = connection.introspection.django_table_names() + cursor.execute("DROP TABLE django_ixn_test_table;") + self.assertTrue('django_ixn_testcase_table' not in tl, + "django_table_names() returned a non-Django table") def test_django_table_names_retval_type(self): # Ticket #15216 - cursor = connection.cursor() - cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') + with connection.cursor() as cursor: + cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') tl = connection.introspection.django_table_names(only_existing=True) self.assertIs(type(tl), list) @@ -53,14 +53,14 @@ class IntrospectionTests(TestCase): 'Reporter sequence not found in sequence_list()') def test_get_table_description_names(self): - cursor = connection.cursor() - desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) + with connection.cursor() as cursor: + desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) self.assertEqual([r[0] for r in desc], [f.column for f in Reporter._meta.fields]) def test_get_table_description_types(self): - cursor = connection.cursor() - desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) + with connection.cursor() as cursor: + desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) # The MySQL exception is due to the cursor.description returning the same constant for # text and blob columns. TODO: use information_schema database to retrieve the proper # field type on MySQL @@ -75,8 +75,8 @@ class IntrospectionTests(TestCase): # inspect the length of character columns). @expectedFailureOnOracle def test_get_table_description_col_lengths(self): - cursor = connection.cursor() - desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) + with connection.cursor() as cursor: + desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) self.assertEqual( [r[3] for r in desc if datatype(r[1], r) == 'CharField'], [30, 30, 75] @@ -87,8 +87,8 @@ class IntrospectionTests(TestCase): # so its idea about null_ok in cursor.description is different from ours. @skipIfDBFeature('interprets_empty_strings_as_nulls') def test_get_table_description_nullable(self): - cursor = connection.cursor() - desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) + with connection.cursor() as cursor: + desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) self.assertEqual( [r[6] for r in desc], [False, False, False, False, True, True] @@ -97,15 +97,15 @@ class IntrospectionTests(TestCase): # Regression test for #9991 - 'real' types in postgres @skipUnlessDBFeature('has_real_datatype') def test_postgresql_real_type(self): - cursor = connection.cursor() - cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);") - desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table') - cursor.execute('DROP TABLE django_ixn_real_test_table;') + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);") + desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table') + cursor.execute('DROP TABLE django_ixn_real_test_table;') self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField') def test_get_relations(self): - cursor = connection.cursor() - relations = connection.introspection.get_relations(cursor, Article._meta.db_table) + with connection.cursor() as cursor: + relations = connection.introspection.get_relations(cursor, Article._meta.db_table) # Older versions of MySQL don't have the chops to report on this stuff, # so just skip it if no relations come back. If they do, though, we @@ -117,21 +117,21 @@ class IntrospectionTests(TestCase): @skipUnlessDBFeature('can_introspect_foreign_keys') def test_get_key_columns(self): - cursor = connection.cursor() - key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table) + with connection.cursor() as cursor: + key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table) self.assertEqual( set(key_columns), set([('reporter_id', Reporter._meta.db_table, 'id'), ('response_to_id', Article._meta.db_table, 'id')])) def test_get_primary_key_column(self): - cursor = connection.cursor() - primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) + with connection.cursor() as cursor: + primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) self.assertEqual(primary_key_column, 'id') def test_get_indexes(self): - cursor = connection.cursor() - indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table) + with connection.cursor() as cursor: + indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table) self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False}) def test_get_indexes_multicol(self): @@ -139,8 +139,8 @@ class IntrospectionTests(TestCase): Test that multicolumn indexes are not included in the introspection results. """ - cursor = connection.cursor() - indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table) + with connection.cursor() as cursor: + indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table) self.assertNotIn('first_name', indexes) self.assertIn('id', indexes) diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index 7ab09b04a5b..2dba30b2aab 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -9,33 +9,40 @@ class MigrationTestBase(TransactionTestCase): available_apps = ["migrations"] + def get_table_description(self, table): + with connection.cursor() as cursor: + return connection.introspection.get_table_description(cursor, table) + def assertTableExists(self, table): - self.assertIn(table, connection.introspection.get_table_list(connection.cursor())) + with connection.cursor() as cursor: + self.assertIn(table, connection.introspection.get_table_list(cursor)) def assertTableNotExists(self, table): - self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor())) + with connection.cursor() as cursor: + self.assertNotIn(table, connection.introspection.get_table_list(cursor)) def assertColumnExists(self, table, column): - self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) + self.assertIn(column, [c.name for c in self.get_table_description(table)]) def assertColumnNotExists(self, table, column): - self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) + self.assertNotIn(column, [c.name for c in self.get_table_description(table)]) def assertColumnNull(self, table, column): - self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True) + self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True) def assertColumnNotNull(self, table, column): - self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False) + self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False) def assertIndexExists(self, table, columns, value=True): - self.assertEqual( - value, - any( - c["index"] - for c in connection.introspection.get_constraints(connection.cursor(), table).values() - if c['columns'] == list(columns) - ), - ) + with connection.cursor() as cursor: + self.assertEqual( + value, + any( + c["index"] + for c in connection.introspection.get_constraints(cursor, table).values() + if c['columns'] == list(columns) + ), + ) def assertIndexNotExists(self, table, columns): return self.assertIndexExists(table, columns, False) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index eda356fd5d4..375a9ccc54c 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -19,15 +19,15 @@ class OperationTests(MigrationTestBase): Creates a test model state and database table. """ # Delete the tables if they already exist - cursor = connection.cursor() - try: - cursor.execute("DROP TABLE %s_pony" % app_label) - except: - pass - try: - cursor.execute("DROP TABLE %s_stable" % app_label) - except: - pass + with connection.cursor() as cursor: + try: + cursor.execute("DROP TABLE %s_pony" % app_label) + except: + pass + try: + cursor.execute("DROP TABLE %s_stable" % app_label) + except: + pass # Make the "current" state operations = [migrations.CreateModel( "Pony", @@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase): operation.state_forwards("test_alflpkfk", new_state) self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField) + + def assertIdTypeEqualsFkType(self): + with connection.cursor() as cursor: + id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0] + fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0] + self.assertEqual(id_type, fk_type) + assertIdTypeEqualsFkType() # Test the database alteration - id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] - fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] - self.assertEqual(id_type, fk_type) with connection.schema_editor() as editor: operation.database_forwards("test_alflpkfk", editor, project_state, new_state) - id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] - fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] - self.assertEqual(id_type, fk_type) + assertIdTypeEqualsFkType() # And test reversal with connection.schema_editor() as editor: operation.database_backwards("test_alflpkfk", editor, new_state, project_state) - id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] - fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] - self.assertEqual(id_type, fk_type) + assertIdTypeEqualsFkType() def test_rename_field(self): """ @@ -400,24 +400,24 @@ class OperationTests(MigrationTestBase): self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0) self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1) # Make sure we can insert duplicate rows - cursor = connection.cursor() - cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") - cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") - cursor.execute("DELETE FROM test_alunto_pony") - # Test the database alteration - with connection.schema_editor() as editor: - operation.database_forwards("test_alunto", editor, project_state, new_state) - cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") - with self.assertRaises(IntegrityError): - with atomic(): - cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") - cursor.execute("DELETE FROM test_alunto_pony") - # And test reversal - with connection.schema_editor() as editor: - operation.database_backwards("test_alunto", editor, new_state, project_state) - cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") - cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") - cursor.execute("DELETE FROM test_alunto_pony") + with connection.cursor() as cursor: + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") + cursor.execute("DELETE FROM test_alunto_pony") + # Test the database alteration + with connection.schema_editor() as editor: + operation.database_forwards("test_alunto", editor, project_state, new_state) + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") + with self.assertRaises(IntegrityError): + with atomic(): + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") + cursor.execute("DELETE FROM test_alunto_pony") + # And test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_alunto", editor, new_state, project_state) + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") + cursor.execute("DELETE FROM test_alunto_pony") # Test flat unique_together operation = migrations.AlterUniqueTogether("Pony", ("pink", "weight")) operation.state_forwards("test_alunto", new_state) diff --git a/tests/requests/tests.py b/tests/requests/tests.py index bb563693717..3e594376aa2 100644 --- a/tests/requests/tests.py +++ b/tests/requests/tests.py @@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase): # request_finished signal. response = self.client.get('/') # Make sure there is an open connection - connection.cursor() + self.connection.ensure_connection() connection.enter_transaction_management() signals.request_finished.send(sender=response._handler_class) self.assertEqual(len(connection.transaction_state), 0) diff --git a/tests/schema/tests.py b/tests/schema/tests.py index 450362ecdab..a502ef6002e 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -37,38 +37,38 @@ class SchemaTests(TransactionTestCase): def delete_tables(self): "Deletes all model tables for our models for a clean test environment" - cursor = connection.cursor() - connection.disable_constraint_checking() - table_names = connection.introspection.table_names(cursor) - for model in self.models: - # Remove any M2M tables first - for field in model._meta.local_many_to_many: + with connection.cursor() as cursor: + connection.disable_constraint_checking() + table_names = connection.introspection.table_names(cursor) + for model in self.models: + # Remove any M2M tables first + for field in model._meta.local_many_to_many: + with atomic(): + tbl = field.rel.through._meta.db_table + if tbl in table_names: + cursor.execute(connection.schema_editor().sql_delete_table % { + "table": connection.ops.quote_name(tbl), + }) + table_names.remove(tbl) + # Then remove the main tables with atomic(): - tbl = field.rel.through._meta.db_table + tbl = model._meta.db_table if tbl in table_names: cursor.execute(connection.schema_editor().sql_delete_table % { "table": connection.ops.quote_name(tbl), }) table_names.remove(tbl) - # Then remove the main tables - with atomic(): - tbl = model._meta.db_table - if tbl in table_names: - cursor.execute(connection.schema_editor().sql_delete_table % { - "table": connection.ops.quote_name(tbl), - }) - table_names.remove(tbl) connection.enable_constraint_checking() def column_classes(self, model): - cursor = connection.cursor() - columns = dict( - (d[0], (connection.introspection.get_field_type(d[1], d), d)) - for d in connection.introspection.get_table_description( - cursor, - model._meta.db_table, + with connection.cursor() as cursor: + columns = dict( + (d[0], (connection.introspection.get_field_type(d[1], d), d)) + for d in connection.introspection.get_table_description( + cursor, + model._meta.db_table, + ) ) - ) # SQLite has a different format for field_type for name, (type, desc) in columns.items(): if isinstance(type, tuple): @@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase): raise DatabaseError("Table does not exist (empty pragma)") return columns + def get_indexes(self, table): + """ + Get the indexes on the table using a new cursor. + """ + with connection.cursor() as cursor: + return connection.introspection.get_indexes(cursor, table) + + def get_constraints(self, table): + """ + Get the constraints on a table using a new cursor. + """ + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + # Tests def test_creation_deletion(self): @@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase): strict=True, ) # Make sure the new FK constraint is present - constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table) + constraints = self.get_constraints(Book._meta.db_table) for name, details in constraints.items(): if details['columns'] == ["author_id"] and details['foreign_key']: self.assertEqual(details['foreign_key'], ('schema_tag', 'id')) @@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase): editor.create_model(TagM2MTest) editor.create_model(UniqueTest) # Ensure the M2M exists and points to TagM2MTest - constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table) + constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table) if connection.features.supports_foreign_keys: for name, details in constraints.items(): if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']: @@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase): # Ensure old M2M is gone self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through) # Ensure the new M2M exists and points to UniqueTest - constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table) + constraints = self.get_constraints(new_field.rel.through._meta.db_table) if connection.features.supports_foreign_keys: for name, details in constraints.items(): if details['columns'] == ["uniquetest_id"] and details['foreign_key']: @@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase): with connection.schema_editor() as editor: editor.create_model(Author) # Ensure the constraint exists - constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) + constraints = self.get_constraints(Author._meta.db_table) for name, details in constraints.items(): if details['columns'] == ["height"] and details['check']: break @@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase): new_field, strict=True, ) - constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) + constraints = self.get_constraints(Author._meta.db_table) for name, details in constraints.items(): if details['columns'] == ["height"] and details['check']: self.fail("Check constraint for height found") @@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase): Author._meta.get_field_by_name("height")[0], strict=True, ) - constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) + constraints = self.get_constraints(Author._meta.db_table) for name, details in constraints.items(): if details['columns'] == ["height"] and details['check']: break @@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase): False, any( c["index"] - for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() + for c in self.get_constraints("schema_tag").values() if c['columns'] == ["slug", "title"] ), ) @@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase): True, any( c["index"] - for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() + for c in self.get_constraints("schema_tag").values() if c['columns'] == ["slug", "title"] ), ) @@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase): False, any( c["index"] - for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() + for c in self.get_constraints("schema_tag").values() if c['columns'] == ["slug", "title"] ), ) @@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase): True, any( c["index"] - for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values() + for c in self.get_constraints("schema_tagindexed").values() if c['columns'] == ["slug", "title"] ), ) @@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase): # Ensure the table is there and has the right index self.assertIn( "title", - connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), + self.get_indexes(Book._meta.db_table), ) # Alter to remove the index new_field = CharField(max_length=100, db_index=False) @@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase): # Ensure the table is there and has no index self.assertNotIn( "title", - connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), + self.get_indexes(Book._meta.db_table), ) # Alter to re-add the index with connection.schema_editor() as editor: @@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase): # Ensure the table is there and has the index again self.assertIn( "title", - connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), + self.get_indexes(Book._meta.db_table), ) # Add a unique column, verify that creates an implicit index with connection.schema_editor() as editor: @@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase): ) self.assertIn( "slug", - connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), + self.get_indexes(Book._meta.db_table), ) # Remove the unique, check the index goes with it new_field2 = CharField(max_length=20, unique=False) @@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase): ) self.assertNotIn( "slug", - connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), + self.get_indexes(Book._meta.db_table), ) def test_primary_key(self): @@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase): editor.create_model(Tag) # Ensure the table is there and has the right PK self.assertTrue( - connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'], + self.get_indexes(Tag._meta.db_table)['id']['primary_key'], ) # Alter to change the PK new_field = SlugField(primary_key=True) @@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase): # Ensure the PK changed self.assertNotIn( 'id', - connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table), + self.get_indexes(Tag._meta.db_table), ) self.assertTrue( - connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'], + self.get_indexes(Tag._meta.db_table)['slug']['primary_key'], ) def test_context_manager_exit(self): @@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase): # Ensure the table is there and has an index on the column self.assertIn( column_name, - connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table), + self.get_indexes(BookWithLongName._meta.db_table), ) def test_creation_deletion_reserved_names(self): diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py index 5c38bc8ef26..e7ce43cd935 100644 --- a/tests/transactions/tests.py +++ b/tests/transactions/tests.py @@ -202,8 +202,9 @@ class AtomicTests(TransactionTestCase): # trigger a database error inside an inner atomic without savepoint with self.assertRaises(DatabaseError): with transaction.atomic(savepoint=False): - connection.cursor().execute( - "SELECT no_such_col FROM transactions_reporter") + with connection.cursor() as cursor: + cursor.execute( + "SELECT no_such_col FROM transactions_reporter") # prevent atomic from rolling back since we're recovering manually self.assertTrue(transaction.get_rollback()) transaction.set_rollback(False) @@ -534,8 +535,8 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa available_apps = ['transactions'] def execute_bad_sql(self): - cursor = connection.cursor() - cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") + with connection.cursor() as cursor: + cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") @skipUnlessDBFeature('requires_rollback_on_dirty_transaction') def test_bad_sql(self): @@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction """ with self.assertRaises(IntegrityError): with transaction.commit_on_success(): - cursor = connection.cursor() - cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") + with connection.cursor() as cursor: + cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") transaction.rollback() diff --git a/tests/transactions_regress/tests.py b/tests/transactions_regress/tests.py index cada46edb2f..1f9f2913077 100644 --- a/tests/transactions_regress/tests.py +++ b/tests/transactions_regress/tests.py @@ -54,8 +54,8 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase @commit_on_success def raw_sql(): "Write a record using raw sql under a commit_on_success decorator" - cursor = connection.cursor() - cursor.execute("INSERT into transactions_regress_mod (fld) values (18)") + with connection.cursor() as cursor: + cursor.execute("INSERT into transactions_regress_mod (fld) values (18)") raw_sql() # Rollback so that if the decorator didn't commit, the record is unwritten @@ -143,10 +143,10 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase (reference). All this under commit_on_success, so the second insert should be committed. """ - cursor = connection.cursor() - cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") - transaction.rollback() - cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") + with connection.cursor() as cursor: + cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") + transaction.rollback() + cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") reuse_cursor_ref() # Rollback so that if the decorator didn't commit, the record is unwritten