diff --git a/django/test/utils.py b/django/test/utils.py index 2f9678c29a..619f9cd0f4 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -493,3 +493,51 @@ def extend_sys_path(*paths): yield finally: sys.path = _orig_sys_path + + +@contextmanager +def captured_output(stream_name): + """Return a context manager used by captured_stdout/stdin/stderr + that temporarily replaces the sys stream *stream_name* with a StringIO. + + Note: This function and the following ``captured_std*`` are copied + from CPython's ``test.support`` module.""" + orig_stdout = getattr(sys, stream_name) + setattr(sys, stream_name, six.StringIO()) + try: + yield getattr(sys, stream_name) + finally: + setattr(sys, stream_name, orig_stdout) + + +def captured_stdout(): + """Capture the output of sys.stdout: + + with captured_stdout() as stdout: + print("hello") + self.assertEqual(stdout.getvalue(), "hello\n") + """ + return captured_output("stdout") + + +def captured_stderr(): + """Capture the output of sys.stderr: + + with captured_stderr() as stderr: + print("hello", file=sys.stderr) + self.assertEqual(stderr.getvalue(), "hello\n") + """ + return captured_output("stderr") + + +def captured_stdin(): + """Capture the input to sys.stdin: + + with captured_stdin() as stdin: + stdin.write('hello\n') + stdin.seek(0) + # call test code that consumes from sys.stdin + captured = input() + self.assertEqual(captured, "hello") + """ + return captured_output("stdin") diff --git a/tests/bash_completion/tests.py b/tests/bash_completion/tests.py index 25a62148c0..98ea2f3fbe 100644 --- a/tests/bash_completion/tests.py +++ b/tests/bash_completion/tests.py @@ -7,7 +7,7 @@ import unittest from django.apps import apps from django.core.management import ManagementUtility -from django.utils.six import StringIO +from django.test.utils import captured_stdout class BashCompletionTests(unittest.TestCase): @@ -20,12 +20,8 @@ class BashCompletionTests(unittest.TestCase): def setUp(self): self.old_DJANGO_AUTO_COMPLETE = os.environ.get('DJANGO_AUTO_COMPLETE') os.environ['DJANGO_AUTO_COMPLETE'] = '1' - self.output = StringIO() - self.old_stdout = sys.stdout - sys.stdout = self.output def tearDown(self): - sys.stdout = self.old_stdout if self.old_DJANGO_AUTO_COMPLETE: os.environ['DJANGO_AUTO_COMPLETE'] = self.old_DJANGO_AUTO_COMPLETE else: @@ -53,11 +49,12 @@ class BashCompletionTests(unittest.TestCase): def _run_autocomplete(self): util = ManagementUtility(argv=sys.argv) - try: - util.autocomplete() - except SystemExit: - pass - return self.output.getvalue().strip().split('\n') + with captured_stdout() as stdout: + try: + util.autocomplete() + except SystemExit: + pass + return stdout.getvalue().strip().split('\n') def test_django_admin_py(self): "django_admin.py will autocomplete option flags" diff --git a/tests/contenttypes_tests/tests.py b/tests/contenttypes_tests/tests.py index 64414a8732..a21d783008 100644 --- a/tests/contenttypes_tests/tests.py +++ b/tests/contenttypes_tests/tests.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -import sys - from django.apps.registry import Apps, apps from django.contrib.contenttypes.fields import ( GenericForeignKey, GenericRelation @@ -12,9 +10,8 @@ from django.contrib.contenttypes.models import ContentType from django.core import checks from django.db import connections, models, router from django.test import TestCase -from django.test.utils import override_settings +from django.test.utils import captured_stdout, override_settings from django.utils.encoding import force_str -from django.utils.six import StringIO from .models import Author, Article, SchemeIncludedURL @@ -369,11 +366,6 @@ class UpdateContentTypesTests(TestCase): self.before_count = ContentType.objects.count() ContentType.objects.create(name='fake', app_label='contenttypes_tests', model='Fake') self.app_config = apps.get_app_config('contenttypes_tests') - self.old_stdout = sys.stdout - sys.stdout = StringIO() - - def tearDown(self): - sys.stdout = self.old_stdout def test_interactive_true(self): """ @@ -381,8 +373,9 @@ class UpdateContentTypesTests(TestCase): stale contenttypes. """ management.input = lambda x: force_str("yes") - management.update_contenttypes(self.app_config) - self.assertIn("Deleting stale content type", sys.stdout.getvalue()) + with captured_stdout() as stdout: + management.update_contenttypes(self.app_config) + self.assertIn("Deleting stale content type", stdout.getvalue()) self.assertEqual(ContentType.objects.count(), self.before_count) def test_interactive_false(self): @@ -390,8 +383,9 @@ class UpdateContentTypesTests(TestCase): non-interactive mode of update_contenttypes() shouldn't delete stale content types. """ - management.update_contenttypes(self.app_config, interactive=False) - self.assertIn("Stale content types remain.", sys.stdout.getvalue()) + with captured_stdout() as stdout: + management.update_contenttypes(self.app_config, interactive=False) + self.assertIn("Stale content types remain.", stdout.getvalue()) self.assertEqual(ContentType.objects.count(), self.before_count + 1) diff --git a/tests/i18n/test_compilation.py b/tests/i18n/test_compilation.py index be46e168aa..bc76a5337c 100644 --- a/tests/i18n/test_compilation.py +++ b/tests/i18n/test_compilation.py @@ -3,7 +3,6 @@ import os import shutil import stat -import sys import unittest import gettext as gettext_module @@ -11,6 +10,7 @@ from django.core.management import call_command, CommandError, execute_from_comm from django.core.management.utils import find_command from django.test import SimpleTestCase from django.test import override_settings +from django.test.utils import captured_stderr, captured_stdout from django.utils import translation from django.utils.translation import ugettext from django.utils.encoding import force_text @@ -145,15 +145,11 @@ class ExcludedLocaleCompilationTests(MessageCompilationTests): self.addCleanup(self._rmrf, os.path.join(self.test_dir, 'locale')) def test_command_help(self): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = StringIO(), StringIO() - try: + with captured_stdout(), captured_stderr(): # `call_command` bypasses the parser; by calling # `execute_from_command_line` with the help subcommand we # ensure that there are no issues with the parser itself. execute_from_command_line(['django-admin', 'help', 'compilemessages']) - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr def test_one_locale_excluded(self): call_command('compilemessages', exclude=['it'], stdout=StringIO()) diff --git a/tests/i18n/test_extraction.py b/tests/i18n/test_extraction.py index 3081776d96..705eca431e 100644 --- a/tests/i18n/test_extraction.py +++ b/tests/i18n/test_extraction.py @@ -5,7 +5,6 @@ import io import os import re import shutil -import sys import time from unittest import SkipTest, skipUnless import warnings @@ -16,6 +15,7 @@ from django.core.management import execute_from_command_line from django.core.management.utils import find_command from django.test import SimpleTestCase from django.test import override_settings +from django.test.utils import captured_stderr, captured_stdout from django.utils.encoding import force_text from django.utils._os import upath from django.utils import six @@ -632,15 +632,11 @@ class ExcludedLocaleExtractionTests(ExtractorTests): self.addCleanup(self._rmrf, os.path.join(self.test_dir, 'locale')) def test_command_help(self): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = StringIO(), StringIO() - try: + with captured_stdout(), captured_stderr(): # `call_command` bypasses the parser; by calling # `execute_from_command_line` with the help subcommand we # ensure that there are no issues with the parser itself. execute_from_command_line(['django-admin', 'help', 'makemessages']) - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr def test_one_locale_excluded(self): management.call_command('makemessages', exclude=['it'], stdout=StringIO()) diff --git a/tests/servers/test_basehttp.py b/tests/servers/test_basehttp.py index 8ee99d4eaa..41e9258bf5 100644 --- a/tests/servers/test_basehttp.py +++ b/tests/servers/test_basehttp.py @@ -1,10 +1,9 @@ -import sys - from django.core.handlers.wsgi import WSGIRequest from django.core.servers.basehttp import WSGIRequestHandler from django.test import TestCase from django.test.client import RequestFactory -from django.utils.six import BytesIO, StringIO +from django.test.utils import captured_stderr +from django.utils.six import BytesIO class WSGIRequestHandlerTestCase(TestCase): @@ -14,14 +13,10 @@ class WSGIRequestHandlerTestCase(TestCase): handler = WSGIRequestHandler(request, '192.168.0.2', None) - _stderr = sys.stderr - sys.stderr = StringIO() - try: + with captured_stderr() as stderr: handler.log_message("GET %s %s", str('\x16\x03'), "4") self.assertIn( "You're accessing the developement server over HTTPS, " "but it only supports HTTP.", - sys.stderr.getvalue() + stderr.getvalue() ) - finally: - sys.stderr = _stderr diff --git a/tests/user_commands/tests.py b/tests/user_commands/tests.py index 25d1ad7c33..3839f23e78 100644 --- a/tests/user_commands/tests.py +++ b/tests/user_commands/tests.py @@ -1,5 +1,4 @@ import os -import sys import warnings from django.db import connection @@ -7,6 +6,7 @@ from django.core import management from django.core.management import BaseCommand, CommandError from django.core.management.utils import find_command, popen_wrapper from django.test import SimpleTestCase +from django.test.utils import captured_stderr, captured_stdout from django.utils import translation from django.utils.deprecation import RemovedInDjango20Warning from django.utils.six import StringIO @@ -42,14 +42,9 @@ class CommandTests(SimpleTestCase): """ with self.assertRaises(CommandError): management.call_command('dance', example="raise") - old_stderr = sys.stderr - sys.stderr = err = StringIO() - try: - with self.assertRaises(SystemExit): - management.ManagementUtility(['manage.py', 'dance', '--example=raise']).execute() - finally: - sys.stderr = old_stderr - self.assertIn("CommandError", err.getvalue()) + with captured_stderr() as stderr, self.assertRaises(SystemExit): + management.ManagementUtility(['manage.py', 'dance', '--example=raise']).execute() + self.assertIn("CommandError", stderr.getvalue()) def test_default_en_us_locale_set(self): # Forces en_us when set to true @@ -100,14 +95,9 @@ class CommandTests(SimpleTestCase): self.assertEqual(out.getvalue(), "All right, let's dance Rock'n'Roll.\n") # Simulate command line execution - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = StringIO(), StringIO() - try: + with captured_stdout() as stdout, captured_stderr(): management.execute_from_command_line(['django-admin', 'optparse_cmd']) - finally: - output = sys.stdout.getvalue() - sys.stdout, sys.stderr = old_stdout, old_stderr - self.assertEqual(output, "All right, let's dance Rock'n'Roll.\n") + self.assertEqual(stdout.getvalue(), "All right, let's dance Rock'n'Roll.\n") def test_calling_a_command_with_only_empty_parameter_should_ends_gracefully(self): out = StringIO()