diff --git a/django/db/backends/postgresql_psycopg2/client.py b/django/db/backends/postgresql_psycopg2/client.py index aa60e58943..5e3e288301 100644 --- a/django/db/backends/postgresql_psycopg2/client.py +++ b/django/db/backends/postgresql_psycopg2/client.py @@ -1,19 +1,66 @@ +import os import subprocess +from django.core.files.temp import NamedTemporaryFile from django.db.backends.base.client import BaseDatabaseClient +from django.utils.six import print_ + + +def _escape_pgpass(txt): + """ + Escape a fragment of a PostgreSQL .pgpass file. + """ + return txt.replace('\\', '\\\\').replace(':', '\\:') class DatabaseClient(BaseDatabaseClient): executable_name = 'psql' + @classmethod + def runshell_db(cls, settings_dict): + args = [cls.executable_name] + + host = settings_dict.get('HOST', '') + port = settings_dict.get('PORT', '') + name = settings_dict.get('NAME', '') + user = settings_dict.get('USER', '') + passwd = settings_dict.get('PASSWORD', '') + + if user: + args += ['-U', user] + if host: + args += ['-h', host] + if port: + args += ['-p', str(port)] + args += [name] + + temp_pgpass = None + try: + if passwd: + # Create temporary .pgpass file. + temp_pgpass = NamedTemporaryFile(mode='w+') + try: + print_( + _escape_pgpass(host) or '*', + str(port) or '*', + _escape_pgpass(name) or '*', + _escape_pgpass(user) or '*', + _escape_pgpass(passwd), + file=temp_pgpass, + sep=':', + flush=True, + ) + os.environ['PGPASSFILE'] = temp_pgpass.name + except UnicodeEncodeError: + # If the current locale can't encode the data, we let + # the user input the password manually. + pass + subprocess.call(args) + finally: + if temp_pgpass: + temp_pgpass.close() + if 'PGPASSFILE' in os.environ: # unit tests need cleanup + del os.environ['PGPASSFILE'] + def runshell(self): - settings_dict = self.connection.settings_dict - args = [self.executable_name] - if settings_dict['USER']: - args += ["-U", settings_dict['USER']] - if settings_dict['HOST']: - args.extend(["-h", settings_dict['HOST']]) - if settings_dict['PORT']: - args.extend(["-p", str(settings_dict['PORT'])]) - args += [settings_dict['NAME']] - subprocess.call(args) + DatabaseClient.runshell_db(self.connection.settings_dict) diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt index 600c88ca2c..9f2dd6ccd1 100644 --- a/docs/releases/1.9.txt +++ b/docs/releases/1.9.txt @@ -350,6 +350,10 @@ Management Commands * The :djadmin:`startapp` command creates an ``apps.py`` file and adds ``default_app_config`` in ``__init__.py``. +* When using the PostgreSQL backend, the :djadmin:`dbshell` command can connect + to the database using the password from your settings file (instead of + requiring it to be manually entered). + Models ^^^^^^ diff --git a/tests/dbshell/test_postgresql_psycopg2.py b/tests/dbshell/test_postgresql_psycopg2.py new file mode 100644 index 0000000000..aecbba7f42 --- /dev/null +++ b/tests/dbshell/test_postgresql_psycopg2.py @@ -0,0 +1,117 @@ +# -*- coding: utf8 -*- +from __future__ import unicode_literals + +import locale +import os + +from django.db.backends.postgresql_psycopg2.client import DatabaseClient +from django.test import SimpleTestCase, mock +from django.utils import six +from django.utils.encoding import force_bytes, force_str + + +class PostgreSqlDbshellCommandTestCase(SimpleTestCase): + + def _run_it(self, dbinfo): + """ + That function invokes the runshell command, while mocking + subprocess.call. It returns a 2-tuple with: + - The command line list + - The binary content of file pointed by environment PGPASSFILE, or + None. + """ + def _mock_subprocess_call(*args): + self.subprocess_args = list(*args) + if 'PGPASSFILE' in os.environ: + self.pgpass = open(os.environ['PGPASSFILE'], 'rb').read() + else: + self.pgpass = None + return 0 + self.subprocess_args = None + self.pgpass = None + with mock.patch('subprocess.call', new=_mock_subprocess_call): + DatabaseClient.runshell_db(dbinfo) + return self.subprocess_args, self.pgpass + + def test_basic(self): + self.assertEqual( + self._run_it({ + 'NAME': 'dbname', + 'USER': 'someuser', + 'PASSWORD': 'somepassword', + 'HOST': 'somehost', + 'PORT': 444, + }), ( + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], + b'somehost:444:dbname:someuser:somepassword\n', + ) + ) + + def test_nopass(self): + self.assertEqual( + self._run_it({ + 'NAME': 'dbname', + 'USER': 'someuser', + 'HOST': 'somehost', + 'PORT': 444, + }), ( + ['psql', '-U', 'someuser', '-h', 'somehost', '-p', '444', 'dbname'], + None, + ) + ) + + def test_column(self): + self.assertEqual( + self._run_it({ + 'NAME': 'dbname', + 'USER': 'some:user', + 'PASSWORD': 'some:password', + 'HOST': '::1', + 'PORT': 444, + }), ( + ['psql', '-U', 'some:user', '-h', '::1', '-p', '444', 'dbname'], + b'\\:\\:1:444:dbname:some\\:user:some\\:password\n', + ) + ) + + def test_escape_characters(self): + self.assertEqual( + self._run_it({ + 'NAME': 'dbname', + 'USER': 'some\\user', + 'PASSWORD': 'some\\password', + 'HOST': 'somehost', + 'PORT': 444, + }), ( + ['psql', '-U', 'some\\user', '-h', 'somehost', '-p', '444', 'dbname'], + b'somehost:444:dbname:some\\\\user:some\\\\password\n', + ) + ) + + def test_accent(self): + # The pgpass temporary file needs to be encoded using the system locale. + encoding = locale.getpreferredencoding() + username = 'rôle' + password = 'sésame' + try: + username_str = force_str(username, encoding) + password_str = force_str(password, encoding) + pgpass_bytes = force_bytes( + 'somehost:444:dbname:%s:%s\n' % (username, password), + encoding=encoding, + ) + except UnicodeEncodeError: + if six.PY2: + self.skipTest("Your locale can't run this test.") + self.assertEqual( + self._run_it({ + 'NAME': 'dbname', + 'USER': username_str, + 'PASSWORD': password_str, + 'HOST': 'somehost', + 'PORT': 444, + }), ( + ['psql', '-U', username_str, '-h', 'somehost', '-p', '444', 'dbname'], + pgpass_bytes, + ) + )