Fixed #23663 -- Initialize output streams for BaseCommand in __init__().

This helps with testability of management commands.

Thanks to trac username daveoncode for the report and to
Tim Graham and Claude Paroz for the reviews.
This commit is contained in:
Loic Bistuer 2014-10-22 01:11:31 +07:00
parent 494cd857c8
commit 533532302a
4 changed files with 104 additions and 29 deletions

View File

@ -22,7 +22,7 @@ class Command(BaseCommand):
requires_system_checks = False requires_system_checks = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BaseCommand, self).__init__(*args, **kwargs) super(Command, self).__init__(*args, **kwargs)
self.copied_files = [] self.copied_files = []
self.symlinked_files = [] self.symlinked_files = []
self.unmodified_files = [] self.unmodified_files = []

View File

@ -79,11 +79,20 @@ class OutputWrapper(object):
""" """
Wrapper around stdout/stderr Wrapper around stdout/stderr
""" """
@property
def style_func(self):
return self._style_func
@style_func.setter
def style_func(self, style_func):
if style_func and hasattr(self._out, 'isatty') and self._out.isatty():
self._style_func = style_func
else:
self._style_func = lambda x: x
def __init__(self, out, style_func=None, ending='\n'): def __init__(self, out, style_func=None, ending='\n'):
self._out = out self._out = out
self.style_func = None self.style_func = None
if hasattr(out, 'isatty') and out.isatty():
self.style_func = style_func
self.ending = ending self.ending = ending
def __getattr__(self, name): def __getattr__(self, name):
@ -93,8 +102,7 @@ class OutputWrapper(object):
ending = self.ending if ending is None else ending ending = self.ending if ending is None else ending
if ending and not msg.endswith(ending): if ending and not msg.endswith(ending):
msg += ending msg += ending
style_func = [f for f in (style_func, self.style_func, lambda x:x) style_func = style_func or self.style_func
if f is not None][0]
self._out.write(force_str(style_func(msg))) self._out.write(force_str(style_func(msg)))
@ -221,8 +229,14 @@ class BaseCommand(object):
# #
# requires_system_checks = True # requires_system_checks = True
def __init__(self): def __init__(self, stdout=None, stderr=None, no_color=False):
self.style = color_style() self.stdout = OutputWrapper(stdout or sys.stdout)
self.stderr = OutputWrapper(stderr or sys.stderr)
if no_color:
self.style = no_style()
else:
self.style = color_style()
self.stderr.style_func = self.style.ERROR
# `requires_model_validation` is deprecated in favor of # `requires_model_validation` is deprecated in favor of
# `requires_system_checks`. If both options are present, an error is # `requires_system_checks`. If both options are present, an error is
@ -371,9 +385,7 @@ class BaseCommand(object):
if options.traceback or not isinstance(e, CommandError): if options.traceback or not isinstance(e, CommandError):
raise raise
# self.stderr is not guaranteed to be set here self.stderr.write('%s: %s' % (e.__class__.__name__, e))
stderr = getattr(self, 'stderr', OutputWrapper(sys.stderr, self.style.ERROR))
stderr.write('%s: %s' % (e.__class__.__name__, e))
sys.exit(1) sys.exit(1)
def execute(self, *args, **options): def execute(self, *args, **options):
@ -382,12 +394,13 @@ class BaseCommand(object):
controlled by attributes ``self.requires_system_checks`` and controlled by attributes ``self.requires_system_checks`` and
``self.requires_model_validation``, except if force-skipped). ``self.requires_model_validation``, except if force-skipped).
""" """
self.stdout = OutputWrapper(options.get('stdout', sys.stdout))
if options.get('no_color'): if options.get('no_color'):
self.style = no_style() self.style = no_style()
self.stderr = OutputWrapper(options.get('stderr', sys.stderr)) self.stderr.style_func = None
else: if options.get('stdout'):
self.stderr = OutputWrapper(options.get('stderr', sys.stderr), self.style.ERROR) self.stdout = OutputWrapper(options['stdout'])
if options.get('stderr'):
self.stderr = OutputWrapper(options.get('stderr'), self.stderr.style_func)
if self.can_import_settings: if self.can_import_settings:
from django.conf import settings # NOQA from django.conf import settings # NOQA

View File

@ -1,9 +0,0 @@
from django.core.management.base import BaseCommand
class Command(BaseCommand):
help = "Test color output"
requires_system_checks = False
def handle(self, **options):
return self.style.SQL_KEYWORD('BEGIN')

View File

@ -21,7 +21,7 @@ import django
from django import conf, get_version from django import conf, get_version
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.management import BaseCommand, CommandError, call_command from django.core.management import BaseCommand, CommandError, call_command, color
from django.db import connection from django.db import connection
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils._os import npath, upath from django.utils._os import npath, upath
@ -1392,12 +1392,83 @@ class CommandTypes(AdminScriptTestCase):
self.assertOutput(out, "Prints the CREATE TABLE, custom SQL and CREATE INDEX SQL statements for the\ngiven model module name(s).") self.assertOutput(out, "Prints the CREATE TABLE, custom SQL and CREATE INDEX SQL statements for the\ngiven model module name(s).")
self.assertEqual(out.count('optional arguments'), 1) self.assertEqual(out.count('optional arguments'), 1)
def test_no_color(self): def test_command_color(self):
"--no-color prevent colorization of the output" class Command(BaseCommand):
out = StringIO() requires_system_checks = False
call_command('color_command', no_color=True, stdout=out) def handle(self, *args, **options):
self.assertEqual(out.getvalue(), 'BEGIN\n') self.stdout.write('Hello, world!', self.style.ERROR)
self.stderr.write('Hello, world!', self.style.ERROR)
out = StringIO()
err = StringIO()
command = Command(stdout=out, stderr=err)
command.execute()
if color.supports_color():
self.assertIn('Hello, world!\n', out.getvalue())
self.assertIn('Hello, world!\n', err.getvalue())
self.assertNotEqual(out.getvalue(), 'Hello, world!\n')
self.assertNotEqual(err.getvalue(), 'Hello, world!\n')
else:
self.assertEqual(out.getvalue(), 'Hello, world!\n')
self.assertEqual(err.getvalue(), 'Hello, world!\n')
def test_command_no_color(self):
"--no-color prevent colorization of the output"
class Command(BaseCommand):
requires_system_checks = False
def handle(self, *args, **options):
self.stdout.write('Hello, world!', self.style.ERROR)
self.stderr.write('Hello, world!', self.style.ERROR)
out = StringIO()
err = StringIO()
command = Command(stdout=out, stderr=err, no_color=True)
command.execute()
self.assertEqual(out.getvalue(), 'Hello, world!\n')
self.assertEqual(err.getvalue(), 'Hello, world!\n')
out = StringIO()
err = StringIO()
command = Command(stdout=out, stderr=err)
command.execute(no_color=True)
self.assertEqual(out.getvalue(), 'Hello, world!\n')
self.assertEqual(err.getvalue(), 'Hello, world!\n')
def test_custom_stdout(self):
class Command(BaseCommand):
requires_system_checks = False
def handle(self, *args, **options):
self.stdout.write("Hello, World!")
out = StringIO()
command = Command(stdout=out)
command.execute()
self.assertEqual(out.getvalue(), "Hello, World!\n")
out.truncate(0)
new_out = StringIO()
command.execute(stdout=new_out)
self.assertEqual(out.getvalue(), "")
self.assertEqual(new_out.getvalue(), "Hello, World!\n")
def test_custom_stderr(self):
class Command(BaseCommand):
requires_system_checks = False
def handle(self, *args, **options):
self.stderr.write("Hello, World!")
err = StringIO()
command = Command(stderr=err)
command.execute()
self.assertEqual(err.getvalue(), "Hello, World!\n")
err.truncate(0)
new_err = StringIO()
command.execute(stderr=new_err)
self.assertEqual(err.getvalue(), "")
self.assertEqual(new_err.getvalue(), "Hello, World!\n")
def test_base_command(self): def test_base_command(self):
"User BaseCommands can execute when a label is provided" "User BaseCommands can execute when a label is provided"