support using py.test.raises in context manager style

--HG--
branch : trunk
This commit is contained in:
Ronny Pfannschmidt 2010-06-09 10:50:00 +02:00
parent 8ece058256
commit d1c8209875
2 changed files with 66 additions and 6 deletions

View File

@ -358,13 +358,22 @@ def raises(ExpectedException, *args, **kwargs):
if args[0] is a string: raise AssertionError if executing the
the string in the calling scope does not raise expected exception.
for examples:
x = 5
raises(TypeError, lambda x: x + 'hello', x=x)
raises(TypeError, "x + 'hello'")
>>> x = 5
>>> raises(TypeError, lambda x: x + 'hello', x=x)
>>> raises(TypeError, "x + 'hello'")
if no code/callable is given, it asumes you want a contextmanager
>>> with raises(ZeroDivisionError):
... 1/0
>>> with raises(TypeError):
... 1 + 'a'
"""
__tracebackhide__ = True
assert args
if isinstance(args[0], str):
__tracebackhide__ = True
if not args:
return RaisesContext(ExpectedException)
elif isinstance(args[0], str):
code, = args
assert isinstance(code, str)
frame = sys._getframe(1)
@ -391,6 +400,26 @@ def raises(ExpectedException, *args, **kwargs):
raise ExceptionFailure(msg="DID NOT RAISE",
expr=args, expected=ExpectedException)
class RaisesContext(object):
def __init__(self, ExpectedException):
self.ExpectedException = ExpectedException
self.excinfo = None
def __enter__(self):
return self
def __exit__(self, *tp):
__tracebackhide__ = True
if tp[0] is None:
raise ExceptionFailure(msg="DID NOT RAISE",
expr=(),
expected=self.ExpectedException)
self.excinfo = py.code.ExceptionInfo(tp)
return issubclass(self.excinfo.type, self.ExpectedException)
raises.Exception = ExceptionFailure
def importorskip(modname, minversion=None):

View File

@ -340,6 +340,37 @@ class TestRaises:
except py.test.raises.Exception:
pass
@py.test.mark.skipif('sys.version < "2.5"')
def test_raises_as_contextmanager(self, testdir):
testdir.makepyfile("""
from __future__ import with_statement
import py
def test_simple():
with py.test.raises(ZeroDivisionError) as ctx:
1/0
print ctx.excinfo
assert ctx.excinfo.type is ZeroDivisionError
def test_noraise():
with py.test.raises(py.test.raises.Exception):
with py.test.raises(ValueError):
int()
def test_raise_wrong_exception_passes_by():
with py.test.raises(ZeroDivisionError):
with py.test.raises(ValueError):
1/0
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines([
'*3 passed*',
])
def test_pytest_exit():
try:
py.test.exit("hello")