From dd5c2b22bd3534ee02039932a51fc7a9eba01ca0 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Thu, 23 Jan 2020 14:47:27 +0200 Subject: [PATCH] Refactor Session._initialparts to have a more explicit type Previously, _initialparts was a list whose first item was a `py.path.local` and the rest were `str`s. This is not something that mypy is capable of modeling. The type `List[Union[str, py.path.local]]` is too broad and would require asserts for every access. Instead, make each item a `Tuple[py.path.local, List[str]]`. This way the structure is clear and the types are accurate. To make sure any users who might have been accessing this (private) field will not break silently, change the name to _initial_parts. --- src/_pytest/main.py | 40 ++++++++++++++++++-------------------- testing/test_collection.py | 14 ++++++------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/_pytest/main.py b/src/_pytest/main.py index 057dae4f4..f8735abee 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -8,6 +8,7 @@ import sys from typing import Dict from typing import FrozenSet from typing import List +from typing import Tuple import attr import py @@ -485,13 +486,13 @@ class Session(nodes.FSCollector): self.trace("perform_collect", self, args) self.trace.root.indent += 1 self._notfound = [] - initialpaths = [] - self._initialparts = [] + initialpaths = [] # type: List[py.path.local] + self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]] self.items = items = [] for arg in args: - parts = self._parsearg(arg) - self._initialparts.append(parts) - initialpaths.append(parts[0]) + fspath, parts = self._parsearg(arg) + self._initial_parts.append((fspath, parts)) + initialpaths.append(fspath) self._initialpaths = frozenset(initialpaths) rep = collect_one_node(self) self.ihook.pytest_collectreport(report=rep) @@ -511,13 +512,13 @@ class Session(nodes.FSCollector): return items def collect(self): - for initialpart in self._initialparts: - self.trace("processing argument", initialpart) + for fspath, parts in self._initial_parts: + self.trace("processing argument", (fspath, parts)) self.trace.root.indent += 1 try: - yield from self._collect(initialpart) + yield from self._collect(fspath, parts) except NoMatch: - report_arg = "::".join(map(str, initialpart)) + report_arg = "::".join((str(fspath), *parts)) # we are inside a make_report hook so # we cannot directly pass through the exception self._notfound.append((report_arg, sys.exc_info()[1])) @@ -526,12 +527,9 @@ class Session(nodes.FSCollector): self._collection_node_cache.clear() self._collection_pkg_roots.clear() - def _collect(self, arg): + def _collect(self, argpath, names): from _pytest.python import Package - names = arg[:] - argpath = names.pop(0) - # Start with a Session root, and delve to argpath item (dir or file) # and stack all Packages found on the way. # No point in finding packages when collecting doctests @@ -555,7 +553,7 @@ class Session(nodes.FSCollector): # If it's a directory argument, recurse and look for any Subpackages. # Let the Package collector deal with subnodes, don't collect here. if argpath.check(dir=1): - assert not names, "invalid arg {!r}".format(arg) + assert not names, "invalid arg {!r}".format((argpath, names)) seen_dirs = set() for path in argpath.visit( @@ -665,19 +663,19 @@ class Session(nodes.FSCollector): def _parsearg(self, arg): """ return (fspath, names) tuple after checking the file exists. """ - parts = str(arg).split("::") + strpath, *parts = str(arg).split("::") if self.config.option.pyargs: - parts[0] = self._tryconvertpyarg(parts[0]) - relpath = parts[0].replace("/", os.sep) - path = self.config.invocation_dir.join(relpath, abs=True) - if not path.check(): + strpath = self._tryconvertpyarg(strpath) + relpath = strpath.replace("/", os.sep) + fspath = self.config.invocation_dir.join(relpath, abs=True) + if not fspath.check(): if self.config.option.pyargs: raise UsageError( "file or package not found: " + arg + " (missing __init__.py?)" ) raise UsageError("file not found: " + arg) - parts[0] = path.realpath() - return parts + fspath = fspath.realpath() + return (fspath, parts) def matchnodes(self, matching, names): self.trace("matchnodes", matching, names) diff --git a/testing/test_collection.py b/testing/test_collection.py index 885b05ccd..760cb2b7f 100644 --- a/testing/test_collection.py +++ b/testing/test_collection.py @@ -438,7 +438,7 @@ class TestCustomConftests: class TestSession: - def test_parsearg(self, testdir): + def test_parsearg(self, testdir) -> None: p = testdir.makepyfile("def test_func(): pass") subdir = testdir.mkdir("sub") subdir.ensure("__init__.py") @@ -448,14 +448,14 @@ class TestSession: config = testdir.parseconfig(p.basename) rcol = Session.from_config(config) assert rcol.fspath == subdir - parts = rcol._parsearg(p.basename) + fspath, parts = rcol._parsearg(p.basename) - assert parts[0] == target + assert fspath == target + assert len(parts) == 0 + fspath, parts = rcol._parsearg(p.basename + "::test_func") + assert fspath == target + assert parts[0] == "test_func" assert len(parts) == 1 - parts = rcol._parsearg(p.basename + "::test_func") - assert parts[0] == target - assert parts[1] == "test_func" - assert len(parts) == 2 def test_collect_topdir(self, testdir): p = testdir.makepyfile("def test_func(): pass")