Fixed #29501 -- Allowed dbshell to pass options to underlying tool.
This commit is contained in:
parent
8e8c3f964e
commit
5b884d45ac
|
@ -1,3 +1,5 @@
|
||||||
|
import subprocess
|
||||||
|
|
||||||
from django.core.management.base import BaseCommand, CommandError
|
from django.core.management.base import BaseCommand, CommandError
|
||||||
from django.db import DEFAULT_DB_ALIAS, connections
|
from django.db import DEFAULT_DB_ALIAS, connections
|
||||||
|
|
||||||
|
@ -15,11 +17,13 @@ class Command(BaseCommand):
|
||||||
'--database', default=DEFAULT_DB_ALIAS,
|
'--database', default=DEFAULT_DB_ALIAS,
|
||||||
help='Nominates a database onto which to open a shell. Defaults to the "default" database.',
|
help='Nominates a database onto which to open a shell. Defaults to the "default" database.',
|
||||||
)
|
)
|
||||||
|
parameters = parser.add_argument_group('parameters', prefix_chars='--')
|
||||||
|
parameters.add_argument('parameters', nargs='*')
|
||||||
|
|
||||||
def handle(self, **options):
|
def handle(self, **options):
|
||||||
connection = connections[options['database']]
|
connection = connections[options['database']]
|
||||||
try:
|
try:
|
||||||
connection.client.runshell()
|
connection.client.runshell(options['parameters'])
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# Note that we're assuming the FileNotFoundError relates to the
|
# Note that we're assuming the FileNotFoundError relates to the
|
||||||
# command missing. It could be raised for some other reason, in
|
# command missing. It could be raised for some other reason, in
|
||||||
|
@ -29,3 +33,11 @@ class Command(BaseCommand):
|
||||||
'You appear not to have the %r program installed or on your path.' %
|
'You appear not to have the %r program installed or on your path.' %
|
||||||
connection.client.executable_name
|
connection.client.executable_name
|
||||||
)
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise CommandError(
|
||||||
|
'"%s" returned non-zero exit status %s.' % (
|
||||||
|
' '.join(e.cmd),
|
||||||
|
e.returncode,
|
||||||
|
),
|
||||||
|
returncode=e.returncode,
|
||||||
|
)
|
||||||
|
|
|
@ -8,5 +8,5 @@ class BaseDatabaseClient:
|
||||||
# connection is an instance of BaseDatabaseWrapper.
|
# connection is an instance of BaseDatabaseWrapper.
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
def runshell(self):
|
def runshell(self, parameters):
|
||||||
raise NotImplementedError('subclasses of BaseDatabaseClient must provide a runshell() method')
|
raise NotImplementedError('subclasses of BaseDatabaseClient must provide a runshell() method')
|
||||||
|
|
|
@ -7,7 +7,7 @@ class DatabaseClient(BaseDatabaseClient):
|
||||||
executable_name = 'mysql'
|
executable_name = 'mysql'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def settings_to_cmd_args(cls, settings_dict):
|
def settings_to_cmd_args(cls, settings_dict, parameters):
|
||||||
args = [cls.executable_name]
|
args = [cls.executable_name]
|
||||||
db = settings_dict['OPTIONS'].get('db', settings_dict['NAME'])
|
db = settings_dict['OPTIONS'].get('db', settings_dict['NAME'])
|
||||||
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
|
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
|
||||||
|
@ -41,8 +41,9 @@ class DatabaseClient(BaseDatabaseClient):
|
||||||
args += ["--ssl-key=%s" % client_key]
|
args += ["--ssl-key=%s" % client_key]
|
||||||
if db:
|
if db:
|
||||||
args += [db]
|
args += [db]
|
||||||
|
args.extend(parameters)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def runshell(self):
|
def runshell(self, parameters):
|
||||||
args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)
|
args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, parameters)
|
||||||
subprocess.run(args, check=True)
|
subprocess.run(args, check=True)
|
||||||
|
|
|
@ -55,9 +55,9 @@ class DatabaseCreation(BaseDatabaseCreation):
|
||||||
self._clone_db(source_database_name, target_database_name)
|
self._clone_db(source_database_name, target_database_name)
|
||||||
|
|
||||||
def _clone_db(self, source_database_name, target_database_name):
|
def _clone_db(self, source_database_name, target_database_name):
|
||||||
dump_args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)[1:]
|
dump_args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, [])[1:]
|
||||||
dump_cmd = ['mysqldump', *dump_args[:-1], '--routines', '--events', source_database_name]
|
dump_cmd = ['mysqldump', *dump_args[:-1], '--routines', '--events', source_database_name]
|
||||||
load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)
|
load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, [])
|
||||||
load_cmd[-1] = target_database_name
|
load_cmd[-1] = target_database_name
|
||||||
|
|
||||||
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE) as dump_proc:
|
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE) as dump_proc:
|
||||||
|
|
|
@ -8,10 +8,11 @@ class DatabaseClient(BaseDatabaseClient):
|
||||||
executable_name = 'sqlplus'
|
executable_name = 'sqlplus'
|
||||||
wrapper_name = 'rlwrap'
|
wrapper_name = 'rlwrap'
|
||||||
|
|
||||||
def runshell(self):
|
def runshell(self, parameters):
|
||||||
conn_string = self.connection._connect_string()
|
conn_string = self.connection._connect_string()
|
||||||
args = [self.executable_name, "-L", conn_string]
|
args = [self.executable_name, "-L", conn_string]
|
||||||
wrapper_path = shutil.which(self.wrapper_name)
|
wrapper_path = shutil.which(self.wrapper_name)
|
||||||
if wrapper_path:
|
if wrapper_path:
|
||||||
args = [wrapper_path, *args]
|
args = [wrapper_path, *args]
|
||||||
|
args.extend(parameters)
|
||||||
subprocess.run(args, check=True)
|
subprocess.run(args, check=True)
|
||||||
|
|
|
@ -9,7 +9,7 @@ class DatabaseClient(BaseDatabaseClient):
|
||||||
executable_name = 'psql'
|
executable_name = 'psql'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def runshell_db(cls, conn_params):
|
def runshell_db(cls, conn_params, parameters):
|
||||||
args = [cls.executable_name]
|
args = [cls.executable_name]
|
||||||
|
|
||||||
host = conn_params.get('host', '')
|
host = conn_params.get('host', '')
|
||||||
|
@ -29,6 +29,7 @@ class DatabaseClient(BaseDatabaseClient):
|
||||||
if port:
|
if port:
|
||||||
args += ['-p', str(port)]
|
args += ['-p', str(port)]
|
||||||
args += [dbname]
|
args += [dbname]
|
||||||
|
args.extend(parameters)
|
||||||
|
|
||||||
sigint_handler = signal.getsignal(signal.SIGINT)
|
sigint_handler = signal.getsignal(signal.SIGINT)
|
||||||
subprocess_env = os.environ.copy()
|
subprocess_env = os.environ.copy()
|
||||||
|
@ -50,5 +51,5 @@ class DatabaseClient(BaseDatabaseClient):
|
||||||
# Restore the original SIGINT handler.
|
# Restore the original SIGINT handler.
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
def runshell(self):
|
def runshell(self, parameters):
|
||||||
DatabaseClient.runshell_db(self.connection.get_connection_params())
|
self.runshell_db(self.connection.get_connection_params(), parameters)
|
||||||
|
|
|
@ -6,9 +6,10 @@ from django.db.backends.base.client import BaseDatabaseClient
|
||||||
class DatabaseClient(BaseDatabaseClient):
|
class DatabaseClient(BaseDatabaseClient):
|
||||||
executable_name = 'sqlite3'
|
executable_name = 'sqlite3'
|
||||||
|
|
||||||
def runshell(self):
|
def runshell(self, parameters):
|
||||||
# TODO: Remove str() when dropping support for PY37.
|
# TODO: Remove str() when dropping support for PY37.
|
||||||
# args parameter accepts path-like objects on Windows since Python 3.8.
|
# args parameter accepts path-like objects on Windows since Python 3.8.
|
||||||
args = [self.executable_name,
|
args = [self.executable_name,
|
||||||
str(self.connection.settings_dict['NAME'])]
|
str(self.connection.settings_dict['NAME'])]
|
||||||
|
args.extend(parameters)
|
||||||
subprocess.run(args, check=True)
|
subprocess.run(args, check=True)
|
||||||
|
|
|
@ -226,6 +226,33 @@ program manually.
|
||||||
|
|
||||||
Specifies the database onto which to open a shell. Defaults to ``default``.
|
Specifies the database onto which to open a shell. Defaults to ``default``.
|
||||||
|
|
||||||
|
.. django-admin-option:: -- ARGUMENTS
|
||||||
|
|
||||||
|
.. versionadded:: 3.1
|
||||||
|
|
||||||
|
Any arguments following a ``--`` divider will be passed on to the underlying
|
||||||
|
command-line client. For example, with PostgreSQL you can use the ``psql``
|
||||||
|
command's ``-c`` flag to execute a raw SQL query directly:
|
||||||
|
|
||||||
|
.. console::
|
||||||
|
|
||||||
|
$ django-admin dbshell -- -c 'select current_user'
|
||||||
|
current_user
|
||||||
|
--------------
|
||||||
|
postgres
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
On MySQL/MariaDB, you can do this with the ``mysql`` command's ``-e`` flag:
|
||||||
|
|
||||||
|
.. console::
|
||||||
|
|
||||||
|
$ django-admin dbshell -- -e "select user()"
|
||||||
|
+----------------------+
|
||||||
|
| user() |
|
||||||
|
+----------------------+
|
||||||
|
| djangonaut@localhost |
|
||||||
|
+----------------------+
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Be aware that not all options set in the :setting:`OPTIONS` part of your
|
Be aware that not all options set in the :setting:`OPTIONS` part of your
|
||||||
|
|
|
@ -317,6 +317,9 @@ Management Commands
|
||||||
:attr:`~django.core.management.CommandError` allows customizing the exit
|
:attr:`~django.core.management.CommandError` allows customizing the exit
|
||||||
status for management commands.
|
status for management commands.
|
||||||
|
|
||||||
|
* The new :option:`dbshell -- ARGUMENTS <dbshell -->` option allows passing
|
||||||
|
extra arguments to the command-line client for the database.
|
||||||
|
|
||||||
Migrations
|
Migrations
|
||||||
~~~~~~~~~~
|
~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -505,6 +508,9 @@ backends.
|
||||||
yields a cursor and automatically closes the cursor and connection upon
|
yields a cursor and automatically closes the cursor and connection upon
|
||||||
exiting the ``with`` statement.
|
exiting the ``with`` statement.
|
||||||
|
|
||||||
|
* ``DatabaseClient.runshell()`` now requires an additional ``parameters``
|
||||||
|
argument as a list of extra arguments to pass on to the command-line client.
|
||||||
|
|
||||||
Dropped support for MariaDB 10.1
|
Dropped support for MariaDB 10.1
|
||||||
--------------------------------
|
--------------------------------
|
||||||
|
|
||||||
|
|
|
@ -76,5 +76,23 @@ class MySqlDbshellCommandTestCase(SimpleTestCase):
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
def get_command_line_arguments(self, connection_settings):
|
def test_parameters(self):
|
||||||
return DatabaseClient.settings_to_cmd_args(connection_settings)
|
self.assertEqual(
|
||||||
|
['mysql', 'somedbname', '--help'],
|
||||||
|
self.get_command_line_arguments(
|
||||||
|
{
|
||||||
|
'NAME': 'somedbname',
|
||||||
|
'USER': None,
|
||||||
|
'PASSWORD': None,
|
||||||
|
'HOST': None,
|
||||||
|
'PORT': None,
|
||||||
|
'OPTIONS': {},
|
||||||
|
},
|
||||||
|
['--help'],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_command_line_arguments(self, connection_settings, parameters=None):
|
||||||
|
if parameters is None:
|
||||||
|
parameters = []
|
||||||
|
return DatabaseClient.settings_to_cmd_args(connection_settings, parameters)
|
||||||
|
|
|
@ -8,17 +8,19 @@ from django.test import SimpleTestCase
|
||||||
|
|
||||||
@skipUnless(connection.vendor == 'oracle', 'Oracle tests')
|
@skipUnless(connection.vendor == 'oracle', 'Oracle tests')
|
||||||
class OracleDbshellTests(SimpleTestCase):
|
class OracleDbshellTests(SimpleTestCase):
|
||||||
def _run_dbshell(self, rlwrap=False):
|
def _run_dbshell(self, rlwrap=False, parameters=None):
|
||||||
"""Run runshell command and capture its arguments."""
|
"""Run runshell command and capture its arguments."""
|
||||||
def _mock_subprocess_run(*args, **kwargs):
|
def _mock_subprocess_run(*args, **kwargs):
|
||||||
self.subprocess_args = list(*args)
|
self.subprocess_args = list(*args)
|
||||||
return CompletedProcess(self.subprocess_args, 0)
|
return CompletedProcess(self.subprocess_args, 0)
|
||||||
|
|
||||||
|
if parameters is None:
|
||||||
|
parameters = []
|
||||||
client = DatabaseClient(connection)
|
client = DatabaseClient(connection)
|
||||||
self.subprocess_args = None
|
self.subprocess_args = None
|
||||||
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
||||||
with mock.patch('shutil.which', return_value='/usr/bin/rlwrap' if rlwrap else None):
|
with mock.patch('shutil.which', return_value='/usr/bin/rlwrap' if rlwrap else None):
|
||||||
client.runshell()
|
client.runshell(parameters)
|
||||||
return self.subprocess_args
|
return self.subprocess_args
|
||||||
|
|
||||||
def test_without_rlwrap(self):
|
def test_without_rlwrap(self):
|
||||||
|
@ -32,3 +34,9 @@ class OracleDbshellTests(SimpleTestCase):
|
||||||
self._run_dbshell(rlwrap=True),
|
self._run_dbshell(rlwrap=True),
|
||||||
['/usr/bin/rlwrap', 'sqlplus', '-L', connection._connect_string()],
|
['/usr/bin/rlwrap', 'sqlplus', '-L', connection._connect_string()],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_parameters(self):
|
||||||
|
self.assertEqual(
|
||||||
|
self._run_dbshell(parameters=['-HELP']),
|
||||||
|
['sqlplus', '-L', connection._connect_string(), '-HELP'],
|
||||||
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from django.test import SimpleTestCase
|
||||||
|
|
||||||
class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
|
class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
|
||||||
|
|
||||||
def _run_it(self, dbinfo):
|
def _run_it(self, dbinfo, parameters=None):
|
||||||
"""
|
"""
|
||||||
That function invokes the runshell command, while mocking
|
That function invokes the runshell command, while mocking
|
||||||
subprocess.run(). It returns a 2-tuple with:
|
subprocess.run(). It returns a 2-tuple with:
|
||||||
|
@ -21,8 +21,11 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
|
||||||
# PostgreSQL environment variables.
|
# PostgreSQL environment variables.
|
||||||
self.pg_env = {key: env[key] for key in env if key.startswith('PG')}
|
self.pg_env = {key: env[key] for key in env if key.startswith('PG')}
|
||||||
return subprocess.CompletedProcess(self.subprocess_args, 0)
|
return subprocess.CompletedProcess(self.subprocess_args, 0)
|
||||||
|
|
||||||
|
if parameters is None:
|
||||||
|
parameters = []
|
||||||
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
||||||
DatabaseClient.runshell_db(dbinfo)
|
DatabaseClient.runshell_db(dbinfo, parameters)
|
||||||
return self.subprocess_args, self.pg_env
|
return self.subprocess_args, self.pg_env
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
@ -104,6 +107,12 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_parameters(self):
|
||||||
|
self.assertEqual(
|
||||||
|
self._run_it({'database': 'dbname'}, ['--help']),
|
||||||
|
(['psql', 'dbname', '--help'], {}),
|
||||||
|
)
|
||||||
|
|
||||||
def test_sigint_handler(self):
|
def test_sigint_handler(self):
|
||||||
"""SIGINT is ignored in Python and passed to psql to abort queries."""
|
"""SIGINT is ignored in Python and passed to psql to abort queries."""
|
||||||
def _mock_subprocess_run(*args, **kwargs):
|
def _mock_subprocess_run(*args, **kwargs):
|
||||||
|
@ -114,6 +123,6 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
|
||||||
# The default handler isn't SIG_IGN.
|
# The default handler isn't SIG_IGN.
|
||||||
self.assertNotEqual(sigint_handler, signal.SIG_IGN)
|
self.assertNotEqual(sigint_handler, signal.SIG_IGN)
|
||||||
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
||||||
DatabaseClient.runshell_db({})
|
DatabaseClient.runshell_db({}, [])
|
||||||
# dbshell restores the original handler.
|
# dbshell restores the original handler.
|
||||||
self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))
|
self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))
|
||||||
|
|
|
@ -9,15 +9,17 @@ from django.test import SimpleTestCase
|
||||||
|
|
||||||
@skipUnless(connection.vendor == 'sqlite', 'SQLite tests.')
|
@skipUnless(connection.vendor == 'sqlite', 'SQLite tests.')
|
||||||
class SqliteDbshellCommandTestCase(SimpleTestCase):
|
class SqliteDbshellCommandTestCase(SimpleTestCase):
|
||||||
def _run_dbshell(self):
|
def _run_dbshell(self, parameters=None):
|
||||||
"""Run runshell command and capture its arguments."""
|
"""Run runshell command and capture its arguments."""
|
||||||
def _mock_subprocess_run(*args, **kwargs):
|
def _mock_subprocess_run(*args, **kwargs):
|
||||||
self.subprocess_args = list(*args)
|
self.subprocess_args = list(*args)
|
||||||
return CompletedProcess(self.subprocess_args, 0)
|
return CompletedProcess(self.subprocess_args, 0)
|
||||||
|
|
||||||
|
if parameters is None:
|
||||||
|
parameters = []
|
||||||
client = DatabaseClient(connection)
|
client = DatabaseClient(connection)
|
||||||
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
with mock.patch('subprocess.run', new=_mock_subprocess_run):
|
||||||
client.runshell()
|
client.runshell(parameters)
|
||||||
return self.subprocess_args
|
return self.subprocess_args
|
||||||
|
|
||||||
def test_path_name(self):
|
def test_path_name(self):
|
||||||
|
@ -29,3 +31,13 @@ class SqliteDbshellCommandTestCase(SimpleTestCase):
|
||||||
self._run_dbshell(),
|
self._run_dbshell(),
|
||||||
['sqlite3', 'test.db.sqlite3'],
|
['sqlite3', 'test.db.sqlite3'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_parameters(self):
|
||||||
|
with mock.patch.dict(
|
||||||
|
connection.settings_dict,
|
||||||
|
{'NAME': Path('test.db.sqlite3')},
|
||||||
|
):
|
||||||
|
self.assertEqual(
|
||||||
|
self._run_dbshell(['-help']),
|
||||||
|
['sqlite3', 'test.db.sqlite3', '-help'],
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue