diff --git a/django/test/runner.py b/django/test/runner.py index db7628df4cd..20bc93fca74 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -16,9 +16,9 @@ from django.core.management import call_command from django.db import connections from django.test import SimpleTestCase, TestCase from django.test.utils import ( - NullTimeKeeper, TimeKeeper, setup_databases as _setup_databases, - setup_test_environment, teardown_databases as _teardown_databases, - teardown_test_environment, + NullTimeKeeper, TimeKeeper, iter_test_cases, + setup_databases as _setup_databases, setup_test_environment, + teardown_databases as _teardown_databases, teardown_test_environment, ) from django.utils.datastructures import OrderedSet @@ -683,19 +683,16 @@ class DiscoverRunner: def _get_databases(self, suite): databases = {} - for test in suite: - if isinstance(test, unittest.TestCase): - test_databases = getattr(test, 'databases', None) - if test_databases == '__all__': - test_databases = connections - if test_databases: - serialized_rollback = getattr(test, 'serialized_rollback', False) - databases.update( - (alias, serialized_rollback or databases.get(alias, False)) - for alias in test_databases - ) - else: - databases.update(self._get_databases(test)) + for test in iter_test_cases(suite): + test_databases = getattr(test, 'databases', None) + if test_databases == '__all__': + test_databases = connections + if test_databases: + serialized_rollback = getattr(test, 'serialized_rollback', False) + databases.update( + (alias, serialized_rollback or databases.get(alias, False)) + for alias in test_databases + ) return databases def get_databases(self, suite): @@ -800,49 +797,39 @@ def partition_suite_by_type(suite, classes, bins, reverse=False): Tests of type classes[i] are added to bins[i], tests with no match found in classes are place in bins[-1] """ - suite_class = type(suite) - if reverse: - suite = reversed(tuple(suite)) - for test in suite: - if isinstance(test, suite_class): - partition_suite_by_type(test, classes, bins, reverse=reverse) + for test in iter_test_cases(suite, reverse=reverse): + for i in range(len(classes)): + if isinstance(test, classes[i]): + bins[i].add(test) + break else: - for i in range(len(classes)): - if isinstance(test, classes[i]): - bins[i].add(test) - break - else: - bins[-1].add(test) + bins[-1].add(test) def partition_suite_by_case(suite): """Partition a test suite by test case, preserving the order of tests.""" - groups = [] + subsuites = [] suite_class = type(suite) - for test_type, test_group in itertools.groupby(suite, type): - if issubclass(test_type, unittest.TestCase): - groups.append(suite_class(test_group)) - else: - for item in test_group: - groups.extend(partition_suite_by_case(item)) - return groups + tests = iter_test_cases(suite) + for test_type, test_group in itertools.groupby(tests, type): + subsuite = suite_class(test_group) + subsuites.append(subsuite) + + return subsuites def filter_tests_by_tags(suite, tags, exclude_tags): suite_class = type(suite) filtered_suite = suite_class() - for test in suite: - if isinstance(test, suite_class): - filtered_suite.addTests(filter_tests_by_tags(test, tags, exclude_tags)) - else: - test_tags = set(getattr(test, 'tags', set())) - test_fn_name = getattr(test, '_testMethodName', str(test)) - test_fn = getattr(test, test_fn_name, test) - test_fn_tags = set(getattr(test_fn, 'tags', set())) - all_tags = test_tags.union(test_fn_tags) - matched_tags = all_tags.intersection(tags) - if (matched_tags or not tags) and not all_tags.intersection(exclude_tags): - filtered_suite.addTest(test) + for test in iter_test_cases(suite): + test_tags = set(getattr(test, 'tags', set())) + test_fn_name = getattr(test, '_testMethodName', str(test)) + test_fn = getattr(test, test_fn_name, test) + test_fn_tags = set(getattr(test_fn, 'tags', set())) + all_tags = test_tags.union(test_fn_tags) + matched_tags = all_tags.intersection(tags) + if (matched_tags or not tags) and not all_tags.intersection(exclude_tags): + filtered_suite.addTest(test) return filtered_suite diff --git a/django/test/utils.py b/django/test/utils.py index c019d773545..638bcd6fb5f 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -235,6 +235,18 @@ def setup_databases( return old_names +def iter_test_cases(suite, reverse=False): + """Return an iterator over a test suite's unittest.TestCase objects.""" + if reverse: + suite = reversed(tuple(suite)) + for test in suite: + if isinstance(test, TestCase): + yield test + else: + # Otherwise, assume it is a test suite. + yield from iter_test_cases(test, reverse=reverse) + + def dependency_ordered(test_databases, dependencies): """ Reorder test_databases into an order that honors the dependencies diff --git a/tests/test_runner/tests.py b/tests/test_runner/tests.py index 5dc31265812..8a071d0c8d8 100644 --- a/tests/test_runner/tests.py +++ b/tests/test_runner/tests.py @@ -18,12 +18,93 @@ from django.test.runner import DiscoverRunner from django.test.testcases import connections_support_transactions from django.test.utils import ( captured_stderr, dependency_ordered, get_unique_databases_and_mirrors, + iter_test_cases, ) from django.utils.deprecation import RemovedInDjango50Warning from .models import B, Person, Through +class MySuite: + def __init__(self): + self.tests = [] + + def addTest(self, test): + self.tests.append(test) + + def __iter__(self): + yield from self.tests + + +class IterTestCasesTests(unittest.TestCase): + def make_test_suite(self, suite=None, suite_class=None): + if suite_class is None: + suite_class = unittest.TestSuite + if suite is None: + suite = suite_class() + + class Tests1(unittest.TestCase): + def test1(self): + pass + + def test2(self): + pass + + class Tests2(unittest.TestCase): + def test1(self): + pass + + def test2(self): + pass + + loader = unittest.defaultTestLoader + for test_cls in (Tests1, Tests2): + tests = loader.loadTestsFromTestCase(test_cls) + subsuite = suite_class() + # Only use addTest() to simplify testing a custom TestSuite. + for test in tests: + subsuite.addTest(test) + suite.addTest(subsuite) + + return suite + + def assertTestNames(self, tests, expected): + # Each test.id() has a form like the following: + # "test_runner.tests.IterTestCasesTests.test_iter_test_cases..Tests1.test1". + # It suffices to check only the last two parts. + names = ['.'.join(test.id().split('.')[-2:]) for test in tests] + self.assertEqual(names, expected) + + def test_basic(self): + suite = self.make_test_suite() + tests = iter_test_cases(suite) + self.assertTestNames(tests, expected=[ + 'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2', + ]) + + def test_reverse(self): + suite = self.make_test_suite() + tests = iter_test_cases(suite, reverse=True) + self.assertTestNames(tests, expected=[ + 'Tests2.test2', 'Tests2.test1', 'Tests1.test2', 'Tests1.test1', + ]) + + def test_custom_test_suite_class(self): + suite = self.make_test_suite(suite_class=MySuite) + tests = iter_test_cases(suite) + self.assertTestNames(tests, expected=[ + 'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2', + ]) + + def test_mixed_test_suite_classes(self): + suite = self.make_test_suite(suite=MySuite()) + child_suite = list(suite)[0] + self.assertNotIsInstance(child_suite, MySuite) + tests = list(iter_test_cases(suite)) + self.assertEqual(len(tests), 4) + self.assertNotIsInstance(tests[0], unittest.TestSuite) + + class DependencyOrderingTests(unittest.TestCase): def test_simple_dependencies(self):