Fixed #16990 -- Fixed a couple of small docstring typos in the `django/test/testcases.py` module and did some minor cleanup while I was in the area.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16928 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Julien Phalip 2011-10-05 12:50:44 +00:00
parent 05cf72bace
commit 03d4a8d1b6
1 changed files with 50 additions and 26 deletions

View File

@ -19,7 +19,8 @@ from django.forms.fields import CharField
from django.http import QueryDict from django.http import QueryDict
from django.test import _doctest as doctest from django.test import _doctest as doctest
from django.test.client import Client from django.test.client import Client
from django.test.utils import get_warnings_state, restore_warnings_state, override_settings from django.test.utils import (get_warnings_state, restore_warnings_state,
override_settings)
from django.utils import simplejson, unittest as ut2 from django.utils import simplejson, unittest as ut2
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
@ -27,7 +28,8 @@ __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s) normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)",
lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
def to_list(value): def to_list(value):
""" """
@ -65,7 +67,9 @@ def restore_transaction_methods():
class OutputChecker(doctest.OutputChecker): class OutputChecker(doctest.OutputChecker):
def check_output(self, want, got, optionflags): def check_output(self, want, got, optionflags):
"The entry method for doctest output checking. Defers to a sequence of child checkers" """
The entry method for doctest output checking. Defers to a sequence of child checkers
"""
checks = (self.check_output_default, checks = (self.check_output_default,
self.check_output_numeric, self.check_output_numeric,
self.check_output_xml, self.check_output_xml,
@ -76,7 +80,10 @@ class OutputChecker(doctest.OutputChecker):
return False return False
def check_output_default(self, want, got, optionflags): def check_output_default(self, want, got, optionflags):
"The default comparator provided by doctest - not perfect, but good for most purposes" """
The default comparator provided by doctest - not perfect, but good for
most purposes
"""
return doctest.OutputChecker.check_output(self, want, got, optionflags) return doctest.OutputChecker.check_output(self, want, got, optionflags)
def check_output_numeric(self, want, got, optionflags): def check_output_numeric(self, want, got, optionflags):
@ -147,17 +154,19 @@ class OutputChecker(doctest.OutputChecker):
try: try:
want_root = parseString(want).firstChild want_root = parseString(want).firstChild
got_root = parseString(got).firstChild got_root = parseString(got).firstChild
except: except Exception:
return False return False
return check_element(want_root, got_root) return check_element(want_root, got_root)
def check_output_json(self, want, got, optionsflags): def check_output_json(self, want, got, optionsflags):
"Tries to compare want and got as if they were JSON-encoded data" """
Tries to compare want and got as if they were JSON-encoded data
"""
want, got = self._strip_quotes(want, got) want, got = self._strip_quotes(want, got)
try: try:
want_json = simplejson.loads(want) want_json = simplejson.loads(want)
got_json = simplejson.loads(got) got_json = simplejson.loads(got)
except: except Exception:
return False return False
return want_json == got_json return want_json == got_json
@ -248,7 +257,7 @@ class SimpleTestCase(ut2.TestCase):
def restore_warnings_state(self): def restore_warnings_state(self):
""" """
Restores the sate of the warnings module to the state Restores the state of the warnings module to the state
saved by save_warnings_state() saved by save_warnings_state()
""" """
restore_warnings_state(self._warnings_state) restore_warnings_state(self._warnings_state)
@ -262,7 +271,9 @@ class SimpleTestCase(ut2.TestCase):
def assertRaisesMessage(self, expected_exception, expected_message, def assertRaisesMessage(self, expected_exception, expected_message,
callable_obj=None, *args, **kwargs): callable_obj=None, *args, **kwargs):
"""Asserts that the message in a raised exception matches the passe value. """
Asserts that the message in a raised exception matches the passed
value.
Args: Args:
expected_exception: Exception class expected to be raised. expected_exception: Exception class expected to be raised.
@ -295,7 +306,8 @@ class SimpleTestCase(ut2.TestCase):
if field_kwargs is None: if field_kwargs is None:
field_kwargs = {} field_kwargs = {}
required = fieldclass(*field_args, **field_kwargs) required = fieldclass(*field_args, **field_kwargs)
optional = fieldclass(*field_args, **dict(field_kwargs, required=False)) optional = fieldclass(*field_args,
**dict(field_kwargs, required=False))
# test valid inputs # test valid inputs
for input, output in valid.items(): for input, output in valid.items():
self.assertEqual(required.clean(input), output) self.assertEqual(required.clean(input), output)
@ -314,12 +326,14 @@ class SimpleTestCase(ut2.TestCase):
for e in EMPTY_VALUES: for e in EMPTY_VALUES:
with self.assertRaises(ValidationError) as context_manager: with self.assertRaises(ValidationError) as context_manager:
required.clean(e) required.clean(e)
self.assertEqual(context_manager.exception.messages, error_required) self.assertEqual(context_manager.exception.messages,
error_required)
self.assertEqual(optional.clean(e), empty_value) self.assertEqual(optional.clean(e), empty_value)
# test that max_length and min_length are always accepted # test that max_length and min_length are always accepted
if issubclass(fieldclass, CharField): if issubclass(fieldclass, CharField):
field_kwargs.update({'min_length':2, 'max_length':20}) field_kwargs.update({'min_length':2, 'max_length':20})
self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs), fieldclass)) self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs),
fieldclass))
class TransactionTestCase(SimpleTestCase): class TransactionTestCase(SimpleTestCase):
# The class we'll use for the test client self.client. # The class we'll use for the test client self.client.
@ -353,7 +367,8 @@ class TransactionTestCase(SimpleTestCase):
if hasattr(self, 'fixtures'): if hasattr(self, 'fixtures'):
# We have to use this slightly awkward syntax due to the fact # We have to use this slightly awkward syntax due to the fact
# that we're using *args and **kwargs together. # that we're using *args and **kwargs together.
call_command('loaddata', *self.fixtures, **{'verbosity': 0, 'database': db}) call_command('loaddata', *self.fixtures,
**{'verbosity': 0, 'database': db})
def _urlconf_setup(self): def _urlconf_setup(self):
if hasattr(self, 'urls'): if hasattr(self, 'urls'):
@ -466,7 +481,8 @@ class TransactionTestCase(SimpleTestCase):
" response code was %d (expected %d)" % " response code was %d (expected %d)" %
(path, redirect_response.status_code, target_status_code)) (path, redirect_response.status_code, target_status_code))
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url) e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(
expected_url)
if not (e_scheme or e_netloc): if not (e_scheme or e_netloc):
expected_url = urlunsplit(('http', host or 'testserver', e_path, expected_url = urlunsplit(('http', host or 'testserver', e_path,
e_query, e_fragment)) e_query, e_fragment))
@ -617,14 +633,16 @@ def connections_support_transactions():
""" """
Returns True if all connections support transactions. Returns True if all connections support transactions.
""" """
return all(conn.features.supports_transactions for conn in connections.all()) return all(conn.features.supports_transactions
for conn in connections.all())
class TestCase(TransactionTestCase): class TestCase(TransactionTestCase):
""" """
Does basically the same as TransactionTestCase, but surrounds every test Does basically the same as TransactionTestCase, but surrounds every test
with a transaction, monkey-patches the real transaction management routines to with a transaction, monkey-patches the real transaction management routines
do nothing, and rollsback the test transaction at the end of the test. You have to do nothing, and rollsback the test transaction at the end of the test.
to use TransactionTestCase, if you need transaction management inside a test. You have to use TransactionTestCase, if you need transaction management
inside a test.
""" """
def _fixture_setup(self): def _fixture_setup(self):
@ -648,11 +666,12 @@ class TestCase(TransactionTestCase):
for db in databases: for db in databases:
if hasattr(self, 'fixtures'): if hasattr(self, 'fixtures'):
call_command('loaddata', *self.fixtures, **{ call_command('loaddata', *self.fixtures,
'verbosity': 0, **{
'commit': False, 'verbosity': 0,
'database': db 'commit': False,
}) 'database': db
})
def _fixture_teardown(self): def _fixture_teardown(self):
if not connections_support_transactions(): if not connections_support_transactions():
@ -672,7 +691,8 @@ class TestCase(TransactionTestCase):
def _deferredSkip(condition, reason): def _deferredSkip(condition, reason):
def decorator(test_func): def decorator(test_func):
if not (isinstance(test_func, type) and issubclass(test_func, TestCase)): if not (isinstance(test_func, type) and
issubclass(test_func, TestCase)):
@wraps(test_func) @wraps(test_func)
def skip_wrapper(*args, **kwargs): def skip_wrapper(*args, **kwargs):
if condition(): if condition():
@ -686,11 +706,15 @@ def _deferredSkip(condition, reason):
return decorator return decorator
def skipIfDBFeature(feature): def skipIfDBFeature(feature):
"Skip a test if a database has the named feature" """
Skip a test if a database has the named feature
"""
return _deferredSkip(lambda: getattr(connection.features, feature), return _deferredSkip(lambda: getattr(connection.features, feature),
"Database has feature %s" % feature) "Database has feature %s" % feature)
def skipUnlessDBFeature(feature): def skipUnlessDBFeature(feature):
"Skip a test unless a database has the named feature" """
Skip a test unless a database has the named feature
"""
return _deferredSkip(lambda: not getattr(connection.features, feature), return _deferredSkip(lambda: not getattr(connection.features, feature),
"Database doesn't support feature %s" % feature) "Database doesn't support feature %s" % feature)