diff --git a/CHANGELOG b/CHANGELOG index 022c448dd..0ddacdf9b 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -4,6 +4,14 @@ Changes between 1.3.1 and 1.3.x New features ++++++++++++++++++ +- fix issue103: introduce additional "with py.test.raises(exc):" form, example:: + + with py.test.raises(ZeroDivisionError): + x = 0 + 1 / x + + (thanks Ronny Pfannschmidt) + - Funcarg factories can now dynamically apply a marker to a test invocation. This is particularly useful if a factory provides parameters to a test which you expect-to-fail: diff --git a/doc/test/features.txt b/doc/test/features.txt index ed2060173..b0d1e2350 100644 --- a/doc/test/features.txt +++ b/doc/test/features.txt @@ -130,14 +130,20 @@ first failed and then succeeded. asserting expected exceptions ---------------------------------------- -In order to write assertions about exceptions, you use -one of two forms:: +In order to write assertions about exceptions, you can use +``py.test.raises`` as a context manager like this:: - py.test.raises(Exception, func, *args, **kwargs) - py.test.raises(Exception, "func(*args, **kwargs)") + with py.test.raises(ZeroDivisionError): + 1 / 0 + +If you want to write test code that works on Python2.4 as well, +you may also use two other ways to test for an expected exception:: + + py.test.raises(ExpectedException, func, *args, **kwargs) + py.test.raises(ExpectedException, "func(*args, **kwargs)") both of which execute the specified function with args and kwargs and -asserts that the given ``Exception`` is raised. The reporter will +asserts that the given ``ExpectedException`` is raised. The reporter will provide you with helpful output in case of failures such as *no exception* or *wrong exception*. diff --git a/py/_path/gateway/__init__.py b/py/_path/gateway/__init__.py deleted file mode 100644 index 792d60054..000000000 --- a/py/_path/gateway/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# diff --git a/py/_path/gateway/channeltest.py b/py/_path/gateway/channeltest.py deleted file mode 100644 index ac821aeb9..000000000 --- a/py/_path/gateway/channeltest.py +++ /dev/null @@ -1,65 +0,0 @@ -import threading - - -class PathServer: - - def __init__(self, channel): - self.channel = channel - self.C2P = {} - self.next_id = 0 - threading.Thread(target=self.serve).start() - - def p2c(self, path): - id = self.next_id - self.next_id += 1 - self.C2P[id] = path - return id - - def command_LIST(self, id, *args): - path = self.C2P[id] - answer = [(self.p2c(p), p.basename) for p in path.listdir(*args)] - self.channel.send(answer) - - def command_DEL(self, id): - del self.C2P[id] - - def command_GET(self, id, spec): - path = self.C2P[id] - self.channel.send(path._getbyspec(spec)) - - def command_READ(self, id): - path = self.C2P[id] - self.channel.send(path.read()) - - def command_JOIN(self, id, resultid, *args): - path = self.C2P[id] - assert resultid not in self.C2P - self.C2P[resultid] = path.join(*args) - - def command_DIRPATH(self, id, resultid): - path = self.C2P[id] - assert resultid not in self.C2P - self.C2P[resultid] = path.dirpath() - - def serve(self): - try: - while 1: - msg = self.channel.receive() - meth = getattr(self, 'command_' + msg[0]) - meth(*msg[1:]) - except EOFError: - pass - -if __name__ == '__main__': - import py - gw = execnet.PopenGateway() - channel = gw._channelfactory.new() - srv = PathServer(channel) - c = gw.remote_exec(""" - import remotepath - p = remotepath.RemotePath(channel.receive(), channel.receive()) - channel.send(len(p.listdir())) - """) - c.send(channel) - c.send(srv.p2c(py.path.local('/tmp'))) - print(c.receive()) diff --git a/py/_path/gateway/channeltest2.py b/py/_path/gateway/channeltest2.py deleted file mode 100644 index 827abb7d3..000000000 --- a/py/_path/gateway/channeltest2.py +++ /dev/null @@ -1,21 +0,0 @@ -import py -from remotepath import RemotePath - - -SRC = open('channeltest.py', 'r').read() - -SRC += ''' -import py -srv = PathServer(channel.receive()) -channel.send(srv.p2c(py.path.local("/tmp"))) -''' - - -#gw = execnet.SshGateway('codespeak.net') -gw = execnet.PopenGateway() -gw.remote_init_threads(5) -c = gw.remote_exec(SRC, stdout=py.std.sys.stdout, stderr=py.std.sys.stderr) -subchannel = gw._channelfactory.new() -c.send(subchannel) - -p = RemotePath(subchannel, c.receive()) diff --git a/py/_path/gateway/remotepath.py b/py/_path/gateway/remotepath.py deleted file mode 100644 index 149baa435..000000000 --- a/py/_path/gateway/remotepath.py +++ /dev/null @@ -1,47 +0,0 @@ -import py, itertools -from py._path import common - -COUNTER = itertools.count() - -class RemotePath(common.PathBase): - sep = '/' - - def __init__(self, channel, id, basename=None): - self._channel = channel - self._id = id - self._basename = basename - self._specs = {} - - def __del__(self): - self._channel.send(('DEL', self._id)) - - def __repr__(self): - return 'RemotePath(%s)' % self.basename - - def listdir(self, *args): - self._channel.send(('LIST', self._id) + args) - return [RemotePath(self._channel, id, basename) - for (id, basename) in self._channel.receive()] - - def dirpath(self): - id = ~COUNTER.next() - self._channel.send(('DIRPATH', self._id, id)) - return RemotePath(self._channel, id) - - def join(self, *args): - id = ~COUNTER.next() - self._channel.send(('JOIN', self._id, id) + args) - return RemotePath(self._channel, id) - - def _getbyspec(self, spec): - parts = spec.split(',') - ask = [x for x in parts if x not in self._specs] - if ask: - self._channel.send(('GET', self._id, ",".join(ask))) - for part, value in zip(ask, self._channel.receive()): - self._specs[part] = value - return [self._specs[x] for x in parts] - - def read(self): - self._channel.send(('READ', self._id)) - return self._channel.receive() diff --git a/py/_plugin/pytest_runner.py b/py/_plugin/pytest_runner.py index 6a93f7b35..2bcb63b90 100644 --- a/py/_plugin/pytest_runner.py +++ b/py/_plugin/pytest_runner.py @@ -353,18 +353,30 @@ def xfail(reason=""): xfail.Exception = XFailed def raises(ExpectedException, *args, **kwargs): - """ if args[0] is callable: raise AssertionError if calling it with + """ assert that a code block/function call raises an exception. + + If using Python 2.5 or above, you may use this function as a + context manager:: + + >>> with raises(ZeroDivisionError): + ... 1/0 + + Or you can one of two forms: + + if args[0] is callable: raise AssertionError if calling it with the remaining arguments does not raise the expected exception. if args[0] is a string: raise AssertionError if executing the the string in the calling scope does not raise expected exception. - for examples: - x = 5 - raises(TypeError, lambda x: x + 'hello', x=x) - raises(TypeError, "x + 'hello'") + examples: + >>> x = 5 + >>> raises(TypeError, lambda x: x + 'hello', x=x) + >>> raises(TypeError, "x + 'hello'") """ - __tracebackhide__ = True - assert args - if isinstance(args[0], str): + __tracebackhide__ = True + + if not args: + return RaisesContext(ExpectedException) + elif isinstance(args[0], str): code, = args assert isinstance(code, str) frame = sys._getframe(1) @@ -391,6 +403,26 @@ def raises(ExpectedException, *args, **kwargs): raise ExceptionFailure(msg="DID NOT RAISE", expr=args, expected=ExpectedException) + +class RaisesContext(object): + + def __init__(self, ExpectedException): + self.ExpectedException = ExpectedException + self.excinfo = None + + def __enter__(self): + return self + + def __exit__(self, *tp): + __tracebackhide__ = True + if tp[0] is None: + raise ExceptionFailure(msg="DID NOT RAISE", + expr=(), + expected=self.ExpectedException) + self.excinfo = py.code.ExceptionInfo(tp) + return issubclass(self.excinfo.type, self.ExpectedException) + + raises.Exception = ExceptionFailure def importorskip(modname, minversion=None): diff --git a/testing/plugin/test_pytest_runner.py b/testing/plugin/test_pytest_runner.py index 941909d89..551d8c437 100644 --- a/testing/plugin/test_pytest_runner.py +++ b/testing/plugin/test_pytest_runner.py @@ -340,6 +340,37 @@ class TestRaises: except py.test.raises.Exception: pass + @py.test.mark.skipif('sys.version < "2.5"') + def test_raises_as_contextmanager(self, testdir): + testdir.makepyfile(""" + from __future__ import with_statement + import py + + def test_simple(): + with py.test.raises(ZeroDivisionError) as ctx: + 1/0 + + print ctx.excinfo + assert ctx.excinfo.type is ZeroDivisionError + + + def test_noraise(): + with py.test.raises(py.test.raises.Exception): + with py.test.raises(ValueError): + int() + + def test_raise_wrong_exception_passes_by(): + with py.test.raises(ZeroDivisionError): + with py.test.raises(ValueError): + 1/0 + """) + result = testdir.runpytest() + result.stdout.fnmatch_lines([ + '*3 passed*', + ]) + + + def test_pytest_exit(): try: py.test.exit("hello")