Add package scoped fixtures #2283

This commit is contained in:
turturica 2018-04-11 15:39:42 -07:00
parent 372bcdba0c
commit 2b1410895e
8 changed files with 112 additions and 15 deletions

View File

@ -87,6 +87,7 @@ Hugo van Kemenade
Hui Wang (coldnight) Hui Wang (coldnight)
Ian Bicking Ian Bicking
Ian Lesperance Ian Lesperance
Ionuț Turturică
Jaap Broekhuizen Jaap Broekhuizen
Jan Balster Jan Balster
Janne Vanhala Janne Vanhala

View File

@ -36,6 +36,7 @@ def pytest_sessionstart(session):
import _pytest.nodes import _pytest.nodes
scopename2class.update({ scopename2class.update({
'package': _pytest.python.Package,
'class': _pytest.python.Class, 'class': _pytest.python.Class,
'module': _pytest.python.Module, 'module': _pytest.python.Module,
'function': _pytest.nodes.Item, 'function': _pytest.nodes.Item,
@ -48,6 +49,7 @@ scopename2class = {}
scope2props = dict(session=()) scope2props = dict(session=())
scope2props["package"] = ("fspath",)
scope2props["module"] = ("fspath", "module") scope2props["module"] = ("fspath", "module")
scope2props["class"] = scope2props["module"] + ("cls",) scope2props["class"] = scope2props["module"] + ("cls",)
scope2props["instance"] = scope2props["class"] + ("instance", ) scope2props["instance"] = scope2props["class"] + ("instance", )
@ -156,9 +158,11 @@ def get_parametrized_fixture_keys(item, scopenum):
continue continue
if scopenum == 0: # session if scopenum == 0: # session
key = (argname, param_index) key = (argname, param_index)
elif scopenum == 1: # module elif scopenum == 1: # package
key = (argname, param_index, item.fspath) key = (argname, param_index, item.fspath)
elif scopenum == 2: # class elif scopenum == 2: # module
key = (argname, param_index, item.fspath)
elif scopenum == 3: # class
key = (argname, param_index, item.fspath, item.cls) key = (argname, param_index, item.fspath, item.cls)
yield key yield key
@ -596,7 +600,7 @@ class ScopeMismatchError(Exception):
""" """
scopes = "session module class function".split() scopes = "session package module class function".split()
scopenum_function = scopes.index("function") scopenum_function = scopes.index("function")

View File

@ -405,17 +405,30 @@ class Session(nodes.FSCollector):
def _collect(self, arg): def _collect(self, arg):
names = self._parsearg(arg) names = self._parsearg(arg)
path = names.pop(0) argpath = names.pop(0)
if path.check(dir=1): paths = []
if argpath.check(dir=1):
assert not names, "invalid arg %r" % (arg,) assert not names, "invalid arg %r" % (arg,)
for path in path.visit(fil=lambda x: x.check(file=1), for path in argpath.visit(fil=lambda x: x.check(file=1),
rec=self._recurse, bf=True, sort=True): rec=self._recurse, bf=True, sort=True):
for x in self._collectfile(path): pkginit = path.dirpath().join('__init__.py')
yield x if pkginit.exists() and not any(x in pkginit.parts() for x in paths):
for x in self._collectfile(pkginit):
yield x
paths.append(x.fspath.dirpath())
if not any(x in path.parts() for x in paths):
for x in self._collectfile(path):
yield x
else: else:
assert path.check(file=1) assert argpath.check(file=1)
for x in self.matchnodes(self._collectfile(path), names): pkginit = argpath.dirpath().join('__init__.py')
yield x if not self.isinitpath(argpath) and pkginit.exists():
for x in self._collectfile(pkginit):
yield x
else:
for x in self.matchnodes(self._collectfile(argpath), names):
yield x
def _collectfile(self, path): def _collectfile(self, path):
ihook = self.gethookproxy(path) ihook = self.gethookproxy(path)

View File

@ -17,6 +17,7 @@ from _pytest.mark import MarkerError
from _pytest.config import hookimpl from _pytest.config import hookimpl
import _pytest import _pytest
from _pytest.main import Session
import pluggy import pluggy
from _pytest import fixtures from _pytest import fixtures
from _pytest import nodes from _pytest import nodes
@ -157,7 +158,7 @@ def pytest_collect_file(path, parent):
ext = path.ext ext = path.ext
if ext == ".py": if ext == ".py":
if not parent.session.isinitpath(path): if not parent.session.isinitpath(path):
for pat in parent.config.getini('python_files'): for pat in parent.config.getini('python_files') + ['__init__.py']:
if path.fnmatch(pat): if path.fnmatch(pat):
break break
else: else:
@ -167,9 +168,23 @@ def pytest_collect_file(path, parent):
def pytest_pycollect_makemodule(path, parent): def pytest_pycollect_makemodule(path, parent):
if path.basename == '__init__.py':
return Package(path, parent)
return Module(path, parent) return Module(path, parent)
def pytest_ignore_collect(path, config):
# Skip duplicate packages.
keepduplicates = config.getoption("keepduplicates")
if keepduplicates:
duplicate_paths = config.pluginmanager._duplicatepaths
if path.basename == '__init__.py':
if path in duplicate_paths:
return True
else:
duplicate_paths.add(path)
@hookimpl(hookwrapper=True) @hookimpl(hookwrapper=True)
def pytest_pycollect_makeitem(collector, name, obj): def pytest_pycollect_makeitem(collector, name, obj):
outcome = yield outcome = yield
@ -475,6 +490,36 @@ class Module(nodes.File, PyCollector):
self.addfinalizer(teardown_module) self.addfinalizer(teardown_module)
class Package(Session, Module):
def __init__(self, fspath, parent=None, config=None, session=None, nodeid=None):
session = parent.session
nodes.FSCollector.__init__(
self, fspath, parent=parent,
config=config, session=session, nodeid=nodeid)
self.name = fspath.pyimport().__name__
self.trace = session.trace
self._norecursepatterns = session._norecursepatterns
for path in list(session.config.pluginmanager._duplicatepaths):
if path.dirname == fspath.dirname and path != fspath:
session.config.pluginmanager._duplicatepaths.remove(path)
pass
def isinitpath(self, path):
return path in self.session._initialpaths
def collect(self):
path = self.fspath.dirpath()
pkg_prefix = None
for path in path.visit(fil=lambda x: 1,
rec=self._recurse, bf=True, sort=True):
if pkg_prefix and pkg_prefix in path.parts():
continue
for x in self._collectfile(path):
yield x
if isinstance(x, Package):
pkg_prefix = path.dirpath()
def _get_xunit_setup_teardown(holder, attr_name, param_obj=None): def _get_xunit_setup_teardown(holder, attr_name, param_obj=None):
""" """
Return a callable to perform xunit-style setup or teardown if Return a callable to perform xunit-style setup or teardown if

1
changelog/2283.feature Normal file
View File

@ -0,0 +1 @@
Pytest now supports package-level fixtures.

View File

@ -1448,6 +1448,39 @@ class TestFixtureManagerParseFactories(object):
reprec = testdir.inline_run("..") reprec = testdir.inline_run("..")
reprec.assertoutcome(passed=2) reprec.assertoutcome(passed=2)
def test_package_xunit_fixture(self, testdir):
testdir.makepyfile(__init__="""\
values = []
""")
package = testdir.mkdir("package")
package.join("__init__.py").write(dedent("""\
from .. import values
def setup_module():
values.append("package")
def teardown_module():
values[:] = []
"""))
package.join("test_x.py").write(dedent("""\
from .. import values
def test_x():
assert values == ["package"]
"""))
package = testdir.mkdir("package2")
package.join("__init__.py").write(dedent("""\
from .. import values
def setup_module():
values.append("package2")
def teardown_module():
values[:] = []
"""))
package.join("test_x.py").write(dedent("""\
from .. import values
def test_x():
assert values == ["package2"]
"""))
reprec = testdir.inline_run()
reprec.assertoutcome(passed=2)
class TestAutouseDiscovery(object): class TestAutouseDiscovery(object):

View File

@ -835,7 +835,7 @@ def test_continue_on_collection_errors_maxfail(testdir):
def test_fixture_scope_sibling_conftests(testdir): def test_fixture_scope_sibling_conftests(testdir):
"""Regression test case for https://github.com/pytest-dev/pytest/issues/2836""" """Regression test case for https://github.com/pytest-dev/pytest/issues/2836"""
foo_path = testdir.mkpydir("foo") foo_path = testdir.mkdir("foo")
foo_path.join("conftest.py").write(_pytest._code.Source(""" foo_path.join("conftest.py").write(_pytest._code.Source("""
import pytest import pytest
@pytest.fixture @pytest.fixture

View File

@ -192,7 +192,7 @@ class TestNewSession(SessionTests):
started = reprec.getcalls("pytest_collectstart") started = reprec.getcalls("pytest_collectstart")
finished = reprec.getreports("pytest_collectreport") finished = reprec.getreports("pytest_collectreport")
assert len(started) == len(finished) assert len(started) == len(finished)
assert len(started) == 7 # XXX extra TopCollector assert len(started) == 8 # XXX extra TopCollector
colfail = [x for x in finished if x.failed] colfail = [x for x in finished if x.failed]
assert len(colfail) == 1 assert len(colfail) == 1