diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index 0c09a2ca1e..6aa11b2e1f 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -1,3 +1,5 @@ +import signal + from django.db.backends.base.client import BaseDatabaseClient @@ -58,3 +60,13 @@ class DatabaseClient(BaseDatabaseClient): args += [database] args.extend(parameters) return args, env + + def runshell(self, parameters): + sigint_handler = signal.getsignal(signal.SIGINT) + try: + # Allow SIGINT to pass to mysql to abort queries. + signal.signal(signal.SIGINT, signal.SIG_IGN) + super().runshell(parameters) + finally: + # Restore the original SIGINT handler. + signal.signal(signal.SIGINT, sigint_handler) diff --git a/tests/dbshell/test_mysql.py b/tests/dbshell/test_mysql.py index 28410b05e1..13007ec037 100644 --- a/tests/dbshell/test_mysql.py +++ b/tests/dbshell/test_mysql.py @@ -1,8 +1,11 @@ import os +import signal import subprocess import sys from pathlib import Path +from unittest import mock, skipUnless +from django.db import connection from django.db.backends.mysql.client import DatabaseClient from django.test import SimpleTestCase @@ -218,3 +221,19 @@ class MySqlDbshellCommandTestCase(SimpleTestCase): with self.assertRaises(subprocess.CalledProcessError) as ctx: subprocess.run(args, check=True, env=env) self.assertNotIn("somepassword", str(ctx.exception)) + + @skipUnless(connection.vendor == "mysql", "Requires a MySQL connection") + def test_sigint_handler(self): + """SIGINT is ignored in Python and passed to mysql to abort queries.""" + + def _mock_subprocess_run(*args, **kwargs): + handler = signal.getsignal(signal.SIGINT) + self.assertEqual(handler, signal.SIG_IGN) + + sigint_handler = signal.getsignal(signal.SIGINT) + # The default handler isn't SIG_IGN. + self.assertNotEqual(sigint_handler, signal.SIG_IGN) + with mock.patch("subprocess.run", new=_mock_subprocess_run): + connection.client.runshell([]) + # dbshell restores the original handler. + self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))