From 313b61471f47d7d95402110ed5960c1ffacd3157 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sun, 4 Jun 2023 00:32:28 +0300 Subject: [PATCH] main: change pkg_roots to work with `Path`s instead of string paths - Works better on Windows (case sensitivity) - Simpler code --- src/_pytest/main.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/_pytest/main.py b/src/_pytest/main.py index c9a495380..8f07daeec 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -686,8 +686,8 @@ class Session(nodes.FSCollector): # are not collected more than once. matchnodes_cache: Dict[Tuple[Type[nodes.Collector], str], CollectReport] = {} - # Dirnames of pkgs with dunder-init files. - pkg_roots: Dict[str, Package] = {} + # Directories of pkgs with dunder-init files. + pkg_roots: Dict[Path, Package] = {} for argpath, names in self._initial_parts: self.trace("processing argument", (argpath, names)) @@ -708,7 +708,7 @@ class Session(nodes.FSCollector): col = self._collectfile(pkginit, handle_dupes=False) if col: if isinstance(col[0], Package): - pkg_roots[str(parent)] = col[0] + pkg_roots[parent] = col[0] node_cache1[col[0].path] = [col[0]] # If it's a directory argument, recurse and look for any Subpackages. @@ -717,7 +717,7 @@ class Session(nodes.FSCollector): assert not names, f"invalid arg {(argpath, names)!r}" seen_dirs: Set[Path] = set() - for direntry in visit(str(argpath), self._recurse): + for direntry in visit(argpath, self._recurse): if not direntry.is_file(): continue @@ -732,8 +732,8 @@ class Session(nodes.FSCollector): for x in self._collectfile(pkginit): yield x if isinstance(x, Package): - pkg_roots[str(dirpath)] = x - if str(dirpath) in pkg_roots: + pkg_roots[dirpath] = x + if dirpath in pkg_roots: # Do not collect packages here. continue @@ -750,7 +750,7 @@ class Session(nodes.FSCollector): if argpath in node_cache1: col = node_cache1[argpath] else: - collect_root = pkg_roots.get(str(argpath.parent), self) + collect_root = pkg_roots.get(argpath.parent, self) col = collect_root._collectfile(argpath, handle_dupes=False) if col: node_cache1[argpath] = col