Fixed #31169 -- Adapted the parallel test runner to use spawn.

Co-authored-by: Valz <ahmadahussein0@gmail.com>
Co-authored-by: Nick Pope <nick@nickpope.me.uk>
This commit is contained in:
David Smith 2022-02-12 20:40:12 +00:00 committed by Carlton Gibson
parent 3eaba13a47
commit 3b3f38b3b0
12 changed files with 161 additions and 23 deletions

View File

@ -1,8 +1,11 @@
import multiprocessing
import os import os
import shutil import shutil
import sqlite3
import sys import sys
from pathlib import Path from pathlib import Path
from django.db import NotSupportedError
from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.creation import BaseDatabaseCreation
@ -51,16 +54,26 @@ class DatabaseCreation(BaseDatabaseCreation):
def get_test_db_clone_settings(self, suffix): def get_test_db_clone_settings(self, suffix):
orig_settings_dict = self.connection.settings_dict orig_settings_dict = self.connection.settings_dict
source_database_name = orig_settings_dict["NAME"] source_database_name = orig_settings_dict["NAME"]
if self.is_in_memory_db(source_database_name):
if not self.is_in_memory_db(source_database_name):
root, ext = os.path.splitext(source_database_name)
return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"}
start_method = multiprocessing.get_start_method()
if start_method == "fork":
return orig_settings_dict return orig_settings_dict
else: if start_method == "spawn":
root, ext = os.path.splitext(orig_settings_dict["NAME"]) return {
return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)} **orig_settings_dict,
"NAME": f"{self.connection.alias}_{suffix}.sqlite3",
}
raise NotSupportedError(
f"Cloning with start method {start_method!r} is not supported."
)
def _clone_test_db(self, suffix, verbosity, keepdb=False): def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"] source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
# Forking automatically makes a copy of an in-memory database.
if not self.is_in_memory_db(source_database_name): if not self.is_in_memory_db(source_database_name):
# Erase the old test database # Erase the old test database
if os.access(target_database_name, os.F_OK): if os.access(target_database_name, os.F_OK):
@ -85,6 +98,12 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception as e: except Exception as e:
self.log("Got an error cloning the test database: %s" % e) self.log("Got an error cloning the test database: %s" % e)
sys.exit(2) sys.exit(2)
# Forking automatically makes a copy of an in-memory database.
# Spawn requires migrating to disk which will be re-opened in
# setup_worker_connection.
elif multiprocessing.get_start_method() == "spawn":
ondisk_db = sqlite3.connect(target_database_name, uri=True)
self.connection.connection.backup(ondisk_db)
def _destroy_test_db(self, test_database_name, verbosity): def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and not self.is_in_memory_db(test_database_name): if test_database_name and not self.is_in_memory_db(test_database_name):
@ -106,3 +125,34 @@ class DatabaseCreation(BaseDatabaseCreation):
else: else:
sig.append(test_database_name) sig.append(test_database_name)
return tuple(sig) return tuple(sig)
def setup_worker_connection(self, _worker_id):
settings_dict = self.get_test_db_clone_settings(_worker_id)
# connection.settings_dict must be updated in place for changes to be
# reflected in django.db.connections. Otherwise new threads would
# connect to the default database instead of the appropriate clone.
start_method = multiprocessing.get_start_method()
if start_method == "fork":
# Update settings_dict in place.
self.connection.settings_dict.update(settings_dict)
self.connection.close()
elif start_method == "spawn":
alias = self.connection.alias
connection_str = (
f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"
)
source_db = self.connection.Database.connect(
f"file:{alias}_{_worker_id}.sqlite3", uri=True
)
target_db = sqlite3.connect(connection_str, uri=True)
source_db.backup(target_db)
source_db.close()
# Update settings_dict in place.
self.connection.settings_dict.update(settings_dict)
self.connection.settings_dict["NAME"] = connection_str
# Re-open connection to in-memory database before closing copy
# connection.
self.connection.connect()
target_db.close()
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
self.mark_expected_failures_and_skips()

View File

@ -20,7 +20,12 @@ from io import StringIO
from django.core.management import call_command from django.core.management import call_command
from django.db import connections from django.db import connections
from django.test import SimpleTestCase, TestCase from django.test import SimpleTestCase, TestCase
from django.test.utils import NullTimeKeeper, TimeKeeper, iter_test_cases from django.test.utils import (
NullTimeKeeper,
TimeKeeper,
captured_stdout,
iter_test_cases,
)
from django.test.utils import setup_databases as _setup_databases from django.test.utils import setup_databases as _setup_databases
from django.test.utils import setup_test_environment from django.test.utils import setup_test_environment
from django.test.utils import teardown_databases as _teardown_databases from django.test.utils import teardown_databases as _teardown_databases
@ -367,8 +372,8 @@ def get_max_test_processes():
The maximum number of test processes when using the --parallel option. The maximum number of test processes when using the --parallel option.
""" """
# The current implementation of the parallel test runner requires # The current implementation of the parallel test runner requires
# multiprocessing to start subprocesses with fork(). # multiprocessing to start subprocesses with fork() or spawn().
if multiprocessing.get_start_method() != "fork": if multiprocessing.get_start_method() not in {"fork", "spawn"}:
return 1 return 1
try: try:
return int(os.environ["DJANGO_TEST_PROCESSES"]) return int(os.environ["DJANGO_TEST_PROCESSES"])
@ -391,7 +396,13 @@ def parallel_type(value):
_worker_id = 0 _worker_id = 0
def _init_worker(counter): def _init_worker(
counter,
initial_settings=None,
serialized_contents=None,
process_setup=None,
process_setup_args=None,
):
""" """
Switch to databases dedicated to this worker. Switch to databases dedicated to this worker.
@ -405,9 +416,22 @@ def _init_worker(counter):
counter.value += 1 counter.value += 1
_worker_id = counter.value _worker_id = counter.value
start_method = multiprocessing.get_start_method()
if start_method == "spawn":
process_setup(*process_setup_args)
setup_test_environment()
for alias in connections: for alias in connections:
connection = connections[alias] connection = connections[alias]
if start_method == "spawn":
# Restore initial settings in spawned processes.
connection.settings_dict.update(initial_settings[alias])
if value := serialized_contents.get(alias):
connection._test_serialized_contents = value
connection.creation.setup_worker_connection(_worker_id) connection.creation.setup_worker_connection(_worker_id)
with captured_stdout():
call_command("check", databases=connections)
def _run_subsuite(args): def _run_subsuite(args):
@ -449,6 +473,8 @@ class ParallelTestSuite(unittest.TestSuite):
self.processes = processes self.processes = processes
self.failfast = failfast self.failfast = failfast
self.buffer = buffer self.buffer = buffer
self.initial_settings = None
self.serialized_contents = None
super().__init__() super().__init__()
def run(self, result): def run(self, result):
@ -469,8 +495,12 @@ class ParallelTestSuite(unittest.TestSuite):
counter = multiprocessing.Value(ctypes.c_int, 0) counter = multiprocessing.Value(ctypes.c_int, 0)
pool = multiprocessing.Pool( pool = multiprocessing.Pool(
processes=self.processes, processes=self.processes,
initializer=self.init_worker.__func__, initializer=self.init_worker,
initargs=[counter], initargs=[
counter,
self.initial_settings,
self.serialized_contents,
],
) )
args = [ args = [
(self.runner_class, index, subsuite, self.failfast, self.buffer) (self.runner_class, index, subsuite, self.failfast, self.buffer)
@ -508,6 +538,17 @@ class ParallelTestSuite(unittest.TestSuite):
def __iter__(self): def __iter__(self):
return iter(self.subsuites) return iter(self.subsuites)
def initialize_suite(self):
if multiprocessing.get_start_method() == "spawn":
self.initial_settings = {
alias: connections[alias].settings_dict for alias in connections
}
self.serialized_contents = {
alias: connections[alias]._test_serialized_contents
for alias in connections
if alias in self.serialized_aliases
}
class Shuffler: class Shuffler:
""" """
@ -921,6 +962,8 @@ class DiscoverRunner:
def run_suite(self, suite, **kwargs): def run_suite(self, suite, **kwargs):
kwargs = self.get_test_runner_kwargs() kwargs = self.get_test_runner_kwargs()
runner = self.test_runner(**kwargs) runner = self.test_runner(**kwargs)
if hasattr(suite, "initialize_suite"):
suite.initialize_suite()
try: try:
return runner.run(suite) return runner.run(suite)
finally: finally:
@ -989,13 +1032,13 @@ class DiscoverRunner:
self.setup_test_environment() self.setup_test_environment()
suite = self.build_suite(test_labels, extra_tests) suite = self.build_suite(test_labels, extra_tests)
databases = self.get_databases(suite) databases = self.get_databases(suite)
serialized_aliases = set( suite.serialized_aliases = set(
alias for alias, serialize in databases.items() if serialize alias for alias, serialize in databases.items() if serialize
) )
with self.time_keeper.timed("Total database setup"): with self.time_keeper.timed("Total database setup"):
old_config = self.setup_databases( old_config = self.setup_databases(
aliases=databases, aliases=databases,
serialized_aliases=serialized_aliases, serialized_aliases=suite.serialized_aliases,
) )
run_failed = False run_failed = False
try: try:

View File

@ -130,7 +130,7 @@ def iter_modules_and_files(modules, extra_files):
# cause issues here. # cause issues here.
if not isinstance(module, ModuleType): if not isinstance(module, ModuleType):
continue continue
if module.__name__ == "__main__": if module.__name__ in ("__main__", "__mp_main__"):
# __main__ (usually manage.py) doesn't always have a __spec__ set. # __main__ (usually manage.py) doesn't always have a __spec__ set.
# Handle this by falling back to using __file__, resolved below. # Handle this by falling back to using __file__, resolved below.
# See https://docs.python.org/reference/import.html#main-spec # See https://docs.python.org/reference/import.html#main-spec

View File

@ -70,6 +70,8 @@ class SessionMiddlewareSubclass(SessionMiddleware):
], ],
) )
class SystemChecksTestCase(SimpleTestCase): class SystemChecksTestCase(SimpleTestCase):
databases = "__all__"
def test_checks_are_performed(self): def test_checks_are_performed(self):
admin.site.register(Song, MyAdmin) admin.site.register(Song, MyAdmin)
try: try:

View File

@ -1,7 +1,9 @@
import copy import copy
import multiprocessing
import unittest import unittest
from unittest import mock
from django.db import DEFAULT_DB_ALIAS, connection, connections from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connection, connections
from django.test import SimpleTestCase from django.test import SimpleTestCase
@ -33,3 +35,9 @@ class TestDbSignatureTests(SimpleTestCase):
creation_class = test_connection.creation_class(test_connection) creation_class = test_connection.creation_class(test_connection)
clone_settings_dict = creation_class.get_test_db_clone_settings("1") clone_settings_dict = creation_class.get_test_db_clone_settings("1")
self.assertEqual(clone_settings_dict["NAME"], expected_clone_name) self.assertEqual(clone_settings_dict["NAME"], expected_clone_name)
@mock.patch.object(multiprocessing, "get_start_method", return_value="forkserver")
def test_get_test_db_clone_settings_not_supported(self, *mocked_objects):
msg = "Cloning with start method 'forkserver' is not supported."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.creation.get_test_db_clone_settings(1)

View File

@ -362,5 +362,7 @@ class CheckFrameworkReservedNamesTests(SimpleTestCase):
class ChecksRunDuringTests(SimpleTestCase): class ChecksRunDuringTests(SimpleTestCase):
databases = "__all__"
def test_registered_check_did_run(self): def test_registered_check_did_run(self):
self.assertTrue(my_check.did_run) self.assertTrue(my_check.did_run)

View File

@ -11,6 +11,8 @@ from django.test.utils import isolate_apps
@isolate_apps("contenttypes_tests", attr_name="apps") @isolate_apps("contenttypes_tests", attr_name="apps")
class GenericForeignKeyTests(SimpleTestCase): class GenericForeignKeyTests(SimpleTestCase):
databases = "__all__"
def test_missing_content_type_field(self): def test_missing_content_type_field(self):
class TaggedItem(models.Model): class TaggedItem(models.Model):
# no content_type field # no content_type field

View File

@ -22,6 +22,13 @@ class RemoveStaleContentTypesTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
with captured_stdout():
call_command(
"remove_stale_contenttypes",
interactive=False,
include_stale_apps=True,
verbosity=2,
)
cls.before_count = ContentType.objects.count() cls.before_count = ContentType.objects.count()
cls.content_type = ContentType.objects.create( cls.content_type = ContentType.objects.create(
app_label="contenttypes_tests", model="Fake" app_label="contenttypes_tests", model="Fake"

View File

@ -1,5 +1,7 @@
from datetime import date from datetime import date
from django.test import modify_settings
from . import PostgreSQLTestCase from . import PostgreSQLTestCase
from .models import ( from .models import (
HStoreModel, HStoreModel,
@ -16,6 +18,7 @@ except ImportError:
pass # psycopg2 isn't installed. pass # psycopg2 isn't installed.
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class BulkSaveTests(PostgreSQLTestCase): class BulkSaveTests(PostgreSQLTestCase):
def test_bulk_update(self): def test_bulk_update(self):
test_data = [ test_data = [

View File

@ -3,6 +3,7 @@ import argparse
import atexit import atexit
import copy import copy
import gc import gc
import multiprocessing
import os import os
import shutil import shutil
import socket import socket
@ -10,6 +11,7 @@ import subprocess
import sys import sys
import tempfile import tempfile
import warnings import warnings
from functools import partial
from pathlib import Path from pathlib import Path
try: try:
@ -24,7 +26,7 @@ else:
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import connection, connections from django.db import connection, connections
from django.test import TestCase, TransactionTestCase from django.test import TestCase, TransactionTestCase
from django.test.runner import get_max_test_processes, parallel_type from django.test.runner import _init_worker, get_max_test_processes, parallel_type
from django.test.selenium import SeleniumTestCaseBase from django.test.selenium import SeleniumTestCaseBase
from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner
from django.utils.deprecation import RemovedInDjango50Warning from django.utils.deprecation import RemovedInDjango50Warning
@ -382,7 +384,8 @@ def django_tests(
msg += " with up to %d processes" % max_parallel msg += " with up to %d processes" % max_parallel
print(msg) print(msg)
test_labels, state = setup_run_tests(verbosity, start_at, start_after, test_labels) process_setup_args = (verbosity, start_at, start_after, test_labels)
test_labels, state = setup_run_tests(*process_setup_args)
# Run the test suite, including the extra validation tests. # Run the test suite, including the extra validation tests.
if not hasattr(settings, "TEST_RUNNER"): if not hasattr(settings, "TEST_RUNNER"):
settings.TEST_RUNNER = "django.test.runner.DiscoverRunner" settings.TEST_RUNNER = "django.test.runner.DiscoverRunner"
@ -395,6 +398,11 @@ def django_tests(
parallel = 1 parallel = 1
TestRunner = get_runner(settings) TestRunner = get_runner(settings)
TestRunner.parallel_test_suite.init_worker = partial(
_init_worker,
process_setup=setup_run_tests,
process_setup_args=process_setup_args,
)
test_runner = TestRunner( test_runner = TestRunner(
verbosity=verbosity, verbosity=verbosity,
interactive=interactive, interactive=interactive,
@ -718,6 +726,11 @@ if __name__ == "__main__":
options.settings = os.environ["DJANGO_SETTINGS_MODULE"] options.settings = os.environ["DJANGO_SETTINGS_MODULE"]
if options.selenium: if options.selenium:
if multiprocessing.get_start_method() == "spawn" and options.parallel != 1:
parser.error(
"You cannot use --selenium with parallel tests on this system. "
"Pass --parallel=1 to use --selenium."
)
if not options.tags: if not options.tags:
options.tags = ["selenium"] options.tags = ["selenium"]
elif "selenium" not in options.tags: elif "selenium" not in options.tags:

View File

@ -86,6 +86,16 @@ class DiscoverRunnerParallelArgumentTests(SimpleTestCase):
mocked_cpu_count, mocked_cpu_count,
): ):
mocked_get_start_method.return_value = "spawn" mocked_get_start_method.return_value = "spawn"
self.assertEqual(get_max_test_processes(), 12)
with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}):
self.assertEqual(get_max_test_processes(), 7)
def test_get_max_test_processes_forkserver(
self,
mocked_get_start_method,
mocked_cpu_count,
):
mocked_get_start_method.return_value = "forkserver"
self.assertEqual(get_max_test_processes(), 1) self.assertEqual(get_max_test_processes(), 1)
with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}): with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}):
self.assertEqual(get_max_test_processes(), 1) self.assertEqual(get_max_test_processes(), 1)

View File

@ -480,8 +480,6 @@ class ManageCommandTests(unittest.TestCase):
# Isolate from the real environment. # Isolate from the real environment.
@mock.patch.dict(os.environ, {}, clear=True) @mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(multiprocessing, "cpu_count", return_value=12) @mock.patch.object(multiprocessing, "cpu_count", return_value=12)
# Python 3.8 on macOS defaults to 'spawn' mode.
@mock.patch.object(multiprocessing, "get_start_method", return_value="fork")
class ManageCommandParallelTests(SimpleTestCase): class ManageCommandParallelTests(SimpleTestCase):
def test_parallel_default(self, *mocked_objects): def test_parallel_default(self, *mocked_objects):
with captured_stderr() as stderr: with captured_stderr() as stderr:
@ -507,8 +505,8 @@ class ManageCommandParallelTests(SimpleTestCase):
# Parallel is disabled by default. # Parallel is disabled by default.
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
def test_parallel_spawn(self, mocked_get_start_method, mocked_cpu_count): @mock.patch.object(multiprocessing, "get_start_method", return_value="spawn")
mocked_get_start_method.return_value = "spawn" def test_parallel_spawn(self, *mocked_objects):
with captured_stderr() as stderr: with captured_stderr() as stderr:
call_command( call_command(
"test", "test",
@ -517,8 +515,8 @@ class ManageCommandParallelTests(SimpleTestCase):
) )
self.assertIn("parallel=1", stderr.getvalue()) self.assertIn("parallel=1", stderr.getvalue())
def test_no_parallel_spawn(self, mocked_get_start_method, mocked_cpu_count): @mock.patch.object(multiprocessing, "get_start_method", return_value="spawn")
mocked_get_start_method.return_value = "spawn" def test_no_parallel_spawn(self, *mocked_objects):
with captured_stderr() as stderr: with captured_stderr() as stderr:
call_command( call_command(
"test", "test",