Add package scoped fixtures #2283

This commit is contained in:
turturica 2018-04-11 15:39:42 -07:00
parent 372bcdba0c
commit 2b1410895e
8 changed files with 112 additions and 15 deletions

View File

@ -87,6 +87,7 @@ Hugo van Kemenade
Hui Wang (coldnight)
Ian Bicking
Ian Lesperance
Ionuț Turturică
Jaap Broekhuizen
Jan Balster
Janne Vanhala

View File

@ -36,6 +36,7 @@ def pytest_sessionstart(session):
import _pytest.nodes
scopename2class.update({
'package': _pytest.python.Package,
'class': _pytest.python.Class,
'module': _pytest.python.Module,
'function': _pytest.nodes.Item,
@ -48,6 +49,7 @@ scopename2class = {}
scope2props = dict(session=())
scope2props["package"] = ("fspath",)
scope2props["module"] = ("fspath", "module")
scope2props["class"] = scope2props["module"] + ("cls",)
scope2props["instance"] = scope2props["class"] + ("instance", )
@ -156,9 +158,11 @@ def get_parametrized_fixture_keys(item, scopenum):
continue
if scopenum == 0: # session
key = (argname, param_index)
elif scopenum == 1: # module
elif scopenum == 1: # package
key = (argname, param_index, item.fspath)
elif scopenum == 2: # class
elif scopenum == 2: # module
key = (argname, param_index, item.fspath)
elif scopenum == 3: # class
key = (argname, param_index, item.fspath, item.cls)
yield key
@ -596,7 +600,7 @@ class ScopeMismatchError(Exception):
"""
scopes = "session module class function".split()
scopes = "session package module class function".split()
scopenum_function = scopes.index("function")

View File

@ -405,16 +405,29 @@ class Session(nodes.FSCollector):
def _collect(self, arg):
names = self._parsearg(arg)
path = names.pop(0)
if path.check(dir=1):
argpath = names.pop(0)
paths = []
if argpath.check(dir=1):
assert not names, "invalid arg %r" % (arg,)
for path in path.visit(fil=lambda x: x.check(file=1),
for path in argpath.visit(fil=lambda x: x.check(file=1),
rec=self._recurse, bf=True, sort=True):
pkginit = path.dirpath().join('__init__.py')
if pkginit.exists() and not any(x in pkginit.parts() for x in paths):
for x in self._collectfile(pkginit):
yield x
paths.append(x.fspath.dirpath())
if not any(x in path.parts() for x in paths):
for x in self._collectfile(path):
yield x
else:
assert path.check(file=1)
for x in self.matchnodes(self._collectfile(path), names):
assert argpath.check(file=1)
pkginit = argpath.dirpath().join('__init__.py')
if not self.isinitpath(argpath) and pkginit.exists():
for x in self._collectfile(pkginit):
yield x
else:
for x in self.matchnodes(self._collectfile(argpath), names):
yield x
def _collectfile(self, path):

View File

@ -17,6 +17,7 @@ from _pytest.mark import MarkerError
from _pytest.config import hookimpl
import _pytest
from _pytest.main import Session
import pluggy
from _pytest import fixtures
from _pytest import nodes
@ -157,7 +158,7 @@ def pytest_collect_file(path, parent):
ext = path.ext
if ext == ".py":
if not parent.session.isinitpath(path):
for pat in parent.config.getini('python_files'):
for pat in parent.config.getini('python_files') + ['__init__.py']:
if path.fnmatch(pat):
break
else:
@ -167,9 +168,23 @@ def pytest_collect_file(path, parent):
def pytest_pycollect_makemodule(path, parent):
if path.basename == '__init__.py':
return Package(path, parent)
return Module(path, parent)
def pytest_ignore_collect(path, config):
# Skip duplicate packages.
keepduplicates = config.getoption("keepduplicates")
if keepduplicates:
duplicate_paths = config.pluginmanager._duplicatepaths
if path.basename == '__init__.py':
if path in duplicate_paths:
return True
else:
duplicate_paths.add(path)
@hookimpl(hookwrapper=True)
def pytest_pycollect_makeitem(collector, name, obj):
outcome = yield
@ -475,6 +490,36 @@ class Module(nodes.File, PyCollector):
self.addfinalizer(teardown_module)
class Package(Session, Module):
def __init__(self, fspath, parent=None, config=None, session=None, nodeid=None):
session = parent.session
nodes.FSCollector.__init__(
self, fspath, parent=parent,
config=config, session=session, nodeid=nodeid)
self.name = fspath.pyimport().__name__
self.trace = session.trace
self._norecursepatterns = session._norecursepatterns
for path in list(session.config.pluginmanager._duplicatepaths):
if path.dirname == fspath.dirname and path != fspath:
session.config.pluginmanager._duplicatepaths.remove(path)
pass
def isinitpath(self, path):
return path in self.session._initialpaths
def collect(self):
path = self.fspath.dirpath()
pkg_prefix = None
for path in path.visit(fil=lambda x: 1,
rec=self._recurse, bf=True, sort=True):
if pkg_prefix and pkg_prefix in path.parts():
continue
for x in self._collectfile(path):
yield x
if isinstance(x, Package):
pkg_prefix = path.dirpath()
def _get_xunit_setup_teardown(holder, attr_name, param_obj=None):
"""
Return a callable to perform xunit-style setup or teardown if

1
changelog/2283.feature Normal file
View File

@ -0,0 +1 @@
Pytest now supports package-level fixtures.

View File

@ -1448,6 +1448,39 @@ class TestFixtureManagerParseFactories(object):
reprec = testdir.inline_run("..")
reprec.assertoutcome(passed=2)
def test_package_xunit_fixture(self, testdir):
testdir.makepyfile(__init__="""\
values = []
""")
package = testdir.mkdir("package")
package.join("__init__.py").write(dedent("""\
from .. import values
def setup_module():
values.append("package")
def teardown_module():
values[:] = []
"""))
package.join("test_x.py").write(dedent("""\
from .. import values
def test_x():
assert values == ["package"]
"""))
package = testdir.mkdir("package2")
package.join("__init__.py").write(dedent("""\
from .. import values
def setup_module():
values.append("package2")
def teardown_module():
values[:] = []
"""))
package.join("test_x.py").write(dedent("""\
from .. import values
def test_x():
assert values == ["package2"]
"""))
reprec = testdir.inline_run()
reprec.assertoutcome(passed=2)
class TestAutouseDiscovery(object):

View File

@ -835,7 +835,7 @@ def test_continue_on_collection_errors_maxfail(testdir):
def test_fixture_scope_sibling_conftests(testdir):
"""Regression test case for https://github.com/pytest-dev/pytest/issues/2836"""
foo_path = testdir.mkpydir("foo")
foo_path = testdir.mkdir("foo")
foo_path.join("conftest.py").write(_pytest._code.Source("""
import pytest
@pytest.fixture

View File

@ -192,7 +192,7 @@ class TestNewSession(SessionTests):
started = reprec.getcalls("pytest_collectstart")
finished = reprec.getreports("pytest_collectreport")
assert len(started) == len(finished)
assert len(started) == 7 # XXX extra TopCollector
assert len(started) == 8 # XXX extra TopCollector
colfail = [x for x in finished if x.failed]
assert len(colfail) == 1