diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index ce1da6d4c4..2db0acca7c 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -2,11 +2,13 @@ import sys import time from django.conf import settings +from django.db.utils import load_backend # The prefix to put on the default database name when creating # the test database. TEST_DATABASE_PREFIX = 'test_' + class BaseDatabaseCreation(object): """ This class encapsulates all backend-specific differences that pertain to @@ -57,35 +59,45 @@ class BaseDatabaseCreation(object): if tablespace and f.unique: # We must specify the index tablespace inline, because we # won't be generating a CREATE INDEX statement for this field. - tablespace_sql = self.connection.ops.tablespace_sql(tablespace, inline=True) + tablespace_sql = self.connection.ops.tablespace_sql( + tablespace, inline=True) if tablespace_sql: field_output.append(tablespace_sql) if f.rel: - ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style) + ref_output, pending = self.sql_for_inline_foreign_key_references( + f, known_models, style) if pending: - pending_references.setdefault(f.rel.to, []).append((model, f)) + pending_references.setdefault(f.rel.to, []).append( + (model, f)) else: field_output.extend(ref_output) table_output.append(' '.join(field_output)) for field_constraints in opts.unique_together: - table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \ - ", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints])) + table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % + ", ".join( + [style.SQL_FIELD(qn(opts.get_field(f).column)) + for f in field_constraints])) - full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' ('] + full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + + style.SQL_TABLE(qn(opts.db_table)) + ' ('] for i, line in enumerate(table_output): # Combine and add commas. - full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or '')) + full_statement.append( + ' %s%s' % (line, i < len(table_output)-1 and ',' or '')) full_statement.append(')') if opts.db_tablespace: - tablespace_sql = self.connection.ops.tablespace_sql(opts.db_tablespace) + tablespace_sql = self.connection.ops.tablespace_sql( + opts.db_tablespace) if tablespace_sql: full_statement.append(tablespace_sql) full_statement.append(';') final_output.append('\n'.join(full_statement)) if opts.has_auto_field: - # Add any extra SQL needed to support auto-incrementing primary keys. + # Add any extra SQL needed to support auto-incrementing primary + # keys. auto_column = opts.auto_field.db_column or opts.auto_field.name - autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column) + autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, + auto_column) if autoinc_sql: for stmt in autoinc_sql: final_output.append(stmt) @@ -93,12 +105,15 @@ class BaseDatabaseCreation(object): return final_output, pending_references def sql_for_inline_foreign_key_references(self, field, known_models, style): - "Return the SQL snippet defining the foreign key reference for a field" + """ + Return the SQL snippet defining the foreign key reference for a field. + """ qn = self.connection.ops.quote_name if field.rel.to in known_models: - output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \ - style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \ - style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' + + output = [style.SQL_KEYWORD('REFERENCES') + ' ' + + style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + + style.SQL_FIELD(qn(field.rel.to._meta.get_field( + field.rel.field_name).column)) + ')' + self.connection.ops.deferrable_sql() ] pending = False @@ -111,7 +126,9 @@ class BaseDatabaseCreation(object): return output, pending def sql_for_pending_references(self, model, style, pending_references): - "Returns any ALTER TABLE statements to add constraints after the fact." + """ + Returns any ALTER TABLE statements to add constraints after the fact. + """ from django.db.backends.util import truncate_name if not model._meta.managed or model._meta.proxy: @@ -128,16 +145,21 @@ class BaseDatabaseCreation(object): col = opts.get_field(f.rel.field_name).column # For MySQL, r_name must be unique in the first 64 characters. # So we are careful with character usage here. - r_name = '%s_refs_%s_%s' % (r_col, col, self._digest(r_table, table)) - final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \ - (qn(r_table), qn(truncate_name(r_name, self.connection.ops.max_name_length())), + r_name = '%s_refs_%s_%s' % ( + r_col, col, self._digest(r_table, table)) + final_output.append(style.SQL_KEYWORD('ALTER TABLE') + + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % + (qn(r_table), qn(truncate_name( + r_name, self.connection.ops.max_name_length())), qn(r_col), qn(table), qn(col), self.connection.ops.deferrable_sql())) del pending_references[model] return final_output def sql_indexes_for_model(self, model, style): - "Returns the CREATE INDEX SQL statements for a single model" + """ + Returns the CREATE INDEX SQL statements for a single model. + """ if not model._meta.managed or model._meta.proxy: return [] output = [] @@ -146,7 +168,9 @@ class BaseDatabaseCreation(object): return output def sql_indexes_for_field(self, model, f, style): - "Return the CREATE INDEX SQL statements for a single model field" + """ + Return the CREATE INDEX SQL statements for a single model field. + """ from django.db.backends.util import truncate_name if f.db_index and not f.unique: @@ -160,7 +184,8 @@ class BaseDatabaseCreation(object): tablespace_sql = '' i_name = '%s_%s' % (model._meta.db_table, self._digest(f.column)) output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' + - style.SQL_TABLE(qn(truncate_name(i_name, self.connection.ops.max_name_length()))) + ' ' + + style.SQL_TABLE(qn(truncate_name( + i_name, self.connection.ops.max_name_length()))) + ' ' + style.SQL_KEYWORD('ON') + ' ' + style.SQL_TABLE(qn(model._meta.db_table)) + ' ' + "(%s)" % style.SQL_FIELD(qn(f.column)) + @@ -170,7 +195,10 @@ class BaseDatabaseCreation(object): return output def sql_destroy_model(self, model, references_to_delete, style): - "Return the DROP TABLE and restraint dropping statements for a single model" + """ + Return the DROP TABLE and restraint dropping statements for a single + model. + """ if not model._meta.managed or model._meta.proxy: return [] # Drop the table now @@ -178,8 +206,8 @@ class BaseDatabaseCreation(object): output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'), style.SQL_TABLE(qn(model._meta.db_table)))] if model in references_to_delete: - output.extend(self.sql_remove_table_constraints(model, references_to_delete, style)) - + output.extend(self.sql_remove_table_constraints( + model, references_to_delete, style)) if model._meta.has_auto_field: ds = self.connection.ops.drop_sequence_sql(model._meta.db_table) if ds: @@ -188,7 +216,6 @@ class BaseDatabaseCreation(object): def sql_remove_table_constraints(self, model, references_to_delete, style): from django.db.backends.util import truncate_name - if not model._meta.managed or model._meta.proxy: return [] output = [] @@ -198,12 +225,14 @@ class BaseDatabaseCreation(object): col = f.column r_table = model._meta.db_table r_col = model._meta.get_field(f.rel.field_name).column - r_name = '%s_refs_%s_%s' % (col, r_col, self._digest(table, r_table)) + r_name = '%s_refs_%s_%s' % ( + col, r_col, self._digest(table, r_table)) output.append('%s %s %s %s;' % \ (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(table)), style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()), - style.SQL_FIELD(qn(truncate_name(r_name, self.connection.ops.max_name_length()))))) + style.SQL_FIELD(qn(truncate_name( + r_name, self.connection.ops.max_name_length()))))) del references_to_delete[model] return output @@ -221,7 +250,8 @@ class BaseDatabaseCreation(object): test_db_repr = '' if verbosity >= 2: test_db_repr = " ('%s')" % test_database_name - print "Creating test database for alias '%s'%s..." % (self.connection.alias, test_db_repr) + print "Creating test database for alias '%s'%s..." % ( + self.connection.alias, test_db_repr) self._create_test_db(verbosity, autoclobber) @@ -255,7 +285,8 @@ class BaseDatabaseCreation(object): for cache_alias in settings.CACHES: cache = get_cache(cache_alias) if isinstance(cache, BaseDatabaseCache): - call_command('createcachetable', cache._table, database=self.connection.alias) + call_command('createcachetable', cache._table, + database=self.connection.alias) # Get a cursor (even though we don't need one yet). This has # the side effect of initializing the test database. @@ -275,7 +306,9 @@ class BaseDatabaseCreation(object): return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME'] def _create_test_db(self, verbosity, autoclobber): - "Internal implementation - creates the test db tables." + """ + Internal implementation - creates the test db tables. + """ suffix = self.sql_table_creation_suffix() test_database_name = self._get_test_db_name() @@ -288,19 +321,28 @@ class BaseDatabaseCreation(object): cursor = self.connection.cursor() self._prepare_for_test_db_ddl() try: - cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) + cursor.execute( + "CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) except Exception, e: - sys.stderr.write("Got an error creating the test database: %s\n" % e) + sys.stderr.write( + "Got an error creating the test database: %s\n" % e) if not autoclobber: - confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name) + confirm = raw_input( + "Type 'yes' if you would like to try deleting the test " + "database '%s', or 'no' to cancel: " % test_database_name) if autoclobber or confirm == 'yes': try: if verbosity >= 1: - print "Destroying old test database '%s'..." % self.connection.alias - cursor.execute("DROP DATABASE %s" % qn(test_database_name)) - cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) + print ("Destroying old test database '%s'..." + % self.connection.alias) + cursor.execute( + "DROP DATABASE %s" % qn(test_database_name)) + cursor.execute( + "CREATE DATABASE %s %s" % (qn(test_database_name), + suffix)) except Exception, e: - sys.stderr.write("Got an error recreating the test database: %s\n" % e) + sys.stderr.write( + "Got an error recreating the test database: %s\n" % e) sys.exit(2) else: print "Tests cancelled." @@ -319,21 +361,36 @@ class BaseDatabaseCreation(object): test_db_repr = '' if verbosity >= 2: test_db_repr = " ('%s')" % test_database_name - print "Destroying test database for alias '%s'%s..." % (self.connection.alias, test_db_repr) - self.connection.settings_dict['NAME'] = old_database_name + print "Destroying test database for alias '%s'%s..." % ( + self.connection.alias, test_db_repr) - self._destroy_test_db(test_database_name, verbosity) + # Temporarily use a new connection and a copy of the settings dict. + # This prevents the production database from being exposed to potential + # child threads while (or after) the test database is destroyed. + # Refs #10868. + settings_dict = self.connection.settings_dict.copy() + settings_dict['NAME'] = old_database_name + backend = load_backend(settings_dict['ENGINE']) + new_connection = backend.DatabaseWrapper( + settings_dict, + alias='__destroy_test_db__', + allow_thread_sharing=False) + new_connection.creation._destroy_test_db(test_database_name, verbosity) def _destroy_test_db(self, test_database_name, verbosity): - "Internal implementation - remove the test db tables." + """ + Internal implementation - remove the test db tables. + """ # Remove the test database to clean up after # ourselves. Connect to the previous database (not the test database) # to do so, because it's not allowed to delete a database while being # connected to it. cursor = self.connection.cursor() self._prepare_for_test_db_ddl() - time.sleep(1) # To avoid "database is being accessed by other users" errors. - cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)) + # Wait to avoid "database is being accessed by other users" errors. + time.sleep(1) + cursor.execute("DROP DATABASE %s" + % self.connection.ops.quote_name(test_database_name)) self.connection.close() def set_autocommit(self): @@ -346,15 +403,17 @@ class BaseDatabaseCreation(object): def _prepare_for_test_db_ddl(self): """ - Internal implementation - Hook for tasks that should be performed before - the ``CREATE DATABASE``/``DROP DATABASE`` clauses used by testing code - to create/ destroy test databases. Needed e.g. in PostgreSQL to rollback - and close any active transaction. + Internal implementation - Hook for tasks that should be performed + before the ``CREATE DATABASE``/``DROP DATABASE`` clauses used by + testing code to create/ destroy test databases. Needed e.g. in + PostgreSQL to rollback and close any active transaction. """ pass def sql_table_creation_suffix(self): - "SQL to append to the end of the test table creation statements" + """ + SQL to append to the end of the test table creation statements. + """ return '' def test_db_signature(self): diff --git a/django/test/simple.py b/django/test/simple.py index 1534011c17..0c6a21bb37 100644 --- a/django/test/simple.py +++ b/django/test/simple.py @@ -17,15 +17,18 @@ TEST_MODULE = 'tests' doctestOutputChecker = OutputChecker() + class DjangoTestRunner(unittest.TextTestRunner): def __init__(self, *args, **kwargs): import warnings warnings.warn( - "DjangoTestRunner is deprecated; it's functionality is indistinguishable from TextTestRunner", + "DjangoTestRunner is deprecated; it's functionality is " + "indistinguishable from TextTestRunner", DeprecationWarning ) super(DjangoTestRunner, self).__init__(*args, **kwargs) + def get_tests(app_module): parts = app_module.__name__.split('.') prefix, last = parts[:-1], parts[-1] @@ -49,8 +52,11 @@ def get_tests(app_module): raise return test_module + def build_suite(app_module): - "Create a complete Django test suite for the provided application module" + """ + Create a complete Django test suite for the provided application module. + """ suite = unittest.TestSuite() # Load unit and doctests in the models.py module. If module has @@ -58,7 +64,8 @@ def build_suite(app_module): if hasattr(app_module, 'suite'): suite.addTest(app_module.suite()) else: - suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(app_module)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromModule( + app_module)) try: suite.addTest(doctest.DocTestSuite(app_module, checker=doctestOutputChecker, @@ -76,25 +83,29 @@ def build_suite(app_module): if hasattr(test_module, 'suite'): suite.addTest(test_module.suite()) else: - suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(test_module)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromModule( + test_module)) try: - suite.addTest(doctest.DocTestSuite(test_module, - checker=doctestOutputChecker, - runner=DocTestRunner)) + suite.addTest(doctest.DocTestSuite( + test_module, checker=doctestOutputChecker, + runner=DocTestRunner)) except ValueError: # No doc tests in tests.py pass return suite + def build_test(label): - """Construct a test case with the specified label. Label should be of the + """ + Construct a test case with the specified label. Label should be of the form model.TestClass or model.TestClass.test_method. Returns an instantiated test or test suite corresponding to the label provided. """ parts = label.split('.') if len(parts) < 2 or len(parts) > 3: - raise ValueError("Test label '%s' should be of the form app.TestCase or app.TestCase.test_method" % label) + raise ValueError("Test label '%s' should be of the form app.TestCase " + "or app.TestCase.test_method" % label) # # First, look for TestCase instances with a name that matches @@ -112,9 +123,12 @@ def build_test(label): if issubclass(TestClass, (unittest.TestCase, real_unittest.TestCase)): if len(parts) == 2: # label is app.TestClass try: - return unittest.TestLoader().loadTestsFromTestCase(TestClass) + return unittest.TestLoader().loadTestsFromTestCase( + TestClass) except TypeError: - raise ValueError("Test label '%s' does not refer to a test class" % label) + raise ValueError( + "Test label '%s' does not refer to a test class" + % label) else: # label is app.TestClass.test_method return TestClass(parts[2]) except TypeError: @@ -135,7 +149,8 @@ def build_test(label): for test in doctests: if test._dt_test.name in ( '%s.%s' % (module.__name__, '.'.join(parts[1:])), - '%s.__test__.%s' % (module.__name__, '.'.join(parts[1:]))): + '%s.__test__.%s' % ( + module.__name__, '.'.join(parts[1:]))): tests.append(test) except ValueError: # No doctests found. @@ -148,6 +163,7 @@ def build_test(label): # Construct a suite out of the tests that matched. return unittest.TestSuite(tests) + def partition_suite(suite, classes, bins): """ Partitions a test suite by test type. @@ -169,14 +185,15 @@ def partition_suite(suite, classes, bins): else: bins[-1].addTest(test) + def reorder_suite(suite, classes): """ Reorders a test suite by test type. - classes is a sequence of types + `classes` is a sequence of types - All tests of type clases[0] are placed first, then tests of type classes[1], etc. - Tests with no match in classes are placed last. + All tests of type classes[0] are placed first, then tests of type + classes[1], etc. Tests with no match in classes are placed last. """ class_count = len(classes) bins = [unittest.TestSuite() for i in range(class_count+1)] @@ -185,6 +202,7 @@ def reorder_suite(suite, classes): bins[0].addTests(bins[i+1]) return bins[0] + def dependency_ordered(test_databases, dependencies): """Reorder test_databases into an order that honors the dependencies described in TEST_DEPENDENCIES. @@ -200,7 +218,8 @@ def dependency_ordered(test_databases, dependencies): dependencies_satisfied = True for alias in aliases: if alias in dependencies: - if all(a in resolved_databases for a in dependencies[alias]): + if all(a in resolved_databases + for a in dependencies[alias]): # all dependencies for this alias are satisfied dependencies.pop(alias) resolved_databases.add(alias) @@ -216,10 +235,12 @@ def dependency_ordered(test_databases, dependencies): deferred.append((signature, (db_name, aliases))) if not changed: - raise ImproperlyConfigured("Circular dependency in TEST_DEPENDENCIES") + raise ImproperlyConfigured( + "Circular dependency in TEST_DEPENDENCIES") test_databases = deferred return ordered_test_databases + class DjangoTestSuiteRunner(object): def __init__(self, verbosity=1, interactive=True, failfast=True, **kwargs): self.verbosity = verbosity @@ -264,7 +285,8 @@ class DjangoTestSuiteRunner(object): if connection.settings_dict['TEST_MIRROR']: # If the database is marked as a test mirror, save # the alias. - mirrored_aliases[alias] = connection.settings_dict['TEST_MIRROR'] + mirrored_aliases[alias] = ( + connection.settings_dict['TEST_MIRROR']) else: # Store a tuple with DB parameters that uniquely identify it. # If we have two aliases with the same values for that tuple, @@ -276,53 +298,57 @@ class DjangoTestSuiteRunner(object): item[1].append(alias) if 'TEST_DEPENDENCIES' in connection.settings_dict: - dependencies[alias] = connection.settings_dict['TEST_DEPENDENCIES'] + dependencies[alias] = ( + connection.settings_dict['TEST_DEPENDENCIES']) else: if alias != DEFAULT_DB_ALIAS: - dependencies[alias] = connection.settings_dict.get('TEST_DEPENDENCIES', [DEFAULT_DB_ALIAS]) + dependencies[alias] = connection.settings_dict.get( + 'TEST_DEPENDENCIES', [DEFAULT_DB_ALIAS]) # Second pass -- actually create the databases. old_names = [] mirrors = [] - for signature, (db_name, aliases) in dependency_ordered(test_databases.items(), dependencies): + for signature, (db_name, aliases) in dependency_ordered( + test_databases.items(), dependencies): # Actually create the database for the first connection connection = connections[aliases[0]] old_names.append((connection, db_name, True)) - test_db_name = connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive) + test_db_name = connection.creation.create_test_db( + self.verbosity, autoclobber=not self.interactive) for alias in aliases[1:]: connection = connections[alias] if db_name: old_names.append((connection, db_name, False)) connection.settings_dict['NAME'] = test_db_name else: - # If settings_dict['NAME'] isn't defined, we have a backend where - # the name isn't important -- e.g., SQLite, which uses :memory:. - # Force create the database instead of assuming it's a duplicate. + # If settings_dict['NAME'] isn't defined, we have a backend + # where the name isn't important -- e.g., SQLite, which + # uses :memory:. Force create the database instead of + # assuming it's a duplicate. old_names.append((connection, db_name, True)) - connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive) + connection.creation.create_test_db( + self.verbosity, autoclobber=not self.interactive) for alias, mirror_alias in mirrored_aliases.items(): mirrors.append((alias, connections[alias].settings_dict['NAME'])) - connections[alias].settings_dict['NAME'] = connections[mirror_alias].settings_dict['NAME'] + connections[alias].settings_dict['NAME'] = ( + connections[mirror_alias].settings_dict['NAME']) connections[alias].features = connections[mirror_alias].features return old_names, mirrors def run_suite(self, suite, **kwargs): - return unittest.TextTestRunner(verbosity=self.verbosity, failfast=self.failfast).run(suite) + return unittest.TextTestRunner( + verbosity=self.verbosity, failfast=self.failfast).run(suite) def teardown_databases(self, old_config, **kwargs): - from django.db import connections + """ + Destroys all the non-mirror databases. + """ old_names, mirrors = old_config - # Point all the mirrors back to the originals - for alias, old_name in mirrors: - connections[alias].settings_dict['NAME'] = old_name - # Destroy all the non-mirror databases for connection, old_name, destroy in old_names: if destroy: connection.creation.destroy_test_db(old_name, self.verbosity) - else: - connection.settings_dict['NAME'] = old_name def teardown_test_environment(self, **kwargs): unittest.removeHandler() diff --git a/docs/releases/1.4.txt b/docs/releases/1.4.txt index c5128b10f2..8a7bfb2011 100644 --- a/docs/releases/1.4.txt +++ b/docs/releases/1.4.txt @@ -946,6 +946,19 @@ apply URL escaping again. This is wrong for URLs whose unquoted form contains a ``%xx`` sequence, but such URLs are very unlikely to happen in the wild, since they would confuse browsers too. +Database connections after running the test suite +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The default test runner now does not restore the database connections after the +tests' execution any more. This prevents the production database from being +exposed to potential threads that would still be running and attempting to +create new connections. + +If your code relied on connections to the production database being created +after the tests' execution, then you may restore the previous behavior by +subclassing ``DjangoTestRunner`` and overriding its ``teardown_databases()`` +method. + Features deprecated in 1.4 ==========================