Refactor Session._initialparts to have a more explicit type

Previously, _initialparts was a list whose first item was a
`py.path.local` and the rest were `str`s. This is not something that
mypy is capable of modeling. The type `List[Union[str, py.path.local]]`
is too broad and would require asserts for every access.

Instead, make each item a `Tuple[py.path.local, List[str]]`. This way
the structure is clear and the types are accurate.

To make sure any users who might have been accessing this (private)
field will not break silently, change the name to _initial_parts.
This commit is contained in:
Ran Benita 2020-01-23 14:47:27 +02:00
parent e17f5fad14
commit dd5c2b22bd
2 changed files with 26 additions and 28 deletions

View File

@ -8,6 +8,7 @@ import sys
from typing import Dict from typing import Dict
from typing import FrozenSet from typing import FrozenSet
from typing import List from typing import List
from typing import Tuple
import attr import attr
import py import py
@ -485,13 +486,13 @@ class Session(nodes.FSCollector):
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)
self.trace.root.indent += 1 self.trace.root.indent += 1
self._notfound = [] self._notfound = []
initialpaths = [] initialpaths = [] # type: List[py.path.local]
self._initialparts = [] self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
self.items = items = [] self.items = items = []
for arg in args: for arg in args:
parts = self._parsearg(arg) fspath, parts = self._parsearg(arg)
self._initialparts.append(parts) self._initial_parts.append((fspath, parts))
initialpaths.append(parts[0]) initialpaths.append(fspath)
self._initialpaths = frozenset(initialpaths) self._initialpaths = frozenset(initialpaths)
rep = collect_one_node(self) rep = collect_one_node(self)
self.ihook.pytest_collectreport(report=rep) self.ihook.pytest_collectreport(report=rep)
@ -511,13 +512,13 @@ class Session(nodes.FSCollector):
return items return items
def collect(self): def collect(self):
for initialpart in self._initialparts: for fspath, parts in self._initial_parts:
self.trace("processing argument", initialpart) self.trace("processing argument", (fspath, parts))
self.trace.root.indent += 1 self.trace.root.indent += 1
try: try:
yield from self._collect(initialpart) yield from self._collect(fspath, parts)
except NoMatch: except NoMatch:
report_arg = "::".join(map(str, initialpart)) report_arg = "::".join((str(fspath), *parts))
# we are inside a make_report hook so # we are inside a make_report hook so
# we cannot directly pass through the exception # we cannot directly pass through the exception
self._notfound.append((report_arg, sys.exc_info()[1])) self._notfound.append((report_arg, sys.exc_info()[1]))
@ -526,12 +527,9 @@ class Session(nodes.FSCollector):
self._collection_node_cache.clear() self._collection_node_cache.clear()
self._collection_pkg_roots.clear() self._collection_pkg_roots.clear()
def _collect(self, arg): def _collect(self, argpath, names):
from _pytest.python import Package from _pytest.python import Package
names = arg[:]
argpath = names.pop(0)
# Start with a Session root, and delve to argpath item (dir or file) # Start with a Session root, and delve to argpath item (dir or file)
# and stack all Packages found on the way. # and stack all Packages found on the way.
# No point in finding packages when collecting doctests # No point in finding packages when collecting doctests
@ -555,7 +553,7 @@ class Session(nodes.FSCollector):
# If it's a directory argument, recurse and look for any Subpackages. # If it's a directory argument, recurse and look for any Subpackages.
# Let the Package collector deal with subnodes, don't collect here. # Let the Package collector deal with subnodes, don't collect here.
if argpath.check(dir=1): if argpath.check(dir=1):
assert not names, "invalid arg {!r}".format(arg) assert not names, "invalid arg {!r}".format((argpath, names))
seen_dirs = set() seen_dirs = set()
for path in argpath.visit( for path in argpath.visit(
@ -665,19 +663,19 @@ class Session(nodes.FSCollector):
def _parsearg(self, arg): def _parsearg(self, arg):
""" return (fspath, names) tuple after checking the file exists. """ """ return (fspath, names) tuple after checking the file exists. """
parts = str(arg).split("::") strpath, *parts = str(arg).split("::")
if self.config.option.pyargs: if self.config.option.pyargs:
parts[0] = self._tryconvertpyarg(parts[0]) strpath = self._tryconvertpyarg(strpath)
relpath = parts[0].replace("/", os.sep) relpath = strpath.replace("/", os.sep)
path = self.config.invocation_dir.join(relpath, abs=True) fspath = self.config.invocation_dir.join(relpath, abs=True)
if not path.check(): if not fspath.check():
if self.config.option.pyargs: if self.config.option.pyargs:
raise UsageError( raise UsageError(
"file or package not found: " + arg + " (missing __init__.py?)" "file or package not found: " + arg + " (missing __init__.py?)"
) )
raise UsageError("file not found: " + arg) raise UsageError("file not found: " + arg)
parts[0] = path.realpath() fspath = fspath.realpath()
return parts return (fspath, parts)
def matchnodes(self, matching, names): def matchnodes(self, matching, names):
self.trace("matchnodes", matching, names) self.trace("matchnodes", matching, names)

View File

@ -438,7 +438,7 @@ class TestCustomConftests:
class TestSession: class TestSession:
def test_parsearg(self, testdir): def test_parsearg(self, testdir) -> None:
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")
subdir = testdir.mkdir("sub") subdir = testdir.mkdir("sub")
subdir.ensure("__init__.py") subdir.ensure("__init__.py")
@ -448,14 +448,14 @@ class TestSession:
config = testdir.parseconfig(p.basename) config = testdir.parseconfig(p.basename)
rcol = Session.from_config(config) rcol = Session.from_config(config)
assert rcol.fspath == subdir assert rcol.fspath == subdir
parts = rcol._parsearg(p.basename) fspath, parts = rcol._parsearg(p.basename)
assert parts[0] == target assert fspath == target
assert len(parts) == 0
fspath, parts = rcol._parsearg(p.basename + "::test_func")
assert fspath == target
assert parts[0] == "test_func"
assert len(parts) == 1 assert len(parts) == 1
parts = rcol._parsearg(p.basename + "::test_func")
assert parts[0] == target
assert parts[1] == "test_func"
assert len(parts) == 2
def test_collect_topdir(self, testdir): def test_collect_topdir(self, testdir):
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")