allow unittest test functions to work with the "pytestmark" mechanism

by refactoring mark/keyword handling and initialization

--HG--
branch : trunk
This commit is contained in:
holger krekel 2010-10-25 23:08:56 +02:00
parent a6f10a6d80
commit 4480401119
11 changed files with 90 additions and 88 deletions

View File

@ -23,6 +23,8 @@ Changes between 1.3.4 and 2.0.0dev0
- "xpass" (unexpected pass) tests don't cause exitcode!=0 - "xpass" (unexpected pass) tests don't cause exitcode!=0
- fix issue131 / issue60 - importing doctests in __init__ files used as namespace packages - fix issue131 / issue60 - importing doctests in __init__ files used as namespace packages
- fix issue93 stdout/stderr is captured while importing conftest.py - fix issue93 stdout/stderr is captured while importing conftest.py
- fix bug: unittest collected functions now also can have "pytestmark"
applied at class/module level
Changes between 1.3.3 and 1.3.4 Changes between 1.3.3 and 1.3.4
---------------------------------------------- ----------------------------------------------

View File

@ -5,7 +5,7 @@ see http://pytest.org for documentation and details
(c) Holger Krekel and others, 2004-2010 (c) Holger Krekel and others, 2004-2010
""" """
__version__ = '2.0.0.dev7' __version__ = '2.0.0.dev8'
__all__ = ['config', 'cmdline'] __all__ = ['config', 'cmdline']

View File

@ -7,8 +7,8 @@ assert py.__version__.split(".")[:2] >= ['2', '0'], ("installation problem: "
"%s is too old, remove or upgrade 'py'" % (py.__version__)) "%s is too old, remove or upgrade 'py'" % (py.__version__))
default_plugins = ( default_plugins = (
"config session terminal python runner pdb capture mark skipping tmpdir " "config session terminal python runner pdb capture unittest mark skipping "
"monkeypatch recwarn pastebin unittest helpconfig nose assertion genscript " "tmpdir monkeypatch recwarn pastebin helpconfig nose assertion genscript "
"junitxml doctest").split() "junitxml doctest").split()
IMPORTPREFIX = "pytest_" IMPORTPREFIX = "pytest_"
@ -282,7 +282,7 @@ class MultiCall:
kwargs[argname] = self kwargs[argname] = self
return kwargs return kwargs
def varnames(func, cache={}): def varnames(func):
if not inspect.isfunction(func) and not inspect.ismethod(func): if not inspect.isfunction(func) and not inspect.ismethod(func):
func = getattr(func, '__call__', func) func = getattr(func, '__call__', func)
ismethod = inspect.ismethod(func) ismethod = inspect.ismethod(func)

View File

@ -7,11 +7,12 @@ def pytest_namespace():
def pytest_addoption(parser): def pytest_addoption(parser):
group = parser.getgroup("general") group = parser.getgroup("general")
group._addoption('-k', group._addoption('-k',
action="store", dest="keyword", default='', action="store", dest="keyword", default='', metavar="KEYWORDEXPR",
help="only run test items matching the given " help="only run tests which match given keyword expression. "
"space separated keywords. precede a keyword with '-' to negate. " "An expression consists of space-separated terms. "
"Terminate the expression with ':' to treat a match as a signal " "Each term must match. Precede a term with '-' to negate. "
"to run all subsequent tests. ") "Terminate expression with ':' to make the first match match "
"all subsequent tests (usually file-order). ")
def pytest_collection_modifyitems(items, config): def pytest_collection_modifyitems(items, config):
keywordexpr = config.option.keyword keywordexpr = config.option.keyword
@ -42,32 +43,31 @@ def skipbykeyword(colitem, keywordexpr):
""" """
if not keywordexpr: if not keywordexpr:
return return
chain = colitem.listchain()
itemkeywords = getkeywords(colitem)
for key in filter(None, keywordexpr.split()): for key in filter(None, keywordexpr.split()):
eor = key[:1] == '-' eor = key[:1] == '-'
if eor: if eor:
key = key[1:] key = key[1:]
if not (eor ^ matchonekeyword(key, chain)): if not (eor ^ matchonekeyword(key, itemkeywords)):
return True return True
def matchonekeyword(key, chain): def getkeywords(node):
elems = key.split(".") keywords = {}
# XXX O(n^2), anyone cares? while node is not None:
chain = [item.keywords for item in chain if item.keywords] keywords.update(node.keywords)
for start, _ in enumerate(chain): node = node.parent
if start + len(elems) > len(chain): return keywords
return False
for num, elem in enumerate(elems):
for keyword in chain[num + start]: def matchonekeyword(key, itemkeywords):
ok = False for elem in key.split("."):
if elem in keyword: for kw in itemkeywords:
ok = True if elem in kw:
break
if not ok:
break break
if num == len(elems) - 1 and ok: else:
return True return False
return False return True
class MarkGenerator: class MarkGenerator:
""" Factory for :class:`MarkDecorator` objects - exposed as """ Factory for :class:`MarkDecorator` objects - exposed as
@ -155,21 +155,22 @@ class MarkInfo:
return "<MarkInfo %r args=%r kwargs=%r>" % ( return "<MarkInfo %r args=%r kwargs=%r>" % (
self._name, self.args, self.kwargs) self._name, self.args, self.kwargs)
def pytest_pycollect_makeitem(__multicall__, collector, name, obj): def pytest_log_itemcollect(item):
item = __multicall__.execute() if not isinstance(item, py.test.collect.Function):
if isinstance(item, py.test.collect.Function): return
cls = collector.getparent(py.test.collect.Class) try:
mod = collector.getparent(py.test.collect.Module) func = item.obj.__func__
func = item.obj except AttributeError:
func = getattr(func, '__func__', func) # py3 func = getattr(item.obj, 'im_func', item.obj)
func = getattr(func, 'im_func', func) # py2 pyclasses = (py.test.collect.Class, py.test.collect.Module)
for parent in [x for x in (mod, cls) if x]: for node in item.listchain():
marker = getattr(parent.obj, 'pytestmark', None) if isinstance(node, pyclasses):
marker = getattr(node.obj, 'pytestmark', None)
if marker is not None: if marker is not None:
if not isinstance(marker, list): if isinstance(marker, list):
marker = [marker] for mark in marker:
for mark in marker:
if isinstance(mark, MarkDecorator):
mark(func) mark(func)
item.keywords.update(py.builtin._getfuncdict(func) or {}) else:
return item marker(func)
node = node.parent
item.keywords.update(py.builtin._getfuncdict(func))

View File

@ -354,13 +354,10 @@ class TmpTestdir:
return config return config
def getitem(self, source, funcname="test_func"): def getitem(self, source, funcname="test_func"):
modcol = self.getmodulecol(source) for item in self.getitems(source):
moditems = modcol.collect()
for item in modcol.collect():
if item.name == funcname: if item.name == funcname:
return item return item
else: assert 0, "%r item not found in module:\n%s" %(funcname, source)
assert 0, "%r item not found in module:\n%s" %(funcname, source)
def getitems(self, source): def getitems(self, source):
modcol = self.getmodulecol(source) modcol = self.getmodulecol(source)

View File

@ -136,6 +136,7 @@ class PyCollectorMixin(PyobjMixin, pytest.collect.Collector):
def collect(self): def collect(self):
# NB. we avoid random getattrs and peek in the __dict__ instead # NB. we avoid random getattrs and peek in the __dict__ instead
# (XXX originally introduced from a PyPy need, still true?)
dicts = [getattr(self.obj, '__dict__', {})] dicts = [getattr(self.obj, '__dict__', {})]
for basecls in inspect.getmro(self.obj.__class__): for basecls in inspect.getmro(self.obj.__class__):
dicts.append(basecls.__dict__) dicts.append(basecls.__dict__)
@ -254,9 +255,6 @@ class Instance(PyCollectorMixin, pytest.collect.Collector):
def _getobj(self): def _getobj(self):
return self.parent.obj() return self.parent.obj()
def _keywords(self):
return []
def newinstance(self): def newinstance(self):
self.obj = self._getobj() self.obj = self._getobj()
return self.obj return self.obj
@ -449,6 +447,7 @@ def hasinit(obj):
def getfuncargnames(function): def getfuncargnames(function):
# XXX merge with _core.py's varnames
argnames = py.std.inspect.getargs(py.code.getrawcode(function))[0] argnames = py.std.inspect.getargs(py.code.getrawcode(function))[0]
startindex = py.std.inspect.ismethod(function) and 1 or 0 startindex = py.std.inspect.ismethod(function) and 1 or 0
defaults = getattr(function, 'func_defaults', defaults = getattr(function, 'func_defaults',

View File

@ -249,7 +249,7 @@ class CollectReport(BaseReport):
self.fspath = fspath self.fspath = fspath
self.outcome = outcome self.outcome = outcome
self.longrepr = longrepr self.longrepr = longrepr
self.result = result self.result = result or []
self.reason = reason self.reason = reason
@property @property

View File

@ -266,7 +266,7 @@ class Collection:
node.ihook.pytest_collectstart(collector=node) node.ihook.pytest_collectstart(collector=node)
rep = node.ihook.pytest_make_collect_report(collector=node) rep = node.ihook.pytest_make_collect_report(collector=node)
if rep.passed: if rep.passed:
for subnode in rep.result or []: for subnode in rep.result:
for x in self.genitems(subnode): for x in self.genitems(subnode):
yield x yield x
node.ihook.pytest_collectreport(report=rep) node.ihook.pytest_collectreport(report=rep)
@ -328,7 +328,7 @@ class Node(object):
#: the file where this item is contained/collected from. #: the file where this item is contained/collected from.
self.fspath = getattr(parent, 'fspath', None) self.fspath = getattr(parent, 'fspath', None)
self.ihook = HookProxy(self) self.ihook = HookProxy(self)
self.keywords = self.readkeywords() self.keywords = {self.name: True}
def __repr__(self): def __repr__(self):
if getattr(self.config.option, 'debug', False): if getattr(self.config.option, 'debug', False):
@ -396,12 +396,6 @@ class Node(object):
current = current.parent current = current.parent
return current return current
def readkeywords(self):
return dict([(x, True) for x in self._keywords()])
def _keywords(self):
return [self.name]
def _prunetraceback(self, traceback): def _prunetraceback(self, traceback):
return traceback return traceback

View File

@ -22,7 +22,7 @@ def main():
name='pytest', name='pytest',
description='py.test: simple powerful testing with Python', description='py.test: simple powerful testing with Python',
long_description = long_description, long_description = long_description,
version='2.0.0.dev7', version='2.0.0.dev8',
url='http://pytest.org', url='http://pytest.org',
license='MIT license', license='MIT license',
platforms=['unix', 'linux', 'osx', 'cygwin', 'win32'], platforms=['unix', 'linux', 'osx', 'cygwin', 'win32'],

View File

@ -8,13 +8,15 @@ class TestMark:
def test_pytest_mark_bare(self): def test_pytest_mark_bare(self):
mark = Mark() mark = Mark()
def f(): pass def f():
pass
mark.hello(f) mark.hello(f)
assert f.hello assert f.hello
def test_pytest_mark_keywords(self): def test_pytest_mark_keywords(self):
mark = Mark() mark = Mark()
def f(): pass def f():
pass
mark.world(x=3, y=4)(f) mark.world(x=3, y=4)(f)
assert f.world assert f.world
assert f.world.kwargs['x'] == 3 assert f.world.kwargs['x'] == 3
@ -22,7 +24,8 @@ class TestMark:
def test_apply_multiple_and_merge(self): def test_apply_multiple_and_merge(self):
mark = Mark() mark = Mark()
def f(): pass def f():
pass
marker = mark.world marker = mark.world
mark.world(x=3)(f) mark.world(x=3)(f)
assert f.world.kwargs['x'] == 3 assert f.world.kwargs['x'] == 3
@ -35,7 +38,8 @@ class TestMark:
def test_pytest_mark_positional(self): def test_pytest_mark_positional(self):
mark = Mark() mark = Mark()
def f(): pass def f():
pass
mark.world("hello")(f) mark.world("hello")(f)
assert f.world.args[0] == "hello" assert f.world.args[0] == "hello"
mark.world("world")(f) mark.world("world")(f)
@ -62,7 +66,7 @@ class TestFunctional:
assert 'hello' in keywords assert 'hello' in keywords
def test_marklist_per_class(self, testdir): def test_marklist_per_class(self, testdir):
modcol = testdir.getmodulecol(""" item = testdir.getitem("""
import py import py
class TestClass: class TestClass:
pytestmark = [py.test.mark.hello, py.test.mark.world] pytestmark = [py.test.mark.hello, py.test.mark.world]
@ -70,13 +74,11 @@ class TestFunctional:
assert TestClass.test_func.hello assert TestClass.test_func.hello
assert TestClass.test_func.world assert TestClass.test_func.world
""") """)
clscol = modcol.collect()[0]
item = clscol.collect()[0].collect()[0]
keywords = item.keywords keywords = item.keywords
assert 'hello' in keywords assert 'hello' in keywords
def test_marklist_per_module(self, testdir): def test_marklist_per_module(self, testdir):
modcol = testdir.getmodulecol(""" item = testdir.getitem("""
import py import py
pytestmark = [py.test.mark.hello, py.test.mark.world] pytestmark = [py.test.mark.hello, py.test.mark.world]
class TestClass: class TestClass:
@ -84,29 +86,25 @@ class TestFunctional:
assert TestClass.test_func.hello assert TestClass.test_func.hello
assert TestClass.test_func.world assert TestClass.test_func.world
""") """)
clscol = modcol.collect()[0]
item = clscol.collect()[0].collect()[0]
keywords = item.keywords keywords = item.keywords
assert 'hello' in keywords assert 'hello' in keywords
assert 'world' in keywords assert 'world' in keywords
@py.test.mark.skipif("sys.version_info < (2,6)") @py.test.mark.skipif("sys.version_info < (2,6)")
def test_mark_per_class_decorator(self, testdir): def test_mark_per_class_decorator(self, testdir):
modcol = testdir.getmodulecol(""" item = testdir.getitem("""
import py import py
@py.test.mark.hello @py.test.mark.hello
class TestClass: class TestClass:
def test_func(self): def test_func(self):
assert TestClass.test_func.hello assert TestClass.test_func.hello
""") """)
clscol = modcol.collect()[0]
item = clscol.collect()[0].collect()[0]
keywords = item.keywords keywords = item.keywords
assert 'hello' in keywords assert 'hello' in keywords
@py.test.mark.skipif("sys.version_info < (2,6)") @py.test.mark.skipif("sys.version_info < (2,6)")
def test_mark_per_class_decorator_plus_existing_dec(self, testdir): def test_mark_per_class_decorator_plus_existing_dec(self, testdir):
modcol = testdir.getmodulecol(""" item = testdir.getitem("""
import py import py
@py.test.mark.hello @py.test.mark.hello
class TestClass: class TestClass:
@ -115,8 +113,6 @@ class TestFunctional:
assert TestClass.test_func.hello assert TestClass.test_func.hello
assert TestClass.test_func.world assert TestClass.test_func.world
""") """)
clscol = modcol.collect()[0]
item = clscol.collect()[0].collect()[0]
keywords = item.keywords keywords = item.keywords
assert 'hello' in keywords assert 'hello' in keywords
assert 'world' in keywords assert 'world' in keywords
@ -140,14 +136,15 @@ class TestFunctional:
assert marker.kwargs == {'x': 3, 'y': 2, 'z': 4} assert marker.kwargs == {'x': 3, 'y': 2, 'z': 4}
def test_mark_other(self, testdir): def test_mark_other(self, testdir):
item = testdir.getitem(""" py.test.raises(TypeError, '''
import py testdir.getitem("""
class pytestmark: import py
pass class pytestmark:
def test_func(): pass
pass def test_func():
""") pass
keywords = item.keywords """)
''')
def test_mark_dynamically_in_funcarg(self, testdir): def test_mark_dynamically_in_funcarg(self, testdir):
testdir.makeconftest(""" testdir.makeconftest("""
@ -223,7 +220,8 @@ class Test_genitems:
class TestKeywordSelection: class TestKeywordSelection:
def test_select_simple(self, testdir): def test_select_simple(self, testdir):
file_test = testdir.makepyfile(""" file_test = testdir.makepyfile("""
def test_one(): assert 0 def test_one():
assert 0
class TestClass(object): class TestClass(object):
def test_method_one(self): def test_method_one(self):
assert 42 == 43 assert 42 == 43

View File

@ -70,3 +70,14 @@ def test_teardown(testdir):
assert passed == 2 assert passed == 2
assert passed + skipped + failed == 2 assert passed + skipped + failed == 2
def test_module_level_pytestmark(testdir):
testpath = testdir.makepyfile("""
import unittest
import py
pytestmark = py.test.mark.xfail
class MyTestCase(unittest.TestCase):
def test_func1(self):
assert 0
""")
reprec = testdir.inline_run(testpath, "-s")
reprec.assertoutcome(skipped=1)