diff --git a/django/test/utils.py b/django/test/utils.py index c1bfc87fc0..e977db8a10 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -235,9 +235,13 @@ def setup_databases( return old_names -def iter_test_cases(suite): - """Return an iterator over a test suite's unittest.TestCase objects.""" - for test in suite: +def iter_test_cases(tests): + """ + Return an iterator over a test suite's unittest.TestCase objects. + + The tests argument can also be an iterable of TestCase objects. + """ + for test in tests: if isinstance(test, TestCase): yield test else: diff --git a/tests/test_runner/tests.py b/tests/test_runner/tests.py index 27a03fdda0..ec84ba31b7 100644 --- a/tests/test_runner/tests.py +++ b/tests/test_runner/tests.py @@ -89,6 +89,20 @@ class TestSuiteTests(unittest.TestCase): 'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2', ]) + def test_iter_test_cases_iterable_of_tests(self): + class Tests(unittest.TestCase): + def test1(self): + pass + + def test2(self): + pass + + tests = list(unittest.defaultTestLoader.loadTestsFromTestCase(Tests)) + actual_tests = iter_test_cases(tests) + self.assertTestNames(actual_tests, expected=[ + 'Tests.test1', 'Tests.test2', + ]) + def test_iter_test_cases_custom_test_suite_class(self): suite = self.make_test_suite(suite_class=MySuite) tests = iter_test_cases(suite)