diff --git a/django/test/runner.py b/django/test/runner.py index fe30d2289b..55d902bb09 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -419,7 +419,10 @@ def _init_worker( start_method = multiprocessing.get_start_method() if start_method == "spawn": - process_setup(*process_setup_args) + if process_setup and callable(process_setup): + if process_setup_args is None: + process_setup_args = () + process_setup(*process_setup_args) setup_test_environment() for alias in connections: @@ -465,6 +468,7 @@ class ParallelTestSuite(unittest.TestSuite): # In case someone wants to modify these in a subclass. init_worker = _init_worker + process_setup_args = () run_subsuite = _run_subsuite runner_class = RemoteTestRunner @@ -477,6 +481,14 @@ class ParallelTestSuite(unittest.TestSuite): self.serialized_contents = None super().__init__() + def process_setup(self, *args): + """ + Stub method to simplify run() implementation. "self" is never actually + passed because a function implementing this method (__func__) is + always used, not the method itself. + """ + pass + def run(self, result): """ Distribute test cases across workers. @@ -496,11 +508,13 @@ class ParallelTestSuite(unittest.TestSuite): counter = multiprocessing.Value(ctypes.c_int, 0) pool = multiprocessing.Pool( processes=self.processes, - initializer=self.init_worker, + initializer=self.init_worker.__func__, initargs=[ counter, self.initial_settings, self.serialized_contents, + self.process_setup.__func__, + self.process_setup_args, ], ) args = [ diff --git a/tests/runtests.py b/tests/runtests.py index b5cea631b0..3bbaf5ada7 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -11,7 +11,6 @@ import subprocess import sys import tempfile import warnings -from functools import partial from pathlib import Path try: @@ -26,7 +25,7 @@ else: from django.core.exceptions import ImproperlyConfigured from django.db import connection, connections from django.test import TestCase, TransactionTestCase - from django.test.runner import _init_worker, get_max_test_processes, parallel_type + from django.test.runner import get_max_test_processes, parallel_type from django.test.selenium import SeleniumTestCaseBase from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner from django.utils.deprecation import ( @@ -405,11 +404,8 @@ def django_tests( parallel = 1 TestRunner = get_runner(settings) - TestRunner.parallel_test_suite.init_worker = partial( - _init_worker, - process_setup=setup_run_tests, - process_setup_args=process_setup_args, - ) + TestRunner.parallel_test_suite.process_setup = setup_run_tests + TestRunner.parallel_test_suite.process_setup_args = process_setup_args test_runner = TestRunner( verbosity=verbosity, interactive=interactive, diff --git a/tests/test_runner/tests.py b/tests/test_runner/tests.py index c022a6fb86..b3c7cc5a55 100644 --- a/tests/test_runner/tests.py +++ b/tests/test_runner/tests.py @@ -19,6 +19,7 @@ from django.test import SimpleTestCase, TransactionTestCase, skipUnlessDBFeature from django.test.runner import ( DiscoverRunner, Shuffler, + _init_worker, reorder_test_bin, reorder_tests, shuffle_tests, @@ -684,6 +685,46 @@ class NoInitializeSuiteTestRunnerTests(SimpleTestCase): ) +class TestRunnerInitializerTests(SimpleTestCase): + + # Raise an exception to don't actually run tests. + @mock.patch.object( + multiprocessing, "Pool", side_effect=Exception("multiprocessing.Pool()") + ) + def test_no_initialize_suite_test_runner(self, mocked_pool): + class StubTestRunner(DiscoverRunner): + def setup_test_environment(self, **kwargs): + return + + def setup_databases(self, **kwargs): + return + + def run_checks(self, databases): + return + + def teardown_databases(self, old_config, **kwargs): + return + + def teardown_test_environment(self, **kwargs): + return + + def run_suite(self, suite, **kwargs): + kwargs = self.get_test_runner_kwargs() + runner = self.test_runner(**kwargs) + return runner.run(suite) + + runner = StubTestRunner(verbosity=0, interactive=False, parallel=2) + with self.assertRaisesMessage(Exception, "multiprocessing.Pool()"): + runner.run_tests( + [ + "test_runner_apps.sample.tests_sample.TestDjangoTestCase", + "test_runner_apps.simple.tests", + ] + ) + # Initializer must be a function. + self.assertIs(mocked_pool.call_args.kwargs["initializer"], _init_worker) + + class Ticket17477RegressionTests(AdminScriptTestCase): def setUp(self): super().setUp()