From 3b3f38b3b09b0f2373e51406ecb8c9c45d36aebc Mon Sep 17 00:00:00 2001 From: David Smith Date: Sat, 12 Feb 2022 20:40:12 +0000 Subject: [PATCH] Fixed #31169 -- Adapted the parallel test runner to use spawn. Co-authored-by: Valz Co-authored-by: Nick Pope --- django/db/backends/sqlite3/creation.py | 60 +++++++++++++++++++-- django/test/runner.py | 59 +++++++++++++++++--- django/utils/autoreload.py | 2 +- tests/admin_checks/tests.py | 2 + tests/backends/sqlite/test_creation.py | 10 +++- tests/check_framework/tests.py | 2 + tests/contenttypes_tests/test_checks.py | 2 + tests/contenttypes_tests/test_management.py | 7 +++ tests/postgres_tests/test_bulk_update.py | 3 ++ tests/runtests.py | 17 +++++- tests/test_runner/test_discover_runner.py | 10 ++++ tests/test_runner/tests.py | 10 ++-- 12 files changed, 161 insertions(+), 23 deletions(-) diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index 9d8d4a63ad3..d15dea4b036 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -1,8 +1,11 @@ +import multiprocessing import os import shutil +import sqlite3 import sys from pathlib import Path +from django.db import NotSupportedError from django.db.backends.base.creation import BaseDatabaseCreation @@ -51,16 +54,26 @@ class DatabaseCreation(BaseDatabaseCreation): def get_test_db_clone_settings(self, suffix): orig_settings_dict = self.connection.settings_dict source_database_name = orig_settings_dict["NAME"] - if self.is_in_memory_db(source_database_name): + + if not self.is_in_memory_db(source_database_name): + root, ext = os.path.splitext(source_database_name) + return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"} + + start_method = multiprocessing.get_start_method() + if start_method == "fork": return orig_settings_dict - else: - root, ext = os.path.splitext(orig_settings_dict["NAME"]) - return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)} + if start_method == "spawn": + return { + **orig_settings_dict, + "NAME": f"{self.connection.alias}_{suffix}.sqlite3", + } + raise NotSupportedError( + f"Cloning with start method {start_method!r} is not supported." + ) def _clone_test_db(self, suffix, verbosity, keepdb=False): source_database_name = self.connection.settings_dict["NAME"] target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] - # Forking automatically makes a copy of an in-memory database. if not self.is_in_memory_db(source_database_name): # Erase the old test database if os.access(target_database_name, os.F_OK): @@ -85,6 +98,12 @@ class DatabaseCreation(BaseDatabaseCreation): except Exception as e: self.log("Got an error cloning the test database: %s" % e) sys.exit(2) + # Forking automatically makes a copy of an in-memory database. + # Spawn requires migrating to disk which will be re-opened in + # setup_worker_connection. + elif multiprocessing.get_start_method() == "spawn": + ondisk_db = sqlite3.connect(target_database_name, uri=True) + self.connection.connection.backup(ondisk_db) def _destroy_test_db(self, test_database_name, verbosity): if test_database_name and not self.is_in_memory_db(test_database_name): @@ -106,3 +125,34 @@ class DatabaseCreation(BaseDatabaseCreation): else: sig.append(test_database_name) return tuple(sig) + + def setup_worker_connection(self, _worker_id): + settings_dict = self.get_test_db_clone_settings(_worker_id) + # connection.settings_dict must be updated in place for changes to be + # reflected in django.db.connections. Otherwise new threads would + # connect to the default database instead of the appropriate clone. + start_method = multiprocessing.get_start_method() + if start_method == "fork": + # Update settings_dict in place. + self.connection.settings_dict.update(settings_dict) + self.connection.close() + elif start_method == "spawn": + alias = self.connection.alias + connection_str = ( + f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared" + ) + source_db = self.connection.Database.connect( + f"file:{alias}_{_worker_id}.sqlite3", uri=True + ) + target_db = sqlite3.connect(connection_str, uri=True) + source_db.backup(target_db) + source_db.close() + # Update settings_dict in place. + self.connection.settings_dict.update(settings_dict) + self.connection.settings_dict["NAME"] = connection_str + # Re-open connection to in-memory database before closing copy + # connection. + self.connection.connect() + target_db.close() + if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true": + self.mark_expected_failures_and_skips() diff --git a/django/test/runner.py b/django/test/runner.py index aba515e7351..89bb6cf1fcd 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -20,7 +20,12 @@ from io import StringIO from django.core.management import call_command from django.db import connections from django.test import SimpleTestCase, TestCase -from django.test.utils import NullTimeKeeper, TimeKeeper, iter_test_cases +from django.test.utils import ( + NullTimeKeeper, + TimeKeeper, + captured_stdout, + iter_test_cases, +) from django.test.utils import setup_databases as _setup_databases from django.test.utils import setup_test_environment from django.test.utils import teardown_databases as _teardown_databases @@ -367,8 +372,8 @@ def get_max_test_processes(): The maximum number of test processes when using the --parallel option. """ # The current implementation of the parallel test runner requires - # multiprocessing to start subprocesses with fork(). - if multiprocessing.get_start_method() != "fork": + # multiprocessing to start subprocesses with fork() or spawn(). + if multiprocessing.get_start_method() not in {"fork", "spawn"}: return 1 try: return int(os.environ["DJANGO_TEST_PROCESSES"]) @@ -391,7 +396,13 @@ def parallel_type(value): _worker_id = 0 -def _init_worker(counter): +def _init_worker( + counter, + initial_settings=None, + serialized_contents=None, + process_setup=None, + process_setup_args=None, +): """ Switch to databases dedicated to this worker. @@ -405,9 +416,22 @@ def _init_worker(counter): counter.value += 1 _worker_id = counter.value + start_method = multiprocessing.get_start_method() + + if start_method == "spawn": + process_setup(*process_setup_args) + setup_test_environment() + for alias in connections: connection = connections[alias] + if start_method == "spawn": + # Restore initial settings in spawned processes. + connection.settings_dict.update(initial_settings[alias]) + if value := serialized_contents.get(alias): + connection._test_serialized_contents = value connection.creation.setup_worker_connection(_worker_id) + with captured_stdout(): + call_command("check", databases=connections) def _run_subsuite(args): @@ -449,6 +473,8 @@ class ParallelTestSuite(unittest.TestSuite): self.processes = processes self.failfast = failfast self.buffer = buffer + self.initial_settings = None + self.serialized_contents = None super().__init__() def run(self, result): @@ -469,8 +495,12 @@ class ParallelTestSuite(unittest.TestSuite): counter = multiprocessing.Value(ctypes.c_int, 0) pool = multiprocessing.Pool( processes=self.processes, - initializer=self.init_worker.__func__, - initargs=[counter], + initializer=self.init_worker, + initargs=[ + counter, + self.initial_settings, + self.serialized_contents, + ], ) args = [ (self.runner_class, index, subsuite, self.failfast, self.buffer) @@ -508,6 +538,17 @@ class ParallelTestSuite(unittest.TestSuite): def __iter__(self): return iter(self.subsuites) + def initialize_suite(self): + if multiprocessing.get_start_method() == "spawn": + self.initial_settings = { + alias: connections[alias].settings_dict for alias in connections + } + self.serialized_contents = { + alias: connections[alias]._test_serialized_contents + for alias in connections + if alias in self.serialized_aliases + } + class Shuffler: """ @@ -921,6 +962,8 @@ class DiscoverRunner: def run_suite(self, suite, **kwargs): kwargs = self.get_test_runner_kwargs() runner = self.test_runner(**kwargs) + if hasattr(suite, "initialize_suite"): + suite.initialize_suite() try: return runner.run(suite) finally: @@ -989,13 +1032,13 @@ class DiscoverRunner: self.setup_test_environment() suite = self.build_suite(test_labels, extra_tests) databases = self.get_databases(suite) - serialized_aliases = set( + suite.serialized_aliases = set( alias for alias, serialize in databases.items() if serialize ) with self.time_keeper.timed("Total database setup"): old_config = self.setup_databases( aliases=databases, - serialized_aliases=serialized_aliases, + serialized_aliases=suite.serialized_aliases, ) run_failed = False try: diff --git a/django/utils/autoreload.py b/django/utils/autoreload.py index 7b9219f4c13..1b3652b41d9 100644 --- a/django/utils/autoreload.py +++ b/django/utils/autoreload.py @@ -130,7 +130,7 @@ def iter_modules_and_files(modules, extra_files): # cause issues here. if not isinstance(module, ModuleType): continue - if module.__name__ == "__main__": + if module.__name__ in ("__main__", "__mp_main__"): # __main__ (usually manage.py) doesn't always have a __spec__ set. # Handle this by falling back to using __file__, resolved below. # See https://docs.python.org/reference/import.html#main-spec diff --git a/tests/admin_checks/tests.py b/tests/admin_checks/tests.py index 2646837bbc1..aa87649dcec 100644 --- a/tests/admin_checks/tests.py +++ b/tests/admin_checks/tests.py @@ -70,6 +70,8 @@ class SessionMiddlewareSubclass(SessionMiddleware): ], ) class SystemChecksTestCase(SimpleTestCase): + databases = "__all__" + def test_checks_are_performed(self): admin.site.register(Song, MyAdmin) try: diff --git a/tests/backends/sqlite/test_creation.py b/tests/backends/sqlite/test_creation.py index ab1640c04e9..8aa24674d22 100644 --- a/tests/backends/sqlite/test_creation.py +++ b/tests/backends/sqlite/test_creation.py @@ -1,7 +1,9 @@ import copy +import multiprocessing import unittest +from unittest import mock -from django.db import DEFAULT_DB_ALIAS, connection, connections +from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connection, connections from django.test import SimpleTestCase @@ -33,3 +35,9 @@ class TestDbSignatureTests(SimpleTestCase): creation_class = test_connection.creation_class(test_connection) clone_settings_dict = creation_class.get_test_db_clone_settings("1") self.assertEqual(clone_settings_dict["NAME"], expected_clone_name) + + @mock.patch.object(multiprocessing, "get_start_method", return_value="forkserver") + def test_get_test_db_clone_settings_not_supported(self, *mocked_objects): + msg = "Cloning with start method 'forkserver' is not supported." + with self.assertRaisesMessage(NotSupportedError, msg): + connection.creation.get_test_db_clone_settings(1) diff --git a/tests/check_framework/tests.py b/tests/check_framework/tests.py index 4c09c757b4f..d064a7c4038 100644 --- a/tests/check_framework/tests.py +++ b/tests/check_framework/tests.py @@ -362,5 +362,7 @@ class CheckFrameworkReservedNamesTests(SimpleTestCase): class ChecksRunDuringTests(SimpleTestCase): + databases = "__all__" + def test_registered_check_did_run(self): self.assertTrue(my_check.did_run) diff --git a/tests/contenttypes_tests/test_checks.py b/tests/contenttypes_tests/test_checks.py index bd36c569a19..50730c5a1d6 100644 --- a/tests/contenttypes_tests/test_checks.py +++ b/tests/contenttypes_tests/test_checks.py @@ -11,6 +11,8 @@ from django.test.utils import isolate_apps @isolate_apps("contenttypes_tests", attr_name="apps") class GenericForeignKeyTests(SimpleTestCase): + databases = "__all__" + def test_missing_content_type_field(self): class TaggedItem(models.Model): # no content_type field diff --git a/tests/contenttypes_tests/test_management.py b/tests/contenttypes_tests/test_management.py index d5e14c7df37..eb472d80cef 100644 --- a/tests/contenttypes_tests/test_management.py +++ b/tests/contenttypes_tests/test_management.py @@ -22,6 +22,13 @@ class RemoveStaleContentTypesTests(TestCase): @classmethod def setUpTestData(cls): + with captured_stdout(): + call_command( + "remove_stale_contenttypes", + interactive=False, + include_stale_apps=True, + verbosity=2, + ) cls.before_count = ContentType.objects.count() cls.content_type = ContentType.objects.create( app_label="contenttypes_tests", model="Fake" diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py index f0b473efa7d..5f91f777916 100644 --- a/tests/postgres_tests/test_bulk_update.py +++ b/tests/postgres_tests/test_bulk_update.py @@ -1,5 +1,7 @@ from datetime import date +from django.test import modify_settings + from . import PostgreSQLTestCase from .models import ( HStoreModel, @@ -16,6 +18,7 @@ except ImportError: pass # psycopg2 isn't installed. +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class BulkSaveTests(PostgreSQLTestCase): def test_bulk_update(self): test_data = [ diff --git a/tests/runtests.py b/tests/runtests.py index 330c8abd048..e3a60d777bc 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -3,6 +3,7 @@ import argparse import atexit import copy import gc +import multiprocessing import os import shutil import socket @@ -10,6 +11,7 @@ import subprocess import sys import tempfile import warnings +from functools import partial from pathlib import Path try: @@ -24,7 +26,7 @@ else: from django.core.exceptions import ImproperlyConfigured from django.db import connection, connections from django.test import TestCase, TransactionTestCase - from django.test.runner import get_max_test_processes, parallel_type + from django.test.runner import _init_worker, get_max_test_processes, parallel_type from django.test.selenium import SeleniumTestCaseBase from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner from django.utils.deprecation import RemovedInDjango50Warning @@ -382,7 +384,8 @@ def django_tests( msg += " with up to %d processes" % max_parallel print(msg) - test_labels, state = setup_run_tests(verbosity, start_at, start_after, test_labels) + process_setup_args = (verbosity, start_at, start_after, test_labels) + test_labels, state = setup_run_tests(*process_setup_args) # Run the test suite, including the extra validation tests. if not hasattr(settings, "TEST_RUNNER"): settings.TEST_RUNNER = "django.test.runner.DiscoverRunner" @@ -395,6 +398,11 @@ def django_tests( parallel = 1 TestRunner = get_runner(settings) + TestRunner.parallel_test_suite.init_worker = partial( + _init_worker, + process_setup=setup_run_tests, + process_setup_args=process_setup_args, + ) test_runner = TestRunner( verbosity=verbosity, interactive=interactive, @@ -718,6 +726,11 @@ if __name__ == "__main__": options.settings = os.environ["DJANGO_SETTINGS_MODULE"] if options.selenium: + if multiprocessing.get_start_method() == "spawn" and options.parallel != 1: + parser.error( + "You cannot use --selenium with parallel tests on this system. " + "Pass --parallel=1 to use --selenium." + ) if not options.tags: options.tags = ["selenium"] elif "selenium" not in options.tags: diff --git a/tests/test_runner/test_discover_runner.py b/tests/test_runner/test_discover_runner.py index fd36f0ab890..bca90374923 100644 --- a/tests/test_runner/test_discover_runner.py +++ b/tests/test_runner/test_discover_runner.py @@ -86,6 +86,16 @@ class DiscoverRunnerParallelArgumentTests(SimpleTestCase): mocked_cpu_count, ): mocked_get_start_method.return_value = "spawn" + self.assertEqual(get_max_test_processes(), 12) + with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}): + self.assertEqual(get_max_test_processes(), 7) + + def test_get_max_test_processes_forkserver( + self, + mocked_get_start_method, + mocked_cpu_count, + ): + mocked_get_start_method.return_value = "forkserver" self.assertEqual(get_max_test_processes(), 1) with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}): self.assertEqual(get_max_test_processes(), 1) diff --git a/tests/test_runner/tests.py b/tests/test_runner/tests.py index 665f2c5ef80..c50c54f25c8 100644 --- a/tests/test_runner/tests.py +++ b/tests/test_runner/tests.py @@ -480,8 +480,6 @@ class ManageCommandTests(unittest.TestCase): # Isolate from the real environment. @mock.patch.dict(os.environ, {}, clear=True) @mock.patch.object(multiprocessing, "cpu_count", return_value=12) -# Python 3.8 on macOS defaults to 'spawn' mode. -@mock.patch.object(multiprocessing, "get_start_method", return_value="fork") class ManageCommandParallelTests(SimpleTestCase): def test_parallel_default(self, *mocked_objects): with captured_stderr() as stderr: @@ -507,8 +505,8 @@ class ManageCommandParallelTests(SimpleTestCase): # Parallel is disabled by default. self.assertEqual(stderr.getvalue(), "") - def test_parallel_spawn(self, mocked_get_start_method, mocked_cpu_count): - mocked_get_start_method.return_value = "spawn" + @mock.patch.object(multiprocessing, "get_start_method", return_value="spawn") + def test_parallel_spawn(self, *mocked_objects): with captured_stderr() as stderr: call_command( "test", @@ -517,8 +515,8 @@ class ManageCommandParallelTests(SimpleTestCase): ) self.assertIn("parallel=1", stderr.getvalue()) - def test_no_parallel_spawn(self, mocked_get_start_method, mocked_cpu_count): - mocked_get_start_method.return_value = "spawn" + @mock.patch.object(multiprocessing, "get_start_method", return_value="spawn") + def test_no_parallel_spawn(self, *mocked_objects): with captured_stderr() as stderr: call_command( "test",