diff --git a/django/contrib/staticfiles/management/commands/collectstatic.py b/django/contrib/staticfiles/management/commands/collectstatic.py index e593056ce2..f004727c83 100644 --- a/django/contrib/staticfiles/management/commands/collectstatic.py +++ b/django/contrib/staticfiles/management/commands/collectstatic.py @@ -22,7 +22,7 @@ class Command(BaseCommand): requires_system_checks = False def __init__(self, *args, **kwargs): - super(BaseCommand, self).__init__(*args, **kwargs) + super(Command, self).__init__(*args, **kwargs) self.copied_files = [] self.symlinked_files = [] self.unmodified_files = [] diff --git a/django/core/management/base.py b/django/core/management/base.py index 7b62d6c259..9e0c70a653 100644 --- a/django/core/management/base.py +++ b/django/core/management/base.py @@ -79,11 +79,20 @@ class OutputWrapper(object): """ Wrapper around stdout/stderr """ + @property + def style_func(self): + return self._style_func + + @style_func.setter + def style_func(self, style_func): + if style_func and hasattr(self._out, 'isatty') and self._out.isatty(): + self._style_func = style_func + else: + self._style_func = lambda x: x + def __init__(self, out, style_func=None, ending='\n'): self._out = out self.style_func = None - if hasattr(out, 'isatty') and out.isatty(): - self.style_func = style_func self.ending = ending def __getattr__(self, name): @@ -93,8 +102,7 @@ class OutputWrapper(object): ending = self.ending if ending is None else ending if ending and not msg.endswith(ending): msg += ending - style_func = [f for f in (style_func, self.style_func, lambda x:x) - if f is not None][0] + style_func = style_func or self.style_func self._out.write(force_str(style_func(msg))) @@ -221,8 +229,14 @@ class BaseCommand(object): # # requires_system_checks = True - def __init__(self): - self.style = color_style() + def __init__(self, stdout=None, stderr=None, no_color=False): + self.stdout = OutputWrapper(stdout or sys.stdout) + self.stderr = OutputWrapper(stderr or sys.stderr) + if no_color: + self.style = no_style() + else: + self.style = color_style() + self.stderr.style_func = self.style.ERROR # `requires_model_validation` is deprecated in favor of # `requires_system_checks`. If both options are present, an error is @@ -371,9 +385,7 @@ class BaseCommand(object): if options.traceback or not isinstance(e, CommandError): raise - # self.stderr is not guaranteed to be set here - stderr = getattr(self, 'stderr', OutputWrapper(sys.stderr, self.style.ERROR)) - stderr.write('%s: %s' % (e.__class__.__name__, e)) + self.stderr.write('%s: %s' % (e.__class__.__name__, e)) sys.exit(1) def execute(self, *args, **options): @@ -382,12 +394,13 @@ class BaseCommand(object): controlled by attributes ``self.requires_system_checks`` and ``self.requires_model_validation``, except if force-skipped). """ - self.stdout = OutputWrapper(options.get('stdout', sys.stdout)) if options.get('no_color'): self.style = no_style() - self.stderr = OutputWrapper(options.get('stderr', sys.stderr)) - else: - self.stderr = OutputWrapper(options.get('stderr', sys.stderr), self.style.ERROR) + self.stderr.style_func = None + if options.get('stdout'): + self.stdout = OutputWrapper(options['stdout']) + if options.get('stderr'): + self.stderr = OutputWrapper(options.get('stderr'), self.stderr.style_func) if self.can_import_settings: from django.conf import settings # NOQA diff --git a/tests/admin_scripts/management/commands/color_command.py b/tests/admin_scripts/management/commands/color_command.py deleted file mode 100644 index 5a1c297762..0000000000 --- a/tests/admin_scripts/management/commands/color_command.py +++ /dev/null @@ -1,9 +0,0 @@ -from django.core.management.base import BaseCommand - - -class Command(BaseCommand): - help = "Test color output" - requires_system_checks = False - - def handle(self, **options): - return self.style.SQL_KEYWORD('BEGIN') diff --git a/tests/admin_scripts/tests.py b/tests/admin_scripts/tests.py index 2b3f4824fd..f2c0d598cf 100644 --- a/tests/admin_scripts/tests.py +++ b/tests/admin_scripts/tests.py @@ -21,7 +21,7 @@ import django from django import conf, get_version from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.core.management import BaseCommand, CommandError, call_command +from django.core.management import BaseCommand, CommandError, call_command, color from django.db import connection from django.utils.encoding import force_text from django.utils._os import npath, upath @@ -1392,12 +1392,83 @@ class CommandTypes(AdminScriptTestCase): self.assertOutput(out, "Prints the CREATE TABLE, custom SQL and CREATE INDEX SQL statements for the\ngiven model module name(s).") self.assertEqual(out.count('optional arguments'), 1) - def test_no_color(self): - "--no-color prevent colorization of the output" - out = StringIO() + def test_command_color(self): + class Command(BaseCommand): + requires_system_checks = False - call_command('color_command', no_color=True, stdout=out) - self.assertEqual(out.getvalue(), 'BEGIN\n') + def handle(self, *args, **options): + self.stdout.write('Hello, world!', self.style.ERROR) + self.stderr.write('Hello, world!', self.style.ERROR) + + out = StringIO() + err = StringIO() + command = Command(stdout=out, stderr=err) + command.execute() + if color.supports_color(): + self.assertIn('Hello, world!\n', out.getvalue()) + self.assertIn('Hello, world!\n', err.getvalue()) + self.assertNotEqual(out.getvalue(), 'Hello, world!\n') + self.assertNotEqual(err.getvalue(), 'Hello, world!\n') + else: + self.assertEqual(out.getvalue(), 'Hello, world!\n') + self.assertEqual(err.getvalue(), 'Hello, world!\n') + + def test_command_no_color(self): + "--no-color prevent colorization of the output" + class Command(BaseCommand): + requires_system_checks = False + + def handle(self, *args, **options): + self.stdout.write('Hello, world!', self.style.ERROR) + self.stderr.write('Hello, world!', self.style.ERROR) + + out = StringIO() + err = StringIO() + command = Command(stdout=out, stderr=err, no_color=True) + command.execute() + self.assertEqual(out.getvalue(), 'Hello, world!\n') + self.assertEqual(err.getvalue(), 'Hello, world!\n') + + out = StringIO() + err = StringIO() + command = Command(stdout=out, stderr=err) + command.execute(no_color=True) + self.assertEqual(out.getvalue(), 'Hello, world!\n') + self.assertEqual(err.getvalue(), 'Hello, world!\n') + + def test_custom_stdout(self): + class Command(BaseCommand): + requires_system_checks = False + + def handle(self, *args, **options): + self.stdout.write("Hello, World!") + + out = StringIO() + command = Command(stdout=out) + command.execute() + self.assertEqual(out.getvalue(), "Hello, World!\n") + out.truncate(0) + new_out = StringIO() + command.execute(stdout=new_out) + self.assertEqual(out.getvalue(), "") + self.assertEqual(new_out.getvalue(), "Hello, World!\n") + + def test_custom_stderr(self): + class Command(BaseCommand): + requires_system_checks = False + + def handle(self, *args, **options): + self.stderr.write("Hello, World!") + + err = StringIO() + command = Command(stderr=err) + command.execute() + self.assertEqual(err.getvalue(), "Hello, World!\n") + err.truncate(0) + new_err = StringIO() + command.execute(stderr=new_err) + self.assertEqual(err.getvalue(), "") + self.assertEqual(new_err.getvalue(), "Hello, World!\n") def test_base_command(self): "User BaseCommands can execute when a label is provided"