From f8b944dee0b59bdc483ce505738ea1cb3a57a5b4 Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Wed, 7 Nov 2018 19:33:22 +0100 Subject: [PATCH] pkg_roots --- src/_pytest/main.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/_pytest/main.py b/src/_pytest/main.py index a2b27d9fa..59e2b6d4d 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -490,6 +490,7 @@ class Session(nodes.FSCollector): names = self._parsearg(arg) argpath = names.pop(0).realpath() paths = set() + pkg_roots = {} root = self # Start with a Session root, and delve to argpath item (dir or file) @@ -510,9 +511,9 @@ class Session(nodes.FSCollector): col = root._collectfile(pkginit, handle_dupes=False) if col: if isinstance(col[0], Package): - root = col[0] + pkg_roots[parent] = col[0] # always store a list in the cache, matchnodes expects it - self._node_cache[root.fspath] = [root] + self._node_cache[col[0].fspath] = [col[0]] # If it's a directory argument, recurse and look for any Subpackages. # Let the Package collector deal with subnodes, don't collect here. @@ -534,16 +535,19 @@ class Session(nodes.FSCollector): fil=filter_, rec=self._recurse, bf=True, sort=True ): dirpath = path.dirpath() + collect_root = pkg_roots.get(dirpath, root) if dirpath not in seen_dirs: seen_dirs.add(dirpath) pkginit = dirpath.join("__init__.py") if pkginit.exists() and parts(pkginit.strpath).isdisjoint(paths): - for x in root._collectfile(pkginit): + for x in collect_root._collectfile(pkginit): yield x + if isinstance(x, Package): + pkg_roots[dirpath] = x paths.add(x.fspath.dirpath()) - if parts(path.strpath).isdisjoint(paths): - for x in root._collectfile(path): + if True or parts(path.strpath).isdisjoint(paths): + for x in collect_root._collectfile(path): key = (type(x), x.fspath) if key in self._node_cache: yield self._node_cache[key] @@ -556,7 +560,8 @@ class Session(nodes.FSCollector): if argpath in self._node_cache: col = self._node_cache[argpath] else: - col = root._collectfile(argpath) + collect_root = pkg_roots.get(argpath.dirname, root) + col = collect_root._collectfile(argpath) if col: self._node_cache[argpath] = col m = self.matchnodes(col, names)