Refs #32061 -- Unified DatabaseClient.runshell() in db backends.

This commit is contained in:
Simon Charette 2020-10-04 18:25:29 -04:00 committed by Mariusz Felisiak
parent 4ac2d4fa42
commit bbe6fbb876
16 changed files with 272 additions and 206 deletions

View File

@ -1,3 +1,7 @@
import os
import subprocess
class BaseDatabaseClient: class BaseDatabaseClient:
"""Encapsulate backend-specific methods for opening a client shell.""" """Encapsulate backend-specific methods for opening a client shell."""
# This should be a string representing the name of the executable # This should be a string representing the name of the executable
@ -8,5 +12,15 @@ class BaseDatabaseClient:
# connection is an instance of BaseDatabaseWrapper. # connection is an instance of BaseDatabaseWrapper.
self.connection = connection self.connection = connection
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
raise NotImplementedError(
'subclasses of BaseDatabaseClient must provide a '
'settings_to_cmd_args_env() method or override a runshell().'
)
def runshell(self, parameters): def runshell(self, parameters):
raise NotImplementedError('subclasses of BaseDatabaseClient must provide a runshell() method') args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters)
if env:
env = {**os.environ, **env}
subprocess.run(args, env=env, check=True)

View File

@ -1,5 +1,3 @@
import subprocess
from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.client import BaseDatabaseClient
@ -7,7 +5,7 @@ class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql' executable_name = 'mysql'
@classmethod @classmethod
def settings_to_cmd_args(cls, settings_dict, parameters): def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name] args = [cls.executable_name]
db = settings_dict['OPTIONS'].get('db', settings_dict['NAME']) db = settings_dict['OPTIONS'].get('db', settings_dict['NAME'])
user = settings_dict['OPTIONS'].get('user', settings_dict['USER']) user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
@ -48,8 +46,4 @@ class DatabaseClient(BaseDatabaseClient):
if db: if db:
args += [db] args += [db]
args.extend(parameters) args.extend(parameters)
return args return args, None
def runshell(self, parameters):
args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, parameters)
subprocess.run(args, check=True)

View File

@ -1,3 +1,4 @@
import os
import subprocess import subprocess
import sys import sys
@ -55,12 +56,13 @@ class DatabaseCreation(BaseDatabaseCreation):
self._clone_db(source_database_name, target_database_name) self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name): def _clone_db(self, source_database_name, target_database_name):
dump_args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, [])[1:] cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, [])
dump_cmd = ['mysqldump', *dump_args[:-1], '--routines', '--events', source_database_name] dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name]
load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, []) dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name load_cmd[-1] = target_database_name
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE) as dump_proc: with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc:
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL): with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env):
# Allow dump_proc to receive a SIGPIPE if the load process exits. # Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close() dump_proc.stdout.close()

View File

@ -56,7 +56,7 @@ from .features import DatabaseFeatures # NOQA isort:skip
from .introspection import DatabaseIntrospection # NOQA isort:skip from .introspection import DatabaseIntrospection # NOQA isort:skip
from .operations import DatabaseOperations # NOQA isort:skip from .operations import DatabaseOperations # NOQA isort:skip
from .schema import DatabaseSchemaEditor # NOQA isort:skip from .schema import DatabaseSchemaEditor # NOQA isort:skip
from .utils import Oracle_datetime # NOQA isort:skip from .utils import dsn, Oracle_datetime # NOQA isort:skip
from .validation import DatabaseValidation # NOQA isort:skip from .validation import DatabaseValidation # NOQA isort:skip
@ -218,17 +218,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
self.features.can_return_columns_from_insert = use_returning_into self.features.can_return_columns_from_insert = use_returning_into
def _dsn(self):
settings_dict = self.settings_dict
if not settings_dict['HOST'].strip():
settings_dict['HOST'] = 'localhost'
if settings_dict['PORT']:
return Database.makedsn(settings_dict['HOST'], int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']
def _connect_string(self):
return '%s/"%s"@%s' % (self.settings_dict['USER'], self.settings_dict['PASSWORD'], self._dsn())
def get_connection_params(self): def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy() conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params: if 'use_returning_into' in conn_params:
@ -240,7 +229,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return Database.connect( return Database.connect(
user=self.settings_dict['USER'], user=self.settings_dict['USER'],
password=self.settings_dict['PASSWORD'], password=self.settings_dict['PASSWORD'],
dsn=self._dsn(), dsn=dsn(self.settings_dict),
**conn_params, **conn_params,
) )

View File

@ -1,5 +1,4 @@
import shutil import shutil
import subprocess
from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.client import BaseDatabaseClient
@ -8,11 +7,21 @@ class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlplus' executable_name = 'sqlplus'
wrapper_name = 'rlwrap' wrapper_name = 'rlwrap'
def runshell(self, parameters): @staticmethod
conn_string = self.connection._connect_string() def connect_string(settings_dict):
args = [self.executable_name, "-L", conn_string] from django.db.backends.oracle.utils import dsn
wrapper_path = shutil.which(self.wrapper_name)
return '%s/"%s"@%s' % (
settings_dict['USER'],
settings_dict['PASSWORD'],
dsn(settings_dict),
)
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, '-L', cls.connect_string(settings_dict)]
wrapper_path = shutil.which(cls.wrapper_name)
if wrapper_path: if wrapper_path:
args = [wrapper_path, *args] args = [wrapper_path, *args]
args.extend(parameters) args.extend(parameters)
subprocess.run(args, check=True) return args, None

View File

@ -82,3 +82,10 @@ class BulkInsertMapper:
'TextField': CLOB, 'TextField': CLOB,
'TimeField': TIMESTAMP, 'TimeField': TIMESTAMP,
} }
def dsn(settings_dict):
if settings_dict['PORT']:
host = settings_dict['HOST'].strip() or 'localhost'
return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']

View File

@ -1,6 +1,4 @@
import os
import signal import signal
import subprocess
from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.client import BaseDatabaseClient
@ -9,18 +7,19 @@ class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql' executable_name = 'psql'
@classmethod @classmethod
def runshell_db(cls, conn_params, parameters): def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name] args = [cls.executable_name]
options = settings_dict.get('OPTIONS', {})
host = conn_params.get('host', '') host = settings_dict.get('HOST')
port = conn_params.get('port', '') port = settings_dict.get('PORT')
dbname = conn_params.get('database', '') dbname = settings_dict.get('NAME') or 'postgres'
user = conn_params.get('user', '') user = settings_dict.get('USER')
passwd = conn_params.get('password', '') passwd = settings_dict.get('PASSWORD')
sslmode = conn_params.get('sslmode', '') sslmode = options.get('sslmode')
sslrootcert = conn_params.get('sslrootcert', '') sslrootcert = options.get('sslrootcert')
sslcert = conn_params.get('sslcert', '') sslcert = options.get('sslcert')
sslkey = conn_params.get('sslkey', '') sslkey = options.get('sslkey')
if user: if user:
args += ['-U', user] args += ['-U', user]
@ -31,25 +30,25 @@ class DatabaseClient(BaseDatabaseClient):
args += [dbname] args += [dbname]
args.extend(parameters) args.extend(parameters)
sigint_handler = signal.getsignal(signal.SIGINT) env = {}
subprocess_env = os.environ.copy()
if passwd: if passwd:
subprocess_env['PGPASSWORD'] = str(passwd) env['PGPASSWORD'] = str(passwd)
if sslmode: if sslmode:
subprocess_env['PGSSLMODE'] = str(sslmode) env['PGSSLMODE'] = str(sslmode)
if sslrootcert: if sslrootcert:
subprocess_env['PGSSLROOTCERT'] = str(sslrootcert) env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert: if sslcert:
subprocess_env['PGSSLCERT'] = str(sslcert) env['PGSSLCERT'] = str(sslcert)
if sslkey: if sslkey:
subprocess_env['PGSSLKEY'] = str(sslkey) env['PGSSLKEY'] = str(sslkey)
return args, env
def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try: try:
# Allow SIGINT to pass to psql to abort queries. # Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
subprocess.run(args, check=True, env=subprocess_env) super().runshell(parameters)
finally: finally:
# Restore the original SIGINT handler. # Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def runshell(self, parameters):
self.runshell_db(self.connection.get_connection_params(), parameters)

View File

@ -1,15 +1,16 @@
import subprocess
from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlite3' executable_name = 'sqlite3'
def runshell(self, parameters): @classmethod
# TODO: Remove str() when dropping support for PY37. def settings_to_cmd_args_env(cls, settings_dict, parameters):
# args parameter accepts path-like objects on Windows since Python 3.8. args = [
args = [self.executable_name, cls.executable_name,
str(self.connection.settings_dict['NAME'])] # TODO: Remove str() when dropping support for PY37. args
args.extend(parameters) # parameter accepts path-like objects on Windows since Python 3.8.
subprocess.run(args, check=True) str(settings_dict['NAME']),
*parameters,
]
return args, None

View File

@ -477,6 +477,12 @@ backends.
``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname`` ``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname``
argument in order to truncate in a specific timezone. argument in order to truncate in a specific timezone.
* ``DatabaseClient.runshell()`` now gets arguments and an optional dictionary
with environment variables to the underlying command-line client from
``DatabaseClient.settings_to_cmd_args_env()`` method. Third-party database
backends must implement ``DatabaseClient.settings_to_cmd_args_env()`` or
override ``DatabaseClient.runshell()``.
:mod:`django.contrib.admin` :mod:`django.contrib.admin`
--------------------------- ---------------------------

View File

@ -0,0 +1,16 @@
from django.db import connection
from django.db.backends.base.client import BaseDatabaseClient
from django.test import SimpleTestCase
class SimpleDatabaseClientTests(SimpleTestCase):
def setUp(self):
self.client = BaseDatabaseClient(connection=connection)
def test_settings_to_cmd_args_env(self):
msg = (
'subclasses of BaseDatabaseClient must provide a '
'settings_to_cmd_args_env() method or override a runshell().'
)
with self.assertRaisesMessage(NotImplementedError, msg):
self.client.settings_to_cmd_args_env(None, None)

View File

@ -78,6 +78,7 @@ class DatabaseCreationTests(SimpleTestCase):
'source_db', 'source_db',
], ],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
env=None,
), ),
]) ])
finally: finally:

View File

@ -86,7 +86,10 @@ class TransactionalTests(TransactionTestCase):
old_password = connection.settings_dict['PASSWORD'] old_password = connection.settings_dict['PASSWORD']
connection.settings_dict['PASSWORD'] = 'p@ssword' connection.settings_dict['PASSWORD'] = 'p@ssword'
try: try:
self.assertIn('/"p@ssword"@', connection._connect_string()) self.assertIn(
'/"p@ssword"@',
connection.client.connect_string(connection.settings_dict),
)
with self.assertRaises(DatabaseError) as context: with self.assertRaises(DatabaseError) as context:
connection.cursor() connection.cursor()
# Database exception: "ORA-01017: invalid username/password" is # Database exception: "ORA-01017: invalid username/password" is

View File

@ -3,32 +3,52 @@ from django.test import SimpleTestCase
class MySqlDbshellCommandTestCase(SimpleTestCase): class MySqlDbshellCommandTestCase(SimpleTestCase):
def settings_to_cmd_args_env(self, settings_dict, parameters=None):
if parameters is None:
parameters = []
return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters)
def test_fails_with_keyerror_on_incomplete_config(self): def test_fails_with_keyerror_on_incomplete_config(self):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
self.get_command_line_arguments({}) self.settings_to_cmd_args_env({})
def test_basic_params_specified_in_settings(self): def test_basic_params_specified_in_settings(self):
expected_args = [
'mysql',
'--user=someuser',
'--password=somepassword',
'--host=somehost',
'--port=444',
'somedbname',
]
expected_env = None
self.assertEqual( self.assertEqual(
['mysql', '--user=someuser', '--password=somepassword', self.settings_to_cmd_args_env({
'--host=somehost', '--port=444', 'somedbname'],
self.get_command_line_arguments({
'NAME': 'somedbname', 'NAME': 'somedbname',
'USER': 'someuser', 'USER': 'someuser',
'PASSWORD': 'somepassword', 'PASSWORD': 'somepassword',
'HOST': 'somehost', 'HOST': 'somehost',
'PORT': 444, 'PORT': 444,
'OPTIONS': {}, 'OPTIONS': {},
})) }),
(expected_args, expected_env),
)
def test_options_override_settings_proper_values(self): def test_options_override_settings_proper_values(self):
settings_port = 444 settings_port = 444
options_port = 555 options_port = 555
self.assertNotEqual(settings_port, options_port, 'test pre-req') self.assertNotEqual(settings_port, options_port, 'test pre-req')
expected_args = [
'mysql',
'--user=optionuser',
'--password=optionpassword',
'--host=optionhost',
'--port=%s' % options_port,
'optiondbname',
]
expected_env = None
self.assertEqual( self.assertEqual(
['mysql', '--user=optionuser', '--password=optionpassword', self.settings_to_cmd_args_env({
'--host=optionhost', '--port={}'.format(options_port), 'optiondbname'],
self.get_command_line_arguments({
'NAME': 'settingdbname', 'NAME': 'settingdbname',
'USER': 'settinguser', 'USER': 'settinguser',
'PASSWORD': 'settingpassword', 'PASSWORD': 'settingpassword',
@ -41,15 +61,22 @@ class MySqlDbshellCommandTestCase(SimpleTestCase):
'host': 'optionhost', 'host': 'optionhost',
'port': options_port, 'port': options_port,
}, },
})) }),
(expected_args, expected_env),
)
def test_options_password(self): def test_options_password(self):
expected_args = [
'mysql',
'--user=someuser',
'--password=optionpassword',
'--host=somehost',
'--port=444',
'somedbname',
]
expected_env = None
self.assertEqual( self.assertEqual(
[ self.settings_to_cmd_args_env({
'mysql', '--user=someuser', '--password=optionpassword',
'--host=somehost', '--port=444', 'somedbname',
],
self.get_command_line_arguments({
'NAME': 'somedbname', 'NAME': 'somedbname',
'USER': 'someuser', 'USER': 'someuser',
'PASSWORD': 'settingpassword', 'PASSWORD': 'settingpassword',
@ -57,16 +84,22 @@ class MySqlDbshellCommandTestCase(SimpleTestCase):
'PORT': 444, 'PORT': 444,
'OPTIONS': {'password': 'optionpassword'}, 'OPTIONS': {'password': 'optionpassword'},
}), }),
(expected_args, expected_env),
) )
def test_options_charset(self): def test_options_charset(self):
expected_args = [
'mysql',
'--user=someuser',
'--password=somepassword',
'--host=somehost',
'--port=444',
'--default-character-set=utf8',
'somedbname',
]
expected_env = None
self.assertEqual( self.assertEqual(
[ self.settings_to_cmd_args_env({
'mysql', '--user=someuser', '--password=somepassword',
'--host=somehost', '--port=444',
'--default-character-set=utf8', 'somedbname',
],
self.get_command_line_arguments({
'NAME': 'somedbname', 'NAME': 'somedbname',
'USER': 'someuser', 'USER': 'someuser',
'PASSWORD': 'somepassword', 'PASSWORD': 'somepassword',
@ -74,27 +107,45 @@ class MySqlDbshellCommandTestCase(SimpleTestCase):
'PORT': 444, 'PORT': 444,
'OPTIONS': {'charset': 'utf8'}, 'OPTIONS': {'charset': 'utf8'},
}), }),
(expected_args, expected_env),
) )
def test_can_connect_using_sockets(self): def test_can_connect_using_sockets(self):
expected_args = [
'mysql',
'--user=someuser',
'--password=somepassword',
'--socket=/path/to/mysql.socket.file',
'somedbname',
]
expected_env = None
self.assertEqual( self.assertEqual(
['mysql', '--user=someuser', '--password=somepassword', self.settings_to_cmd_args_env({
'--socket=/path/to/mysql.socket.file', 'somedbname'],
self.get_command_line_arguments({
'NAME': 'somedbname', 'NAME': 'somedbname',
'USER': 'someuser', 'USER': 'someuser',
'PASSWORD': 'somepassword', 'PASSWORD': 'somepassword',
'HOST': '/path/to/mysql.socket.file', 'HOST': '/path/to/mysql.socket.file',
'PORT': None, 'PORT': None,
'OPTIONS': {}, 'OPTIONS': {},
})) }),
(expected_args, expected_env),
)
def test_ssl_certificate_is_added(self): def test_ssl_certificate_is_added(self):
expected_args = [
'mysql',
'--user=someuser',
'--password=somepassword',
'--host=somehost',
'--port=444',
'--ssl-ca=sslca',
'--ssl-cert=sslcert',
'--ssl-key=sslkey',
'somedbname',
]
expected_env = None
self.assertEqual( self.assertEqual(
['mysql', '--user=someuser', '--password=somepassword', self.settings_to_cmd_args_env({
'--host=somehost', '--port=444', '--ssl-ca=sslca',
'--ssl-cert=sslcert', '--ssl-key=sslkey', 'somedbname'],
self.get_command_line_arguments({
'NAME': 'somedbname', 'NAME': 'somedbname',
'USER': 'someuser', 'USER': 'someuser',
'PASSWORD': 'somepassword', 'PASSWORD': 'somepassword',
@ -107,12 +158,13 @@ class MySqlDbshellCommandTestCase(SimpleTestCase):
'key': 'sslkey', 'key': 'sslkey',
}, },
}, },
})) }),
(expected_args, expected_env),
)
def test_parameters(self): def test_parameters(self):
self.assertEqual( self.assertEqual(
['mysql', 'somedbname', '--help'], self.settings_to_cmd_args_env(
self.get_command_line_arguments(
{ {
'NAME': 'somedbname', 'NAME': 'somedbname',
'USER': None, 'USER': None,
@ -123,9 +175,5 @@ class MySqlDbshellCommandTestCase(SimpleTestCase):
}, },
['--help'], ['--help'],
), ),
(['mysql', 'somedbname', '--help'], None),
) )
def get_command_line_arguments(self, connection_settings, parameters=None):
if parameters is None:
parameters = []
return DatabaseClient.settings_to_cmd_args(connection_settings, parameters)

View File

@ -1,4 +1,3 @@
from subprocess import CompletedProcess
from unittest import mock, skipUnless from unittest import mock, skipUnless
from django.db import connection from django.db import connection
@ -6,37 +5,48 @@ from django.db.backends.oracle.client import DatabaseClient
from django.test import SimpleTestCase from django.test import SimpleTestCase
@skipUnless(connection.vendor == 'oracle', 'Oracle tests') @skipUnless(connection.vendor == 'oracle', 'Requires cx_Oracle to be installed')
class OracleDbshellTests(SimpleTestCase): class OracleDbshellTests(SimpleTestCase):
def _run_dbshell(self, rlwrap=False, parameters=None): def settings_to_cmd_args_env(self, settings_dict, parameters=None, rlwrap=False):
"""Run runshell command and capture its arguments."""
def _mock_subprocess_run(*args, **kwargs):
self.subprocess_args = list(*args)
return CompletedProcess(self.subprocess_args, 0)
if parameters is None: if parameters is None:
parameters = [] parameters = []
client = DatabaseClient(connection) with mock.patch('shutil.which', return_value='/usr/bin/rlwrap' if rlwrap else None):
self.subprocess_args = None return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters)
with mock.patch('subprocess.run', new=_mock_subprocess_run):
with mock.patch('shutil.which', return_value='/usr/bin/rlwrap' if rlwrap else None):
client.runshell(parameters)
return self.subprocess_args
def test_without_rlwrap(self): def test_without_rlwrap(self):
expected_args = [
'sqlplus',
'-L',
connection.client.connect_string(connection.settings_dict),
]
self.assertEqual( self.assertEqual(
self._run_dbshell(rlwrap=False), self.settings_to_cmd_args_env(connection.settings_dict, rlwrap=False),
['sqlplus', '-L', connection._connect_string()], (expected_args, None),
) )
def test_with_rlwrap(self): def test_with_rlwrap(self):
expected_args = [
'/usr/bin/rlwrap',
'sqlplus',
'-L',
connection.client.connect_string(connection.settings_dict),
]
self.assertEqual( self.assertEqual(
self._run_dbshell(rlwrap=True), self.settings_to_cmd_args_env(connection.settings_dict, rlwrap=True),
['/usr/bin/rlwrap', 'sqlplus', '-L', connection._connect_string()], (expected_args, None),
) )
def test_parameters(self): def test_parameters(self):
expected_args = [
'sqlplus',
'-L',
connection.client.connect_string(connection.settings_dict),
'-HELP',
]
self.assertEqual( self.assertEqual(
self._run_dbshell(parameters=['-HELP']), self.settings_to_cmd_args_env(
['sqlplus', '-L', connection._connect_string(), '-HELP'], connection.settings_dict,
parameters=['-HELP'],
),
(expected_args, None),
) )

View File

@ -1,41 +1,25 @@
import os
import signal import signal
import subprocess from unittest import mock, skipUnless
from unittest import mock
from django.db import connection
from django.db.backends.postgresql.client import DatabaseClient from django.db.backends.postgresql.client import DatabaseClient
from django.test import SimpleTestCase from django.test import SimpleTestCase
class PostgreSqlDbshellCommandTestCase(SimpleTestCase): class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
def settings_to_cmd_args_env(self, settings_dict, parameters=None):
def _run_it(self, dbinfo, parameters=None):
"""
That function invokes the runshell command, while mocking
subprocess.run(). It returns a 2-tuple with:
- The command line list
- The dictionary of PG* environment variables, or {}.
"""
def _mock_subprocess_run(*args, env=os.environ, **kwargs):
self.subprocess_args = list(*args)
# PostgreSQL environment variables.
self.pg_env = {key: env[key] for key in env if key.startswith('PG')}
return subprocess.CompletedProcess(self.subprocess_args, 0)
if parameters is None: if parameters is None:
parameters = [] parameters = []
with mock.patch('subprocess.run', new=_mock_subprocess_run): return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters)
DatabaseClient.runshell_db(dbinfo, parameters)
return self.subprocess_args, self.pg_env
def test_basic(self): def test_basic(self):
self.assertEqual( self.assertEqual(
self._run_it({ self.settings_to_cmd_args_env({
'database': 'dbname', 'NAME': 'dbname',
'user': 'someuser', 'USER': 'someuser',
'password': 'somepassword', 'PASSWORD': 'somepassword',
'host': 'somehost', 'HOST': 'somehost',
'port': '444', 'PORT': '444',
}), ( }), (
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
{'PGPASSWORD': 'somepassword'}, {'PGPASSWORD': 'somepassword'},
@ -44,11 +28,11 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
def test_nopass(self): def test_nopass(self):
self.assertEqual( self.assertEqual(
self._run_it({ self.settings_to_cmd_args_env({
'database': 'dbname', 'NAME': 'dbname',
'user': 'someuser', 'USER': 'someuser',
'host': 'somehost', 'HOST': 'somehost',
'port': '444', 'PORT': '444',
}), ( }), (
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
{}, {},
@ -57,15 +41,17 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
def test_ssl_certificate(self): def test_ssl_certificate(self):
self.assertEqual( self.assertEqual(
self._run_it({ self.settings_to_cmd_args_env({
'database': 'dbname', 'NAME': 'dbname',
'user': 'someuser', 'USER': 'someuser',
'host': 'somehost', 'HOST': 'somehost',
'port': '444', 'PORT': '444',
'sslmode': 'verify-ca', 'OPTIONS': {
'sslrootcert': 'root.crt', 'sslmode': 'verify-ca',
'sslcert': 'client.crt', 'sslrootcert': 'root.crt',
'sslkey': 'client.key', 'sslcert': 'client.crt',
'sslkey': 'client.key',
},
}), ( }), (
['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'],
{ {
@ -79,12 +65,12 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
def test_column(self): def test_column(self):
self.assertEqual( self.assertEqual(
self._run_it({ self.settings_to_cmd_args_env({
'database': 'dbname', 'NAME': 'dbname',
'user': 'some:user', 'USER': 'some:user',
'password': 'some:password', 'PASSWORD': 'some:password',
'host': '::1', 'HOST': '::1',
'port': '444', 'PORT': '444',
}), ( }), (
['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'],
{'PGPASSWORD': 'some:password'}, {'PGPASSWORD': 'some:password'},
@ -95,12 +81,12 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
username = 'rôle' username = 'rôle'
password = 'sésame' password = 'sésame'
self.assertEqual( self.assertEqual(
self._run_it({ self.settings_to_cmd_args_env({
'database': 'dbname', 'NAME': 'dbname',
'user': username, 'USER': username,
'password': password, 'PASSWORD': password,
'host': 'somehost', 'HOST': 'somehost',
'port': '444', 'PORT': '444',
}), ( }), (
['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'], ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'],
{'PGPASSWORD': password}, {'PGPASSWORD': password},
@ -109,10 +95,11 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
def test_parameters(self): def test_parameters(self):
self.assertEqual( self.assertEqual(
self._run_it({'database': 'dbname'}, ['--help']), self.settings_to_cmd_args_env({'NAME': 'dbname'}, ['--help']),
(['psql', 'dbname', '--help'], {}), (['psql', 'dbname', '--help'], {}),
) )
@skipUnless(connection.vendor == 'postgresql', 'Requires a PostgreSQL connection')
def test_sigint_handler(self): def test_sigint_handler(self):
"""SIGINT is ignored in Python and passed to psql to abort queries.""" """SIGINT is ignored in Python and passed to psql to abort queries."""
def _mock_subprocess_run(*args, **kwargs): def _mock_subprocess_run(*args, **kwargs):
@ -123,6 +110,6 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
# The default handler isn't SIG_IGN. # The default handler isn't SIG_IGN.
self.assertNotEqual(sigint_handler, signal.SIG_IGN) self.assertNotEqual(sigint_handler, signal.SIG_IGN)
with mock.patch('subprocess.run', new=_mock_subprocess_run): with mock.patch('subprocess.run', new=_mock_subprocess_run):
DatabaseClient.runshell_db({}, []) connection.client.runshell([])
# dbshell restores the original handler. # dbshell restores the original handler.
self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT)) self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))

View File

@ -1,43 +1,23 @@
from pathlib import Path from pathlib import Path
from subprocess import CompletedProcess
from unittest import mock, skipUnless
from django.db import connection
from django.db.backends.sqlite3.client import DatabaseClient from django.db.backends.sqlite3.client import DatabaseClient
from django.test import SimpleTestCase from django.test import SimpleTestCase
@skipUnless(connection.vendor == 'sqlite', 'SQLite tests.')
class SqliteDbshellCommandTestCase(SimpleTestCase): class SqliteDbshellCommandTestCase(SimpleTestCase):
def _run_dbshell(self, parameters=None): def settings_to_cmd_args_env(self, settings_dict, parameters=None):
"""Run runshell command and capture its arguments."""
def _mock_subprocess_run(*args, **kwargs):
self.subprocess_args = list(*args)
return CompletedProcess(self.subprocess_args, 0)
if parameters is None: if parameters is None:
parameters = [] parameters = []
client = DatabaseClient(connection) return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters)
with mock.patch('subprocess.run', new=_mock_subprocess_run):
client.runshell(parameters)
return self.subprocess_args
def test_path_name(self): def test_path_name(self):
with mock.patch.dict( self.assertEqual(
connection.settings_dict, self.settings_to_cmd_args_env({'NAME': Path('test.db.sqlite3')}),
{'NAME': Path('test.db.sqlite3')}, (['sqlite3', 'test.db.sqlite3'], None),
): )
self.assertEqual(
self._run_dbshell(),
['sqlite3', 'test.db.sqlite3'],
)
def test_parameters(self): def test_parameters(self):
with mock.patch.dict( self.assertEqual(
connection.settings_dict, self.settings_to_cmd_args_env({'NAME': 'test.db.sqlite3'}, ['-help']),
{'NAME': Path('test.db.sqlite3')}, (['sqlite3', 'test.db.sqlite3', '-help'], None),
): )
self.assertEqual(
self._run_dbshell(['-help']),
['sqlite3', 'test.db.sqlite3', '-help'],
)