diff --git a/_py/test/pycollect.py b/_py/test/pycollect.py index 8e6fc4111..f2fbf324a 100644 --- a/_py/test/pycollect.py +++ b/_py/test/pycollect.py @@ -161,13 +161,24 @@ class Module(py.test.collect.File, PyCollectorMixin): def setup(self): if getattr(self.obj, 'disabled', 0): py.test.skip("%r is disabled" %(self.obj,)) - mod = self.obj - if hasattr(mod, 'setup_module'): - self.obj.setup_module(mod) + if hasattr(self.obj, 'setup_module'): + #XXX: nose compat hack, move to nose plugin + # if it takes a positional arg, its probably a py.test style one + # so we pass the current module object + if inspect.getargspec(self.obj.setup_module)[0]: + self.obj.setup_module(self.obj) + else: + self.obj.setup_module() def teardown(self): if hasattr(self.obj, 'teardown_module'): - self.obj.teardown_module(self.obj) + #XXX: nose compat hack, move to nose plugin + # if it takes a positional arg, its probably a py.test style one + # so we pass the current module object + if inspect.getargspec(self.obj.teardown_module)[0]: + self.obj.teardown_module(self.obj) + else: + self.obj.teardown_module() class Class(PyCollectorMixin, py.test.collect.Collector): diff --git a/testing/pytest/plugin/test_pytest_nose.py b/testing/pytest/plugin/test_pytest_nose.py index bb8b80d5f..fea4f0cf8 100644 --- a/testing/pytest/plugin/test_pytest_nose.py +++ b/testing/pytest/plugin/test_pytest_nose.py @@ -115,3 +115,24 @@ def test_module_level_setup(testdir): result.stdout.fnmatch_lines([ "*2 passed*", ]) + +def test_nose_style_setup_teardown(testdir): + testdir.makepyfile(""" + l = [] + def setup_module(): + l.append(1) + + def teardown_module(): + del l[0] + + def test_hello(): + assert l == [1] + + def test_world(): + assert l == [1] + """) + result = testdir.runpytest('-p', 'nose') + result.stdout.fnmatch_lines([ + "*2 passed*", + ]) +