Fixed #10868 -- Stopped restoring database connections after the tests' execution in order to prevent the production database from being exposed to potential threads that would still be running. Also did a bit of PEP8-cleaning while I was in the area. Many thanks to ovidiu for the report and to Anssi Kääriäinen for thoroughly investigating this issue.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@17411 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
b5d0cc9091
commit
f1dc83cb98
|
@ -2,11 +2,13 @@ import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.db.utils import load_backend
|
||||||
|
|
||||||
# The prefix to put on the default database name when creating
|
# The prefix to put on the default database name when creating
|
||||||
# the test database.
|
# the test database.
|
||||||
TEST_DATABASE_PREFIX = 'test_'
|
TEST_DATABASE_PREFIX = 'test_'
|
||||||
|
|
||||||
|
|
||||||
class BaseDatabaseCreation(object):
|
class BaseDatabaseCreation(object):
|
||||||
"""
|
"""
|
||||||
This class encapsulates all backend-specific differences that pertain to
|
This class encapsulates all backend-specific differences that pertain to
|
||||||
|
@ -57,35 +59,45 @@ class BaseDatabaseCreation(object):
|
||||||
if tablespace and f.unique:
|
if tablespace and f.unique:
|
||||||
# We must specify the index tablespace inline, because we
|
# We must specify the index tablespace inline, because we
|
||||||
# won't be generating a CREATE INDEX statement for this field.
|
# won't be generating a CREATE INDEX statement for this field.
|
||||||
tablespace_sql = self.connection.ops.tablespace_sql(tablespace, inline=True)
|
tablespace_sql = self.connection.ops.tablespace_sql(
|
||||||
|
tablespace, inline=True)
|
||||||
if tablespace_sql:
|
if tablespace_sql:
|
||||||
field_output.append(tablespace_sql)
|
field_output.append(tablespace_sql)
|
||||||
if f.rel:
|
if f.rel:
|
||||||
ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style)
|
ref_output, pending = self.sql_for_inline_foreign_key_references(
|
||||||
|
f, known_models, style)
|
||||||
if pending:
|
if pending:
|
||||||
pending_references.setdefault(f.rel.to, []).append((model, f))
|
pending_references.setdefault(f.rel.to, []).append(
|
||||||
|
(model, f))
|
||||||
else:
|
else:
|
||||||
field_output.extend(ref_output)
|
field_output.extend(ref_output)
|
||||||
table_output.append(' '.join(field_output))
|
table_output.append(' '.join(field_output))
|
||||||
for field_constraints in opts.unique_together:
|
for field_constraints in opts.unique_together:
|
||||||
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
|
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' %
|
||||||
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
|
", ".join(
|
||||||
|
[style.SQL_FIELD(qn(opts.get_field(f).column))
|
||||||
|
for f in field_constraints]))
|
||||||
|
|
||||||
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
|
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' +
|
||||||
|
style.SQL_TABLE(qn(opts.db_table)) + ' (']
|
||||||
for i, line in enumerate(table_output): # Combine and add commas.
|
for i, line in enumerate(table_output): # Combine and add commas.
|
||||||
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
|
full_statement.append(
|
||||||
|
' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
|
||||||
full_statement.append(')')
|
full_statement.append(')')
|
||||||
if opts.db_tablespace:
|
if opts.db_tablespace:
|
||||||
tablespace_sql = self.connection.ops.tablespace_sql(opts.db_tablespace)
|
tablespace_sql = self.connection.ops.tablespace_sql(
|
||||||
|
opts.db_tablespace)
|
||||||
if tablespace_sql:
|
if tablespace_sql:
|
||||||
full_statement.append(tablespace_sql)
|
full_statement.append(tablespace_sql)
|
||||||
full_statement.append(';')
|
full_statement.append(';')
|
||||||
final_output.append('\n'.join(full_statement))
|
final_output.append('\n'.join(full_statement))
|
||||||
|
|
||||||
if opts.has_auto_field:
|
if opts.has_auto_field:
|
||||||
# Add any extra SQL needed to support auto-incrementing primary keys.
|
# Add any extra SQL needed to support auto-incrementing primary
|
||||||
|
# keys.
|
||||||
auto_column = opts.auto_field.db_column or opts.auto_field.name
|
auto_column = opts.auto_field.db_column or opts.auto_field.name
|
||||||
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column)
|
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table,
|
||||||
|
auto_column)
|
||||||
if autoinc_sql:
|
if autoinc_sql:
|
||||||
for stmt in autoinc_sql:
|
for stmt in autoinc_sql:
|
||||||
final_output.append(stmt)
|
final_output.append(stmt)
|
||||||
|
@ -93,12 +105,15 @@ class BaseDatabaseCreation(object):
|
||||||
return final_output, pending_references
|
return final_output, pending_references
|
||||||
|
|
||||||
def sql_for_inline_foreign_key_references(self, field, known_models, style):
|
def sql_for_inline_foreign_key_references(self, field, known_models, style):
|
||||||
"Return the SQL snippet defining the foreign key reference for a field"
|
"""
|
||||||
|
Return the SQL snippet defining the foreign key reference for a field.
|
||||||
|
"""
|
||||||
qn = self.connection.ops.quote_name
|
qn = self.connection.ops.quote_name
|
||||||
if field.rel.to in known_models:
|
if field.rel.to in known_models:
|
||||||
output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \
|
output = [style.SQL_KEYWORD('REFERENCES') + ' ' +
|
||||||
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \
|
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' +
|
||||||
style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' +
|
style.SQL_FIELD(qn(field.rel.to._meta.get_field(
|
||||||
|
field.rel.field_name).column)) + ')' +
|
||||||
self.connection.ops.deferrable_sql()
|
self.connection.ops.deferrable_sql()
|
||||||
]
|
]
|
||||||
pending = False
|
pending = False
|
||||||
|
@ -111,7 +126,9 @@ class BaseDatabaseCreation(object):
|
||||||
return output, pending
|
return output, pending
|
||||||
|
|
||||||
def sql_for_pending_references(self, model, style, pending_references):
|
def sql_for_pending_references(self, model, style, pending_references):
|
||||||
"Returns any ALTER TABLE statements to add constraints after the fact."
|
"""
|
||||||
|
Returns any ALTER TABLE statements to add constraints after the fact.
|
||||||
|
"""
|
||||||
from django.db.backends.util import truncate_name
|
from django.db.backends.util import truncate_name
|
||||||
|
|
||||||
if not model._meta.managed or model._meta.proxy:
|
if not model._meta.managed or model._meta.proxy:
|
||||||
|
@ -128,16 +145,21 @@ class BaseDatabaseCreation(object):
|
||||||
col = opts.get_field(f.rel.field_name).column
|
col = opts.get_field(f.rel.field_name).column
|
||||||
# For MySQL, r_name must be unique in the first 64 characters.
|
# For MySQL, r_name must be unique in the first 64 characters.
|
||||||
# So we are careful with character usage here.
|
# So we are careful with character usage here.
|
||||||
r_name = '%s_refs_%s_%s' % (r_col, col, self._digest(r_table, table))
|
r_name = '%s_refs_%s_%s' % (
|
||||||
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
|
r_col, col, self._digest(r_table, table))
|
||||||
(qn(r_table), qn(truncate_name(r_name, self.connection.ops.max_name_length())),
|
final_output.append(style.SQL_KEYWORD('ALTER TABLE') +
|
||||||
|
' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
|
||||||
|
(qn(r_table), qn(truncate_name(
|
||||||
|
r_name, self.connection.ops.max_name_length())),
|
||||||
qn(r_col), qn(table), qn(col),
|
qn(r_col), qn(table), qn(col),
|
||||||
self.connection.ops.deferrable_sql()))
|
self.connection.ops.deferrable_sql()))
|
||||||
del pending_references[model]
|
del pending_references[model]
|
||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
def sql_indexes_for_model(self, model, style):
|
def sql_indexes_for_model(self, model, style):
|
||||||
"Returns the CREATE INDEX SQL statements for a single model"
|
"""
|
||||||
|
Returns the CREATE INDEX SQL statements for a single model.
|
||||||
|
"""
|
||||||
if not model._meta.managed or model._meta.proxy:
|
if not model._meta.managed or model._meta.proxy:
|
||||||
return []
|
return []
|
||||||
output = []
|
output = []
|
||||||
|
@ -146,7 +168,9 @@ class BaseDatabaseCreation(object):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def sql_indexes_for_field(self, model, f, style):
|
def sql_indexes_for_field(self, model, f, style):
|
||||||
"Return the CREATE INDEX SQL statements for a single model field"
|
"""
|
||||||
|
Return the CREATE INDEX SQL statements for a single model field.
|
||||||
|
"""
|
||||||
from django.db.backends.util import truncate_name
|
from django.db.backends.util import truncate_name
|
||||||
|
|
||||||
if f.db_index and not f.unique:
|
if f.db_index and not f.unique:
|
||||||
|
@ -160,7 +184,8 @@ class BaseDatabaseCreation(object):
|
||||||
tablespace_sql = ''
|
tablespace_sql = ''
|
||||||
i_name = '%s_%s' % (model._meta.db_table, self._digest(f.column))
|
i_name = '%s_%s' % (model._meta.db_table, self._digest(f.column))
|
||||||
output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
|
output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
|
||||||
style.SQL_TABLE(qn(truncate_name(i_name, self.connection.ops.max_name_length()))) + ' ' +
|
style.SQL_TABLE(qn(truncate_name(
|
||||||
|
i_name, self.connection.ops.max_name_length()))) + ' ' +
|
||||||
style.SQL_KEYWORD('ON') + ' ' +
|
style.SQL_KEYWORD('ON') + ' ' +
|
||||||
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
|
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
|
||||||
"(%s)" % style.SQL_FIELD(qn(f.column)) +
|
"(%s)" % style.SQL_FIELD(qn(f.column)) +
|
||||||
|
@ -170,7 +195,10 @@ class BaseDatabaseCreation(object):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def sql_destroy_model(self, model, references_to_delete, style):
|
def sql_destroy_model(self, model, references_to_delete, style):
|
||||||
"Return the DROP TABLE and restraint dropping statements for a single model"
|
"""
|
||||||
|
Return the DROP TABLE and restraint dropping statements for a single
|
||||||
|
model.
|
||||||
|
"""
|
||||||
if not model._meta.managed or model._meta.proxy:
|
if not model._meta.managed or model._meta.proxy:
|
||||||
return []
|
return []
|
||||||
# Drop the table now
|
# Drop the table now
|
||||||
|
@ -178,8 +206,8 @@ class BaseDatabaseCreation(object):
|
||||||
output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
|
output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
|
||||||
style.SQL_TABLE(qn(model._meta.db_table)))]
|
style.SQL_TABLE(qn(model._meta.db_table)))]
|
||||||
if model in references_to_delete:
|
if model in references_to_delete:
|
||||||
output.extend(self.sql_remove_table_constraints(model, references_to_delete, style))
|
output.extend(self.sql_remove_table_constraints(
|
||||||
|
model, references_to_delete, style))
|
||||||
if model._meta.has_auto_field:
|
if model._meta.has_auto_field:
|
||||||
ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
|
ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
|
||||||
if ds:
|
if ds:
|
||||||
|
@ -188,7 +216,6 @@ class BaseDatabaseCreation(object):
|
||||||
|
|
||||||
def sql_remove_table_constraints(self, model, references_to_delete, style):
|
def sql_remove_table_constraints(self, model, references_to_delete, style):
|
||||||
from django.db.backends.util import truncate_name
|
from django.db.backends.util import truncate_name
|
||||||
|
|
||||||
if not model._meta.managed or model._meta.proxy:
|
if not model._meta.managed or model._meta.proxy:
|
||||||
return []
|
return []
|
||||||
output = []
|
output = []
|
||||||
|
@ -198,12 +225,14 @@ class BaseDatabaseCreation(object):
|
||||||
col = f.column
|
col = f.column
|
||||||
r_table = model._meta.db_table
|
r_table = model._meta.db_table
|
||||||
r_col = model._meta.get_field(f.rel.field_name).column
|
r_col = model._meta.get_field(f.rel.field_name).column
|
||||||
r_name = '%s_refs_%s_%s' % (col, r_col, self._digest(table, r_table))
|
r_name = '%s_refs_%s_%s' % (
|
||||||
|
col, r_col, self._digest(table, r_table))
|
||||||
output.append('%s %s %s %s;' % \
|
output.append('%s %s %s %s;' % \
|
||||||
(style.SQL_KEYWORD('ALTER TABLE'),
|
(style.SQL_KEYWORD('ALTER TABLE'),
|
||||||
style.SQL_TABLE(qn(table)),
|
style.SQL_TABLE(qn(table)),
|
||||||
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
|
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
|
||||||
style.SQL_FIELD(qn(truncate_name(r_name, self.connection.ops.max_name_length())))))
|
style.SQL_FIELD(qn(truncate_name(
|
||||||
|
r_name, self.connection.ops.max_name_length())))))
|
||||||
del references_to_delete[model]
|
del references_to_delete[model]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -221,7 +250,8 @@ class BaseDatabaseCreation(object):
|
||||||
test_db_repr = ''
|
test_db_repr = ''
|
||||||
if verbosity >= 2:
|
if verbosity >= 2:
|
||||||
test_db_repr = " ('%s')" % test_database_name
|
test_db_repr = " ('%s')" % test_database_name
|
||||||
print "Creating test database for alias '%s'%s..." % (self.connection.alias, test_db_repr)
|
print "Creating test database for alias '%s'%s..." % (
|
||||||
|
self.connection.alias, test_db_repr)
|
||||||
|
|
||||||
self._create_test_db(verbosity, autoclobber)
|
self._create_test_db(verbosity, autoclobber)
|
||||||
|
|
||||||
|
@ -255,7 +285,8 @@ class BaseDatabaseCreation(object):
|
||||||
for cache_alias in settings.CACHES:
|
for cache_alias in settings.CACHES:
|
||||||
cache = get_cache(cache_alias)
|
cache = get_cache(cache_alias)
|
||||||
if isinstance(cache, BaseDatabaseCache):
|
if isinstance(cache, BaseDatabaseCache):
|
||||||
call_command('createcachetable', cache._table, database=self.connection.alias)
|
call_command('createcachetable', cache._table,
|
||||||
|
database=self.connection.alias)
|
||||||
|
|
||||||
# Get a cursor (even though we don't need one yet). This has
|
# Get a cursor (even though we don't need one yet). This has
|
||||||
# the side effect of initializing the test database.
|
# the side effect of initializing the test database.
|
||||||
|
@ -275,7 +306,9 @@ class BaseDatabaseCreation(object):
|
||||||
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
|
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
|
||||||
|
|
||||||
def _create_test_db(self, verbosity, autoclobber):
|
def _create_test_db(self, verbosity, autoclobber):
|
||||||
"Internal implementation - creates the test db tables."
|
"""
|
||||||
|
Internal implementation - creates the test db tables.
|
||||||
|
"""
|
||||||
suffix = self.sql_table_creation_suffix()
|
suffix = self.sql_table_creation_suffix()
|
||||||
|
|
||||||
test_database_name = self._get_test_db_name()
|
test_database_name = self._get_test_db_name()
|
||||||
|
@ -288,19 +321,28 @@ class BaseDatabaseCreation(object):
|
||||||
cursor = self.connection.cursor()
|
cursor = self.connection.cursor()
|
||||||
self._prepare_for_test_db_ddl()
|
self._prepare_for_test_db_ddl()
|
||||||
try:
|
try:
|
||||||
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
|
cursor.execute(
|
||||||
|
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
sys.stderr.write("Got an error creating the test database: %s\n" % e)
|
sys.stderr.write(
|
||||||
|
"Got an error creating the test database: %s\n" % e)
|
||||||
if not autoclobber:
|
if not autoclobber:
|
||||||
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
|
confirm = raw_input(
|
||||||
|
"Type 'yes' if you would like to try deleting the test "
|
||||||
|
"database '%s', or 'no' to cancel: " % test_database_name)
|
||||||
if autoclobber or confirm == 'yes':
|
if autoclobber or confirm == 'yes':
|
||||||
try:
|
try:
|
||||||
if verbosity >= 1:
|
if verbosity >= 1:
|
||||||
print "Destroying old test database '%s'..." % self.connection.alias
|
print ("Destroying old test database '%s'..."
|
||||||
cursor.execute("DROP DATABASE %s" % qn(test_database_name))
|
% self.connection.alias)
|
||||||
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
|
cursor.execute(
|
||||||
|
"DROP DATABASE %s" % qn(test_database_name))
|
||||||
|
cursor.execute(
|
||||||
|
"CREATE DATABASE %s %s" % (qn(test_database_name),
|
||||||
|
suffix))
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
|
sys.stderr.write(
|
||||||
|
"Got an error recreating the test database: %s\n" % e)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
else:
|
else:
|
||||||
print "Tests cancelled."
|
print "Tests cancelled."
|
||||||
|
@ -319,21 +361,36 @@ class BaseDatabaseCreation(object):
|
||||||
test_db_repr = ''
|
test_db_repr = ''
|
||||||
if verbosity >= 2:
|
if verbosity >= 2:
|
||||||
test_db_repr = " ('%s')" % test_database_name
|
test_db_repr = " ('%s')" % test_database_name
|
||||||
print "Destroying test database for alias '%s'%s..." % (self.connection.alias, test_db_repr)
|
print "Destroying test database for alias '%s'%s..." % (
|
||||||
self.connection.settings_dict['NAME'] = old_database_name
|
self.connection.alias, test_db_repr)
|
||||||
|
|
||||||
self._destroy_test_db(test_database_name, verbosity)
|
# Temporarily use a new connection and a copy of the settings dict.
|
||||||
|
# This prevents the production database from being exposed to potential
|
||||||
|
# child threads while (or after) the test database is destroyed.
|
||||||
|
# Refs #10868.
|
||||||
|
settings_dict = self.connection.settings_dict.copy()
|
||||||
|
settings_dict['NAME'] = old_database_name
|
||||||
|
backend = load_backend(settings_dict['ENGINE'])
|
||||||
|
new_connection = backend.DatabaseWrapper(
|
||||||
|
settings_dict,
|
||||||
|
alias='__destroy_test_db__',
|
||||||
|
allow_thread_sharing=False)
|
||||||
|
new_connection.creation._destroy_test_db(test_database_name, verbosity)
|
||||||
|
|
||||||
def _destroy_test_db(self, test_database_name, verbosity):
|
def _destroy_test_db(self, test_database_name, verbosity):
|
||||||
"Internal implementation - remove the test db tables."
|
"""
|
||||||
|
Internal implementation - remove the test db tables.
|
||||||
|
"""
|
||||||
# Remove the test database to clean up after
|
# Remove the test database to clean up after
|
||||||
# ourselves. Connect to the previous database (not the test database)
|
# ourselves. Connect to the previous database (not the test database)
|
||||||
# to do so, because it's not allowed to delete a database while being
|
# to do so, because it's not allowed to delete a database while being
|
||||||
# connected to it.
|
# connected to it.
|
||||||
cursor = self.connection.cursor()
|
cursor = self.connection.cursor()
|
||||||
self._prepare_for_test_db_ddl()
|
self._prepare_for_test_db_ddl()
|
||||||
time.sleep(1) # To avoid "database is being accessed by other users" errors.
|
# Wait to avoid "database is being accessed by other users" errors.
|
||||||
cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name))
|
time.sleep(1)
|
||||||
|
cursor.execute("DROP DATABASE %s"
|
||||||
|
% self.connection.ops.quote_name(test_database_name))
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
|
|
||||||
def set_autocommit(self):
|
def set_autocommit(self):
|
||||||
|
@ -346,15 +403,17 @@ class BaseDatabaseCreation(object):
|
||||||
|
|
||||||
def _prepare_for_test_db_ddl(self):
|
def _prepare_for_test_db_ddl(self):
|
||||||
"""
|
"""
|
||||||
Internal implementation - Hook for tasks that should be performed before
|
Internal implementation - Hook for tasks that should be performed
|
||||||
the ``CREATE DATABASE``/``DROP DATABASE`` clauses used by testing code
|
before the ``CREATE DATABASE``/``DROP DATABASE`` clauses used by
|
||||||
to create/ destroy test databases. Needed e.g. in PostgreSQL to rollback
|
testing code to create/ destroy test databases. Needed e.g. in
|
||||||
and close any active transaction.
|
PostgreSQL to rollback and close any active transaction.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sql_table_creation_suffix(self):
|
def sql_table_creation_suffix(self):
|
||||||
"SQL to append to the end of the test table creation statements"
|
"""
|
||||||
|
SQL to append to the end of the test table creation statements.
|
||||||
|
"""
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
def test_db_signature(self):
|
def test_db_signature(self):
|
||||||
|
|
|
@ -17,15 +17,18 @@ TEST_MODULE = 'tests'
|
||||||
|
|
||||||
doctestOutputChecker = OutputChecker()
|
doctestOutputChecker = OutputChecker()
|
||||||
|
|
||||||
|
|
||||||
class DjangoTestRunner(unittest.TextTestRunner):
|
class DjangoTestRunner(unittest.TextTestRunner):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"DjangoTestRunner is deprecated; it's functionality is indistinguishable from TextTestRunner",
|
"DjangoTestRunner is deprecated; it's functionality is "
|
||||||
|
"indistinguishable from TextTestRunner",
|
||||||
DeprecationWarning
|
DeprecationWarning
|
||||||
)
|
)
|
||||||
super(DjangoTestRunner, self).__init__(*args, **kwargs)
|
super(DjangoTestRunner, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_tests(app_module):
|
def get_tests(app_module):
|
||||||
parts = app_module.__name__.split('.')
|
parts = app_module.__name__.split('.')
|
||||||
prefix, last = parts[:-1], parts[-1]
|
prefix, last = parts[:-1], parts[-1]
|
||||||
|
@ -49,8 +52,11 @@ def get_tests(app_module):
|
||||||
raise
|
raise
|
||||||
return test_module
|
return test_module
|
||||||
|
|
||||||
|
|
||||||
def build_suite(app_module):
|
def build_suite(app_module):
|
||||||
"Create a complete Django test suite for the provided application module"
|
"""
|
||||||
|
Create a complete Django test suite for the provided application module.
|
||||||
|
"""
|
||||||
suite = unittest.TestSuite()
|
suite = unittest.TestSuite()
|
||||||
|
|
||||||
# Load unit and doctests in the models.py module. If module has
|
# Load unit and doctests in the models.py module. If module has
|
||||||
|
@ -58,7 +64,8 @@ def build_suite(app_module):
|
||||||
if hasattr(app_module, 'suite'):
|
if hasattr(app_module, 'suite'):
|
||||||
suite.addTest(app_module.suite())
|
suite.addTest(app_module.suite())
|
||||||
else:
|
else:
|
||||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(app_module))
|
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(
|
||||||
|
app_module))
|
||||||
try:
|
try:
|
||||||
suite.addTest(doctest.DocTestSuite(app_module,
|
suite.addTest(doctest.DocTestSuite(app_module,
|
||||||
checker=doctestOutputChecker,
|
checker=doctestOutputChecker,
|
||||||
|
@ -76,25 +83,29 @@ def build_suite(app_module):
|
||||||
if hasattr(test_module, 'suite'):
|
if hasattr(test_module, 'suite'):
|
||||||
suite.addTest(test_module.suite())
|
suite.addTest(test_module.suite())
|
||||||
else:
|
else:
|
||||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(test_module))
|
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(
|
||||||
|
test_module))
|
||||||
try:
|
try:
|
||||||
suite.addTest(doctest.DocTestSuite(test_module,
|
suite.addTest(doctest.DocTestSuite(
|
||||||
checker=doctestOutputChecker,
|
test_module, checker=doctestOutputChecker,
|
||||||
runner=DocTestRunner))
|
runner=DocTestRunner))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# No doc tests in tests.py
|
# No doc tests in tests.py
|
||||||
pass
|
pass
|
||||||
return suite
|
return suite
|
||||||
|
|
||||||
|
|
||||||
def build_test(label):
|
def build_test(label):
|
||||||
"""Construct a test case with the specified label. Label should be of the
|
"""
|
||||||
|
Construct a test case with the specified label. Label should be of the
|
||||||
form model.TestClass or model.TestClass.test_method. Returns an
|
form model.TestClass or model.TestClass.test_method. Returns an
|
||||||
instantiated test or test suite corresponding to the label provided.
|
instantiated test or test suite corresponding to the label provided.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
parts = label.split('.')
|
parts = label.split('.')
|
||||||
if len(parts) < 2 or len(parts) > 3:
|
if len(parts) < 2 or len(parts) > 3:
|
||||||
raise ValueError("Test label '%s' should be of the form app.TestCase or app.TestCase.test_method" % label)
|
raise ValueError("Test label '%s' should be of the form app.TestCase "
|
||||||
|
"or app.TestCase.test_method" % label)
|
||||||
|
|
||||||
#
|
#
|
||||||
# First, look for TestCase instances with a name that matches
|
# First, look for TestCase instances with a name that matches
|
||||||
|
@ -112,9 +123,12 @@ def build_test(label):
|
||||||
if issubclass(TestClass, (unittest.TestCase, real_unittest.TestCase)):
|
if issubclass(TestClass, (unittest.TestCase, real_unittest.TestCase)):
|
||||||
if len(parts) == 2: # label is app.TestClass
|
if len(parts) == 2: # label is app.TestClass
|
||||||
try:
|
try:
|
||||||
return unittest.TestLoader().loadTestsFromTestCase(TestClass)
|
return unittest.TestLoader().loadTestsFromTestCase(
|
||||||
|
TestClass)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ValueError("Test label '%s' does not refer to a test class" % label)
|
raise ValueError(
|
||||||
|
"Test label '%s' does not refer to a test class"
|
||||||
|
% label)
|
||||||
else: # label is app.TestClass.test_method
|
else: # label is app.TestClass.test_method
|
||||||
return TestClass(parts[2])
|
return TestClass(parts[2])
|
||||||
except TypeError:
|
except TypeError:
|
||||||
|
@ -135,7 +149,8 @@ def build_test(label):
|
||||||
for test in doctests:
|
for test in doctests:
|
||||||
if test._dt_test.name in (
|
if test._dt_test.name in (
|
||||||
'%s.%s' % (module.__name__, '.'.join(parts[1:])),
|
'%s.%s' % (module.__name__, '.'.join(parts[1:])),
|
||||||
'%s.__test__.%s' % (module.__name__, '.'.join(parts[1:]))):
|
'%s.__test__.%s' % (
|
||||||
|
module.__name__, '.'.join(parts[1:]))):
|
||||||
tests.append(test)
|
tests.append(test)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# No doctests found.
|
# No doctests found.
|
||||||
|
@ -148,6 +163,7 @@ def build_test(label):
|
||||||
# Construct a suite out of the tests that matched.
|
# Construct a suite out of the tests that matched.
|
||||||
return unittest.TestSuite(tests)
|
return unittest.TestSuite(tests)
|
||||||
|
|
||||||
|
|
||||||
def partition_suite(suite, classes, bins):
|
def partition_suite(suite, classes, bins):
|
||||||
"""
|
"""
|
||||||
Partitions a test suite by test type.
|
Partitions a test suite by test type.
|
||||||
|
@ -169,14 +185,15 @@ def partition_suite(suite, classes, bins):
|
||||||
else:
|
else:
|
||||||
bins[-1].addTest(test)
|
bins[-1].addTest(test)
|
||||||
|
|
||||||
|
|
||||||
def reorder_suite(suite, classes):
|
def reorder_suite(suite, classes):
|
||||||
"""
|
"""
|
||||||
Reorders a test suite by test type.
|
Reorders a test suite by test type.
|
||||||
|
|
||||||
classes is a sequence of types
|
`classes` is a sequence of types
|
||||||
|
|
||||||
All tests of type clases[0] are placed first, then tests of type classes[1], etc.
|
All tests of type classes[0] are placed first, then tests of type
|
||||||
Tests with no match in classes are placed last.
|
classes[1], etc. Tests with no match in classes are placed last.
|
||||||
"""
|
"""
|
||||||
class_count = len(classes)
|
class_count = len(classes)
|
||||||
bins = [unittest.TestSuite() for i in range(class_count+1)]
|
bins = [unittest.TestSuite() for i in range(class_count+1)]
|
||||||
|
@ -185,6 +202,7 @@ def reorder_suite(suite, classes):
|
||||||
bins[0].addTests(bins[i+1])
|
bins[0].addTests(bins[i+1])
|
||||||
return bins[0]
|
return bins[0]
|
||||||
|
|
||||||
|
|
||||||
def dependency_ordered(test_databases, dependencies):
|
def dependency_ordered(test_databases, dependencies):
|
||||||
"""Reorder test_databases into an order that honors the dependencies
|
"""Reorder test_databases into an order that honors the dependencies
|
||||||
described in TEST_DEPENDENCIES.
|
described in TEST_DEPENDENCIES.
|
||||||
|
@ -200,7 +218,8 @@ def dependency_ordered(test_databases, dependencies):
|
||||||
dependencies_satisfied = True
|
dependencies_satisfied = True
|
||||||
for alias in aliases:
|
for alias in aliases:
|
||||||
if alias in dependencies:
|
if alias in dependencies:
|
||||||
if all(a in resolved_databases for a in dependencies[alias]):
|
if all(a in resolved_databases
|
||||||
|
for a in dependencies[alias]):
|
||||||
# all dependencies for this alias are satisfied
|
# all dependencies for this alias are satisfied
|
||||||
dependencies.pop(alias)
|
dependencies.pop(alias)
|
||||||
resolved_databases.add(alias)
|
resolved_databases.add(alias)
|
||||||
|
@ -216,10 +235,12 @@ def dependency_ordered(test_databases, dependencies):
|
||||||
deferred.append((signature, (db_name, aliases)))
|
deferred.append((signature, (db_name, aliases)))
|
||||||
|
|
||||||
if not changed:
|
if not changed:
|
||||||
raise ImproperlyConfigured("Circular dependency in TEST_DEPENDENCIES")
|
raise ImproperlyConfigured(
|
||||||
|
"Circular dependency in TEST_DEPENDENCIES")
|
||||||
test_databases = deferred
|
test_databases = deferred
|
||||||
return ordered_test_databases
|
return ordered_test_databases
|
||||||
|
|
||||||
|
|
||||||
class DjangoTestSuiteRunner(object):
|
class DjangoTestSuiteRunner(object):
|
||||||
def __init__(self, verbosity=1, interactive=True, failfast=True, **kwargs):
|
def __init__(self, verbosity=1, interactive=True, failfast=True, **kwargs):
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
|
@ -264,7 +285,8 @@ class DjangoTestSuiteRunner(object):
|
||||||
if connection.settings_dict['TEST_MIRROR']:
|
if connection.settings_dict['TEST_MIRROR']:
|
||||||
# If the database is marked as a test mirror, save
|
# If the database is marked as a test mirror, save
|
||||||
# the alias.
|
# the alias.
|
||||||
mirrored_aliases[alias] = connection.settings_dict['TEST_MIRROR']
|
mirrored_aliases[alias] = (
|
||||||
|
connection.settings_dict['TEST_MIRROR'])
|
||||||
else:
|
else:
|
||||||
# Store a tuple with DB parameters that uniquely identify it.
|
# Store a tuple with DB parameters that uniquely identify it.
|
||||||
# If we have two aliases with the same values for that tuple,
|
# If we have two aliases with the same values for that tuple,
|
||||||
|
@ -276,53 +298,57 @@ class DjangoTestSuiteRunner(object):
|
||||||
item[1].append(alias)
|
item[1].append(alias)
|
||||||
|
|
||||||
if 'TEST_DEPENDENCIES' in connection.settings_dict:
|
if 'TEST_DEPENDENCIES' in connection.settings_dict:
|
||||||
dependencies[alias] = connection.settings_dict['TEST_DEPENDENCIES']
|
dependencies[alias] = (
|
||||||
|
connection.settings_dict['TEST_DEPENDENCIES'])
|
||||||
else:
|
else:
|
||||||
if alias != DEFAULT_DB_ALIAS:
|
if alias != DEFAULT_DB_ALIAS:
|
||||||
dependencies[alias] = connection.settings_dict.get('TEST_DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
dependencies[alias] = connection.settings_dict.get(
|
||||||
|
'TEST_DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||||
|
|
||||||
# Second pass -- actually create the databases.
|
# Second pass -- actually create the databases.
|
||||||
old_names = []
|
old_names = []
|
||||||
mirrors = []
|
mirrors = []
|
||||||
for signature, (db_name, aliases) in dependency_ordered(test_databases.items(), dependencies):
|
for signature, (db_name, aliases) in dependency_ordered(
|
||||||
|
test_databases.items(), dependencies):
|
||||||
# Actually create the database for the first connection
|
# Actually create the database for the first connection
|
||||||
connection = connections[aliases[0]]
|
connection = connections[aliases[0]]
|
||||||
old_names.append((connection, db_name, True))
|
old_names.append((connection, db_name, True))
|
||||||
test_db_name = connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive)
|
test_db_name = connection.creation.create_test_db(
|
||||||
|
self.verbosity, autoclobber=not self.interactive)
|
||||||
for alias in aliases[1:]:
|
for alias in aliases[1:]:
|
||||||
connection = connections[alias]
|
connection = connections[alias]
|
||||||
if db_name:
|
if db_name:
|
||||||
old_names.append((connection, db_name, False))
|
old_names.append((connection, db_name, False))
|
||||||
connection.settings_dict['NAME'] = test_db_name
|
connection.settings_dict['NAME'] = test_db_name
|
||||||
else:
|
else:
|
||||||
# If settings_dict['NAME'] isn't defined, we have a backend where
|
# If settings_dict['NAME'] isn't defined, we have a backend
|
||||||
# the name isn't important -- e.g., SQLite, which uses :memory:.
|
# where the name isn't important -- e.g., SQLite, which
|
||||||
# Force create the database instead of assuming it's a duplicate.
|
# uses :memory:. Force create the database instead of
|
||||||
|
# assuming it's a duplicate.
|
||||||
old_names.append((connection, db_name, True))
|
old_names.append((connection, db_name, True))
|
||||||
connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive)
|
connection.creation.create_test_db(
|
||||||
|
self.verbosity, autoclobber=not self.interactive)
|
||||||
|
|
||||||
for alias, mirror_alias in mirrored_aliases.items():
|
for alias, mirror_alias in mirrored_aliases.items():
|
||||||
mirrors.append((alias, connections[alias].settings_dict['NAME']))
|
mirrors.append((alias, connections[alias].settings_dict['NAME']))
|
||||||
connections[alias].settings_dict['NAME'] = connections[mirror_alias].settings_dict['NAME']
|
connections[alias].settings_dict['NAME'] = (
|
||||||
|
connections[mirror_alias].settings_dict['NAME'])
|
||||||
connections[alias].features = connections[mirror_alias].features
|
connections[alias].features = connections[mirror_alias].features
|
||||||
|
|
||||||
return old_names, mirrors
|
return old_names, mirrors
|
||||||
|
|
||||||
def run_suite(self, suite, **kwargs):
|
def run_suite(self, suite, **kwargs):
|
||||||
return unittest.TextTestRunner(verbosity=self.verbosity, failfast=self.failfast).run(suite)
|
return unittest.TextTestRunner(
|
||||||
|
verbosity=self.verbosity, failfast=self.failfast).run(suite)
|
||||||
|
|
||||||
def teardown_databases(self, old_config, **kwargs):
|
def teardown_databases(self, old_config, **kwargs):
|
||||||
from django.db import connections
|
"""
|
||||||
|
Destroys all the non-mirror databases.
|
||||||
|
"""
|
||||||
old_names, mirrors = old_config
|
old_names, mirrors = old_config
|
||||||
# Point all the mirrors back to the originals
|
|
||||||
for alias, old_name in mirrors:
|
|
||||||
connections[alias].settings_dict['NAME'] = old_name
|
|
||||||
# Destroy all the non-mirror databases
|
|
||||||
for connection, old_name, destroy in old_names:
|
for connection, old_name, destroy in old_names:
|
||||||
if destroy:
|
if destroy:
|
||||||
connection.creation.destroy_test_db(old_name, self.verbosity)
|
connection.creation.destroy_test_db(old_name, self.verbosity)
|
||||||
else:
|
|
||||||
connection.settings_dict['NAME'] = old_name
|
|
||||||
|
|
||||||
def teardown_test_environment(self, **kwargs):
|
def teardown_test_environment(self, **kwargs):
|
||||||
unittest.removeHandler()
|
unittest.removeHandler()
|
||||||
|
|
|
@ -946,6 +946,19 @@ apply URL escaping again. This is wrong for URLs whose unquoted form contains
|
||||||
a ``%xx`` sequence, but such URLs are very unlikely to happen in the wild,
|
a ``%xx`` sequence, but such URLs are very unlikely to happen in the wild,
|
||||||
since they would confuse browsers too.
|
since they would confuse browsers too.
|
||||||
|
|
||||||
|
Database connections after running the test suite
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The default test runner now does not restore the database connections after the
|
||||||
|
tests' execution any more. This prevents the production database from being
|
||||||
|
exposed to potential threads that would still be running and attempting to
|
||||||
|
create new connections.
|
||||||
|
|
||||||
|
If your code relied on connections to the production database being created
|
||||||
|
after the tests' execution, then you may restore the previous behavior by
|
||||||
|
subclassing ``DjangoTestRunner`` and overriding its ``teardown_databases()``
|
||||||
|
method.
|
||||||
|
|
||||||
Features deprecated in 1.4
|
Features deprecated in 1.4
|
||||||
==========================
|
==========================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue