From bbe6fbb8768e8fb1aecb96d51c049d7ceaf802d3 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sun, 4 Oct 2020 18:25:29 -0400 Subject: [PATCH] Refs #32061 -- Unified DatabaseClient.runshell() in db backends. --- django/db/backends/base/client.py | 16 +++- django/db/backends/mysql/client.py | 10 +- django/db/backends/mysql/creation.py | 12 ++- django/db/backends/oracle/base.py | 15 +-- django/db/backends/oracle/client.py | 21 +++-- django/db/backends/oracle/utils.py | 7 ++ django/db/backends/postgresql/client.py | 45 +++++---- django/db/backends/sqlite3/client.py | 19 ++-- docs/releases/3.2.txt | 6 ++ tests/backends/base/test_client.py | 16 ++++ tests/backends/mysql/test_creation.py | 1 + tests/backends/oracle/tests.py | 5 +- tests/dbshell/test_mysql.py | 120 +++++++++++++++++------- tests/dbshell/test_oracle.py | 50 ++++++---- tests/dbshell/test_postgresql.py | 95 ++++++++----------- tests/dbshell/test_sqlite.py | 40 ++------ 16 files changed, 272 insertions(+), 206 deletions(-) create mode 100644 tests/backends/base/test_client.py diff --git a/django/db/backends/base/client.py b/django/db/backends/base/client.py index 32901764aa..339f1e863c 100644 --- a/django/db/backends/base/client.py +++ b/django/db/backends/base/client.py @@ -1,3 +1,7 @@ +import os +import subprocess + + class BaseDatabaseClient: """Encapsulate backend-specific methods for opening a client shell.""" # This should be a string representing the name of the executable @@ -8,5 +12,15 @@ class BaseDatabaseClient: # connection is an instance of BaseDatabaseWrapper. self.connection = connection + @classmethod + def settings_to_cmd_args_env(cls, settings_dict, parameters): + raise NotImplementedError( + 'subclasses of BaseDatabaseClient must provide a ' + 'settings_to_cmd_args_env() method or override a runshell().' + ) + def runshell(self, parameters): - raise NotImplementedError('subclasses of BaseDatabaseClient must provide a runshell() method') + args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters) + if env: + env = {**os.environ, **env} + subprocess.run(args, env=env, check=True) diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index a94d388ccd..79032c1207 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -1,5 +1,3 @@ -import subprocess - from django.db.backends.base.client import BaseDatabaseClient @@ -7,7 +5,7 @@ class DatabaseClient(BaseDatabaseClient): executable_name = 'mysql' @classmethod - def settings_to_cmd_args(cls, settings_dict, parameters): + def settings_to_cmd_args_env(cls, settings_dict, parameters): args = [cls.executable_name] db = settings_dict['OPTIONS'].get('db', settings_dict['NAME']) user = settings_dict['OPTIONS'].get('user', settings_dict['USER']) @@ -48,8 +46,4 @@ class DatabaseClient(BaseDatabaseClient): if db: args += [db] args.extend(parameters) - return args - - def runshell(self, parameters): - args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, parameters) - subprocess.run(args, check=True) + return args, None diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py index 99372fd7ee..1f0261b667 100644 --- a/django/db/backends/mysql/creation.py +++ b/django/db/backends/mysql/creation.py @@ -1,3 +1,4 @@ +import os import subprocess import sys @@ -55,12 +56,13 @@ class DatabaseCreation(BaseDatabaseCreation): self._clone_db(source_database_name, target_database_name) def _clone_db(self, source_database_name, target_database_name): - dump_args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, [])[1:] - dump_cmd = ['mysqldump', *dump_args[:-1], '--routines', '--events', source_database_name] - load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, []) + cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, []) + dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name] + dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None + load_cmd = cmd_args load_cmd[-1] = target_database_name - with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE) as dump_proc: - with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL): + with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc: + with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env): # Allow dump_proc to receive a SIGPIPE if the load process exits. dump_proc.stdout.close() diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 10fff26fec..781f123a82 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -56,7 +56,7 @@ from .features import DatabaseFeatures # NOQA isort:skip from .introspection import DatabaseIntrospection # NOQA isort:skip from .operations import DatabaseOperations # NOQA isort:skip from .schema import DatabaseSchemaEditor # NOQA isort:skip -from .utils import Oracle_datetime # NOQA isort:skip +from .utils import dsn, Oracle_datetime # NOQA isort:skip from .validation import DatabaseValidation # NOQA isort:skip @@ -218,17 +218,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) self.features.can_return_columns_from_insert = use_returning_into - def _dsn(self): - settings_dict = self.settings_dict - if not settings_dict['HOST'].strip(): - settings_dict['HOST'] = 'localhost' - if settings_dict['PORT']: - return Database.makedsn(settings_dict['HOST'], int(settings_dict['PORT']), settings_dict['NAME']) - return settings_dict['NAME'] - - def _connect_string(self): - return '%s/"%s"@%s' % (self.settings_dict['USER'], self.settings_dict['PASSWORD'], self._dsn()) - def get_connection_params(self): conn_params = self.settings_dict['OPTIONS'].copy() if 'use_returning_into' in conn_params: @@ -240,7 +229,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): return Database.connect( user=self.settings_dict['USER'], password=self.settings_dict['PASSWORD'], - dsn=self._dsn(), + dsn=dsn(self.settings_dict), **conn_params, ) diff --git a/django/db/backends/oracle/client.py b/django/db/backends/oracle/client.py index 434481ba81..9920f4ca67 100644 --- a/django/db/backends/oracle/client.py +++ b/django/db/backends/oracle/client.py @@ -1,5 +1,4 @@ import shutil -import subprocess from django.db.backends.base.client import BaseDatabaseClient @@ -8,11 +7,21 @@ class DatabaseClient(BaseDatabaseClient): executable_name = 'sqlplus' wrapper_name = 'rlwrap' - def runshell(self, parameters): - conn_string = self.connection._connect_string() - args = [self.executable_name, "-L", conn_string] - wrapper_path = shutil.which(self.wrapper_name) + @staticmethod + def connect_string(settings_dict): + from django.db.backends.oracle.utils import dsn + + return '%s/"%s"@%s' % ( + settings_dict['USER'], + settings_dict['PASSWORD'], + dsn(settings_dict), + ) + + @classmethod + def settings_to_cmd_args_env(cls, settings_dict, parameters): + args = [cls.executable_name, '-L', cls.connect_string(settings_dict)] + wrapper_path = shutil.which(cls.wrapper_name) if wrapper_path: args = [wrapper_path, *args] args.extend(parameters) - subprocess.run(args, check=True) + return args, None diff --git a/django/db/backends/oracle/utils.py b/django/db/backends/oracle/utils.py index 62285f5ee4..5665079aa2 100644 --- a/django/db/backends/oracle/utils.py +++ b/django/db/backends/oracle/utils.py @@ -82,3 +82,10 @@ class BulkInsertMapper: 'TextField': CLOB, 'TimeField': TIMESTAMP, } + + +def dsn(settings_dict): + if settings_dict['PORT']: + host = settings_dict['HOST'].strip() or 'localhost' + return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME']) + return settings_dict['NAME'] diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index 9d390b3807..7965401163 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -1,6 +1,4 @@ -import os import signal -import subprocess from django.db.backends.base.client import BaseDatabaseClient @@ -9,18 +7,19 @@ class DatabaseClient(BaseDatabaseClient): executable_name = 'psql' @classmethod - def runshell_db(cls, conn_params, parameters): + def settings_to_cmd_args_env(cls, settings_dict, parameters): args = [cls.executable_name] + options = settings_dict.get('OPTIONS', {}) - host = conn_params.get('host', '') - port = conn_params.get('port', '') - dbname = conn_params.get('database', '') - user = conn_params.get('user', '') - passwd = conn_params.get('password', '') - sslmode = conn_params.get('sslmode', '') - sslrootcert = conn_params.get('sslrootcert', '') - sslcert = conn_params.get('sslcert', '') - sslkey = conn_params.get('sslkey', '') + host = settings_dict.get('HOST') + port = settings_dict.get('PORT') + dbname = settings_dict.get('NAME') or 'postgres' + user = settings_dict.get('USER') + passwd = settings_dict.get('PASSWORD') + sslmode = options.get('sslmode') + sslrootcert = options.get('sslrootcert') + sslcert = options.get('sslcert') + sslkey = options.get('sslkey') if user: args += ['-U', user] @@ -31,25 +30,25 @@ class DatabaseClient(BaseDatabaseClient): args += [dbname] args.extend(parameters) - sigint_handler = signal.getsignal(signal.SIGINT) - subprocess_env = os.environ.copy() + env = {} if passwd: - subprocess_env['PGPASSWORD'] = str(passwd) + env['PGPASSWORD'] = str(passwd) if sslmode: - subprocess_env['PGSSLMODE'] = str(sslmode) + env['PGSSLMODE'] = str(sslmode) if sslrootcert: - subprocess_env['PGSSLROOTCERT'] = str(sslrootcert) + env['PGSSLROOTCERT'] = str(sslrootcert) if sslcert: - subprocess_env['PGSSLCERT'] = str(sslcert) + env['PGSSLCERT'] = str(sslcert) if sslkey: - subprocess_env['PGSSLKEY'] = str(sslkey) + env['PGSSLKEY'] = str(sslkey) + return args, env + + def runshell(self, parameters): + sigint_handler = signal.getsignal(signal.SIGINT) try: # Allow SIGINT to pass to psql to abort queries. signal.signal(signal.SIGINT, signal.SIG_IGN) - subprocess.run(args, check=True, env=subprocess_env) + super().runshell(parameters) finally: # Restore the original SIGINT handler. signal.signal(signal.SIGINT, sigint_handler) - - def runshell(self, parameters): - self.runshell_db(self.connection.get_connection_params(), parameters) diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py index a71005fd5b..59a2fe7f50 100644 --- a/django/db/backends/sqlite3/client.py +++ b/django/db/backends/sqlite3/client.py @@ -1,15 +1,16 @@ -import subprocess - from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): executable_name = 'sqlite3' - def runshell(self, parameters): - # TODO: Remove str() when dropping support for PY37. - # args parameter accepts path-like objects on Windows since Python 3.8. - args = [self.executable_name, - str(self.connection.settings_dict['NAME'])] - args.extend(parameters) - subprocess.run(args, check=True) + @classmethod + def settings_to_cmd_args_env(cls, settings_dict, parameters): + args = [ + cls.executable_name, + # TODO: Remove str() when dropping support for PY37. args + # parameter accepts path-like objects on Windows since Python 3.8. + str(settings_dict['NAME']), + *parameters, + ] + return args, None diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index c765c4ed63..ca427c3bc5 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -477,6 +477,12 @@ backends. ``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname`` argument in order to truncate in a specific timezone. +* ``DatabaseClient.runshell()`` now gets arguments and an optional dictionary + with environment variables to the underlying command-line client from + ``DatabaseClient.settings_to_cmd_args_env()`` method. Third-party database + backends must implement ``DatabaseClient.settings_to_cmd_args_env()`` or + override ``DatabaseClient.runshell()``. + :mod:`django.contrib.admin` --------------------------- diff --git a/tests/backends/base/test_client.py b/tests/backends/base/test_client.py new file mode 100644 index 0000000000..4573bbe97b --- /dev/null +++ b/tests/backends/base/test_client.py @@ -0,0 +1,16 @@ +from django.db import connection +from django.db.backends.base.client import BaseDatabaseClient +from django.test import SimpleTestCase + + +class SimpleDatabaseClientTests(SimpleTestCase): + def setUp(self): + self.client = BaseDatabaseClient(connection=connection) + + def test_settings_to_cmd_args_env(self): + msg = ( + 'subclasses of BaseDatabaseClient must provide a ' + 'settings_to_cmd_args_env() method or override a runshell().' + ) + with self.assertRaisesMessage(NotImplementedError, msg): + self.client.settings_to_cmd_args_env(None, None) diff --git a/tests/backends/mysql/test_creation.py b/tests/backends/mysql/test_creation.py index 5675601a1b..0d3480adea 100644 --- a/tests/backends/mysql/test_creation.py +++ b/tests/backends/mysql/test_creation.py @@ -78,6 +78,7 @@ class DatabaseCreationTests(SimpleTestCase): 'source_db', ], stdout=subprocess.PIPE, + env=None, ), ]) finally: diff --git a/tests/backends/oracle/tests.py b/tests/backends/oracle/tests.py index 0a7bf01963..258f98f5c9 100644 --- a/tests/backends/oracle/tests.py +++ b/tests/backends/oracle/tests.py @@ -86,7 +86,10 @@ class TransactionalTests(TransactionTestCase): old_password = connection.settings_dict['PASSWORD'] connection.settings_dict['PASSWORD'] = 'p@ssword' try: - self.assertIn('/"p@ssword"@', connection._connect_string()) + self.assertIn( + '/"p@ssword"@', + connection.client.connect_string(connection.settings_dict), + ) with self.assertRaises(DatabaseError) as context: connection.cursor() # Database exception: "ORA-01017: invalid username/password" is diff --git a/tests/dbshell/test_mysql.py b/tests/dbshell/test_mysql.py index e09644962b..374c466059 100644 --- a/tests/dbshell/test_mysql.py +++ b/tests/dbshell/test_mysql.py @@ -3,32 +3,52 @@ from django.test import SimpleTestCase class MySqlDbshellCommandTestCase(SimpleTestCase): + def settings_to_cmd_args_env(self, settings_dict, parameters=None): + if parameters is None: + parameters = [] + return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters) def test_fails_with_keyerror_on_incomplete_config(self): with self.assertRaises(KeyError): - self.get_command_line_arguments({}) + self.settings_to_cmd_args_env({}) def test_basic_params_specified_in_settings(self): + expected_args = [ + 'mysql', + '--user=someuser', + '--password=somepassword', + '--host=somehost', + '--port=444', + 'somedbname', + ] + expected_env = None self.assertEqual( - ['mysql', '--user=someuser', '--password=somepassword', - '--host=somehost', '--port=444', 'somedbname'], - self.get_command_line_arguments({ + self.settings_to_cmd_args_env({ 'NAME': 'somedbname', 'USER': 'someuser', 'PASSWORD': 'somepassword', 'HOST': 'somehost', 'PORT': 444, 'OPTIONS': {}, - })) + }), + (expected_args, expected_env), + ) def test_options_override_settings_proper_values(self): settings_port = 444 options_port = 555 self.assertNotEqual(settings_port, options_port, 'test pre-req') + expected_args = [ + 'mysql', + '--user=optionuser', + '--password=optionpassword', + '--host=optionhost', + '--port=%s' % options_port, + 'optiondbname', + ] + expected_env = None self.assertEqual( - ['mysql', '--user=optionuser', '--password=optionpassword', - '--host=optionhost', '--port={}'.format(options_port), 'optiondbname'], - self.get_command_line_arguments({ + self.settings_to_cmd_args_env({ 'NAME': 'settingdbname', 'USER': 'settinguser', 'PASSWORD': 'settingpassword', @@ -41,15 +61,22 @@ class MySqlDbshellCommandTestCase(SimpleTestCase): 'host': 'optionhost', 'port': options_port, }, - })) + }), + (expected_args, expected_env), + ) def test_options_password(self): + expected_args = [ + 'mysql', + '--user=someuser', + '--password=optionpassword', + '--host=somehost', + '--port=444', + 'somedbname', + ] + expected_env = None self.assertEqual( - [ - 'mysql', '--user=someuser', '--password=optionpassword', - '--host=somehost', '--port=444', 'somedbname', - ], - self.get_command_line_arguments({ + self.settings_to_cmd_args_env({ 'NAME': 'somedbname', 'USER': 'someuser', 'PASSWORD': 'settingpassword', @@ -57,16 +84,22 @@ class MySqlDbshellCommandTestCase(SimpleTestCase): 'PORT': 444, 'OPTIONS': {'password': 'optionpassword'}, }), + (expected_args, expected_env), ) def test_options_charset(self): + expected_args = [ + 'mysql', + '--user=someuser', + '--password=somepassword', + '--host=somehost', + '--port=444', + '--default-character-set=utf8', + 'somedbname', + ] + expected_env = None self.assertEqual( - [ - 'mysql', '--user=someuser', '--password=somepassword', - '--host=somehost', '--port=444', - '--default-character-set=utf8', 'somedbname', - ], - self.get_command_line_arguments({ + self.settings_to_cmd_args_env({ 'NAME': 'somedbname', 'USER': 'someuser', 'PASSWORD': 'somepassword', @@ -74,27 +107,45 @@ class MySqlDbshellCommandTestCase(SimpleTestCase): 'PORT': 444, 'OPTIONS': {'charset': 'utf8'}, }), + (expected_args, expected_env), ) def test_can_connect_using_sockets(self): + expected_args = [ + 'mysql', + '--user=someuser', + '--password=somepassword', + '--socket=/path/to/mysql.socket.file', + 'somedbname', + ] + expected_env = None self.assertEqual( - ['mysql', '--user=someuser', '--password=somepassword', - '--socket=/path/to/mysql.socket.file', 'somedbname'], - self.get_command_line_arguments({ + self.settings_to_cmd_args_env({ 'NAME': 'somedbname', 'USER': 'someuser', 'PASSWORD': 'somepassword', 'HOST': '/path/to/mysql.socket.file', 'PORT': None, 'OPTIONS': {}, - })) + }), + (expected_args, expected_env), + ) def test_ssl_certificate_is_added(self): + expected_args = [ + 'mysql', + '--user=someuser', + '--password=somepassword', + '--host=somehost', + '--port=444', + '--ssl-ca=sslca', + '--ssl-cert=sslcert', + '--ssl-key=sslkey', + 'somedbname', + ] + expected_env = None self.assertEqual( - ['mysql', '--user=someuser', '--password=somepassword', - '--host=somehost', '--port=444', '--ssl-ca=sslca', - '--ssl-cert=sslcert', '--ssl-key=sslkey', 'somedbname'], - self.get_command_line_arguments({ + self.settings_to_cmd_args_env({ 'NAME': 'somedbname', 'USER': 'someuser', 'PASSWORD': 'somepassword', @@ -107,12 +158,13 @@ class MySqlDbshellCommandTestCase(SimpleTestCase): 'key': 'sslkey', }, }, - })) + }), + (expected_args, expected_env), + ) def test_parameters(self): self.assertEqual( - ['mysql', 'somedbname', '--help'], - self.get_command_line_arguments( + self.settings_to_cmd_args_env( { 'NAME': 'somedbname', 'USER': None, @@ -123,9 +175,5 @@ class MySqlDbshellCommandTestCase(SimpleTestCase): }, ['--help'], ), + (['mysql', 'somedbname', '--help'], None), ) - - def get_command_line_arguments(self, connection_settings, parameters=None): - if parameters is None: - parameters = [] - return DatabaseClient.settings_to_cmd_args(connection_settings, parameters) diff --git a/tests/dbshell/test_oracle.py b/tests/dbshell/test_oracle.py index 41fbc07455..34e96fb09b 100644 --- a/tests/dbshell/test_oracle.py +++ b/tests/dbshell/test_oracle.py @@ -1,4 +1,3 @@ -from subprocess import CompletedProcess from unittest import mock, skipUnless from django.db import connection @@ -6,37 +5,48 @@ from django.db.backends.oracle.client import DatabaseClient from django.test import SimpleTestCase -@skipUnless(connection.vendor == 'oracle', 'Oracle tests') +@skipUnless(connection.vendor == 'oracle', 'Requires cx_Oracle to be installed') class OracleDbshellTests(SimpleTestCase): - def _run_dbshell(self, rlwrap=False, parameters=None): - """Run runshell command and capture its arguments.""" - def _mock_subprocess_run(*args, **kwargs): - self.subprocess_args = list(*args) - return CompletedProcess(self.subprocess_args, 0) - + def settings_to_cmd_args_env(self, settings_dict, parameters=None, rlwrap=False): if parameters is None: parameters = [] - client = DatabaseClient(connection) - self.subprocess_args = None - with mock.patch('subprocess.run', new=_mock_subprocess_run): - with mock.patch('shutil.which', return_value='/usr/bin/rlwrap' if rlwrap else None): - client.runshell(parameters) - return self.subprocess_args + with mock.patch('shutil.which', return_value='/usr/bin/rlwrap' if rlwrap else None): + return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters) def test_without_rlwrap(self): + expected_args = [ + 'sqlplus', + '-L', + connection.client.connect_string(connection.settings_dict), + ] self.assertEqual( - self._run_dbshell(rlwrap=False), - ['sqlplus', '-L', connection._connect_string()], + self.settings_to_cmd_args_env(connection.settings_dict, rlwrap=False), + (expected_args, None), ) def test_with_rlwrap(self): + expected_args = [ + '/usr/bin/rlwrap', + 'sqlplus', + '-L', + connection.client.connect_string(connection.settings_dict), + ] self.assertEqual( - self._run_dbshell(rlwrap=True), - ['/usr/bin/rlwrap', 'sqlplus', '-L', connection._connect_string()], + self.settings_to_cmd_args_env(connection.settings_dict, rlwrap=True), + (expected_args, None), ) def test_parameters(self): + expected_args = [ + 'sqlplus', + '-L', + connection.client.connect_string(connection.settings_dict), + '-HELP', + ] self.assertEqual( - self._run_dbshell(parameters=['-HELP']), - ['sqlplus', '-L', connection._connect_string(), '-HELP'], + self.settings_to_cmd_args_env( + connection.settings_dict, + parameters=['-HELP'], + ), + (expected_args, None), ) diff --git a/tests/dbshell/test_postgresql.py b/tests/dbshell/test_postgresql.py index 6de60eaef2..aad9692ecb 100644 --- a/tests/dbshell/test_postgresql.py +++ b/tests/dbshell/test_postgresql.py @@ -1,41 +1,25 @@ -import os import signal -import subprocess -from unittest import mock +from unittest import mock, skipUnless +from django.db import connection from django.db.backends.postgresql.client import DatabaseClient from django.test import SimpleTestCase class PostgreSqlDbshellCommandTestCase(SimpleTestCase): - - def _run_it(self, dbinfo, parameters=None): - """ - That function invokes the runshell command, while mocking - subprocess.run(). It returns a 2-tuple with: - - The command line list - - The dictionary of PG* environment variables, or {}. - """ - def _mock_subprocess_run(*args, env=os.environ, **kwargs): - self.subprocess_args = list(*args) - # PostgreSQL environment variables. - self.pg_env = {key: env[key] for key in env if key.startswith('PG')} - return subprocess.CompletedProcess(self.subprocess_args, 0) - + def settings_to_cmd_args_env(self, settings_dict, parameters=None): if parameters is None: parameters = [] - with mock.patch('subprocess.run', new=_mock_subprocess_run): - DatabaseClient.runshell_db(dbinfo, parameters) - return self.subprocess_args, self.pg_env + return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters) def test_basic(self): self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'someuser', - 'password': 'somepassword', - 'host': 'somehost', - 'port': '444', + self.settings_to_cmd_args_env({ + 'NAME': 'dbname', + 'USER': 'someuser', + 'PASSWORD': 'somepassword', + 'HOST': 'somehost', + 'PORT': '444', }), ( ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], {'PGPASSWORD': 'somepassword'}, @@ -44,11 +28,11 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): def test_nopass(self): self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'someuser', - 'host': 'somehost', - 'port': '444', + self.settings_to_cmd_args_env({ + 'NAME': 'dbname', + 'USER': 'someuser', + 'HOST': 'somehost', + 'PORT': '444', }), ( ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], {}, @@ -57,15 +41,17 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): def test_ssl_certificate(self): self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'someuser', - 'host': 'somehost', - 'port': '444', - 'sslmode': 'verify-ca', - 'sslrootcert': 'root.crt', - 'sslcert': 'client.crt', - 'sslkey': 'client.key', + self.settings_to_cmd_args_env({ + 'NAME': 'dbname', + 'USER': 'someuser', + 'HOST': 'somehost', + 'PORT': '444', + 'OPTIONS': { + 'sslmode': 'verify-ca', + 'sslrootcert': 'root.crt', + 'sslcert': 'client.crt', + 'sslkey': 'client.key', + }, }), ( ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], { @@ -79,12 +65,12 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): def test_column(self): self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': 'some:user', - 'password': 'some:password', - 'host': '::1', - 'port': '444', + self.settings_to_cmd_args_env({ + 'NAME': 'dbname', + 'USER': 'some:user', + 'PASSWORD': 'some:password', + 'HOST': '::1', + 'PORT': '444', }), ( ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], {'PGPASSWORD': 'some:password'}, @@ -95,12 +81,12 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): username = 'rôle' password = 'sésame' self.assertEqual( - self._run_it({ - 'database': 'dbname', - 'user': username, - 'password': password, - 'host': 'somehost', - 'port': '444', + self.settings_to_cmd_args_env({ + 'NAME': 'dbname', + 'USER': username, + 'PASSWORD': password, + 'HOST': 'somehost', + 'PORT': '444', }), ( ['psql', '-U', username, '-h', 'somehost', '-p', '444', 'dbname'], {'PGPASSWORD': password}, @@ -109,10 +95,11 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): def test_parameters(self): self.assertEqual( - self._run_it({'database': 'dbname'}, ['--help']), + self.settings_to_cmd_args_env({'NAME': 'dbname'}, ['--help']), (['psql', 'dbname', '--help'], {}), ) + @skipUnless(connection.vendor == 'postgresql', 'Requires a PostgreSQL connection') def test_sigint_handler(self): """SIGINT is ignored in Python and passed to psql to abort queries.""" def _mock_subprocess_run(*args, **kwargs): @@ -123,6 +110,6 @@ class PostgreSqlDbshellCommandTestCase(SimpleTestCase): # The default handler isn't SIG_IGN. self.assertNotEqual(sigint_handler, signal.SIG_IGN) with mock.patch('subprocess.run', new=_mock_subprocess_run): - DatabaseClient.runshell_db({}, []) + connection.client.runshell([]) # dbshell restores the original handler. self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT)) diff --git a/tests/dbshell/test_sqlite.py b/tests/dbshell/test_sqlite.py index c3b2b1e28d..570230f62d 100644 --- a/tests/dbshell/test_sqlite.py +++ b/tests/dbshell/test_sqlite.py @@ -1,43 +1,23 @@ from pathlib import Path -from subprocess import CompletedProcess -from unittest import mock, skipUnless -from django.db import connection from django.db.backends.sqlite3.client import DatabaseClient from django.test import SimpleTestCase -@skipUnless(connection.vendor == 'sqlite', 'SQLite tests.') class SqliteDbshellCommandTestCase(SimpleTestCase): - def _run_dbshell(self, parameters=None): - """Run runshell command and capture its arguments.""" - def _mock_subprocess_run(*args, **kwargs): - self.subprocess_args = list(*args) - return CompletedProcess(self.subprocess_args, 0) - + def settings_to_cmd_args_env(self, settings_dict, parameters=None): if parameters is None: parameters = [] - client = DatabaseClient(connection) - with mock.patch('subprocess.run', new=_mock_subprocess_run): - client.runshell(parameters) - return self.subprocess_args + return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters) def test_path_name(self): - with mock.patch.dict( - connection.settings_dict, - {'NAME': Path('test.db.sqlite3')}, - ): - self.assertEqual( - self._run_dbshell(), - ['sqlite3', 'test.db.sqlite3'], - ) + self.assertEqual( + self.settings_to_cmd_args_env({'NAME': Path('test.db.sqlite3')}), + (['sqlite3', 'test.db.sqlite3'], None), + ) def test_parameters(self): - with mock.patch.dict( - connection.settings_dict, - {'NAME': Path('test.db.sqlite3')}, - ): - self.assertEqual( - self._run_dbshell(['-help']), - ['sqlite3', 'test.db.sqlite3', '-help'], - ) + self.assertEqual( + self.settings_to_cmd_args_env({'NAME': 'test.db.sqlite3'}, ['-help']), + (['sqlite3', 'test.db.sqlite3', '-help'], None), + )