diff --git a/py/_plugin/pytest_runner.py b/py/_plugin/pytest_runner.py index 0c39ba775..626edb2cf 100644 --- a/py/_plugin/pytest_runner.py +++ b/py/_plugin/pytest_runner.py @@ -358,13 +358,22 @@ def raises(ExpectedException, *args, **kwargs): if args[0] is a string: raise AssertionError if executing the the string in the calling scope does not raise expected exception. for examples: - x = 5 - raises(TypeError, lambda x: x + 'hello', x=x) - raises(TypeError, "x + 'hello'") + >>> x = 5 + >>> raises(TypeError, lambda x: x + 'hello', x=x) + >>> raises(TypeError, "x + 'hello'") + + if no code/callable is given, it asumes you want a contextmanager + + >>> with raises(ZeroDivisionError): + ... 1/0 + >>> with raises(TypeError): + ... 1 + 'a' """ - __tracebackhide__ = True - assert args - if isinstance(args[0], str): + __tracebackhide__ = True + + if not args: + return RaisesContext(ExpectedException) + elif isinstance(args[0], str): code, = args assert isinstance(code, str) frame = sys._getframe(1) @@ -391,6 +400,26 @@ def raises(ExpectedException, *args, **kwargs): raise ExceptionFailure(msg="DID NOT RAISE", expr=args, expected=ExpectedException) + +class RaisesContext(object): + + def __init__(self, ExpectedException): + self.ExpectedException = ExpectedException + self.excinfo = None + + def __enter__(self): + return self + + def __exit__(self, *tp): + __tracebackhide__ = True + if tp[0] is None: + raise ExceptionFailure(msg="DID NOT RAISE", + expr=(), + expected=self.ExpectedException) + self.excinfo = py.code.ExceptionInfo(tp) + return issubclass(self.excinfo.type, self.ExpectedException) + + raises.Exception = ExceptionFailure def importorskip(modname, minversion=None): diff --git a/testing/plugin/test_pytest_runner.py b/testing/plugin/test_pytest_runner.py index 941909d89..551d8c437 100644 --- a/testing/plugin/test_pytest_runner.py +++ b/testing/plugin/test_pytest_runner.py @@ -340,6 +340,37 @@ class TestRaises: except py.test.raises.Exception: pass + @py.test.mark.skipif('sys.version < "2.5"') + def test_raises_as_contextmanager(self, testdir): + testdir.makepyfile(""" + from __future__ import with_statement + import py + + def test_simple(): + with py.test.raises(ZeroDivisionError) as ctx: + 1/0 + + print ctx.excinfo + assert ctx.excinfo.type is ZeroDivisionError + + + def test_noraise(): + with py.test.raises(py.test.raises.Exception): + with py.test.raises(ValueError): + int() + + def test_raise_wrong_exception_passes_by(): + with py.test.raises(ZeroDivisionError): + with py.test.raises(ValueError): + 1/0 + """) + result = testdir.runpytest() + result.stdout.fnmatch_lines([ + '*3 passed*', + ]) + + + def test_pytest_exit(): try: py.test.exit("hello")