Merge pull request #6547 from bluetech/session-initialparts

Refactor Session._initialparts to have a more explicit type
This commit is contained in:
Ran Benita 2020-01-25 14:30:26 +02:00 committed by GitHub
commit a76bc64c54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 28 deletions

View File

@ -8,6 +8,7 @@ import sys
from typing import Dict from typing import Dict
from typing import FrozenSet from typing import FrozenSet
from typing import List from typing import List
from typing import Tuple
import attr import attr
import py import py
@ -486,13 +487,13 @@ class Session(nodes.FSCollector):
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)
self.trace.root.indent += 1 self.trace.root.indent += 1
self._notfound = [] self._notfound = []
initialpaths = [] initialpaths = [] # type: List[py.path.local]
self._initialparts = [] self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
self.items = items = [] self.items = items = []
for arg in args: for arg in args:
parts = self._parsearg(arg) fspath, parts = self._parsearg(arg)
self._initialparts.append(parts) self._initial_parts.append((fspath, parts))
initialpaths.append(parts[0]) initialpaths.append(fspath)
self._initialpaths = frozenset(initialpaths) self._initialpaths = frozenset(initialpaths)
rep = collect_one_node(self) rep = collect_one_node(self)
self.ihook.pytest_collectreport(report=rep) self.ihook.pytest_collectreport(report=rep)
@ -512,13 +513,13 @@ class Session(nodes.FSCollector):
return items return items
def collect(self): def collect(self):
for initialpart in self._initialparts: for fspath, parts in self._initial_parts:
self.trace("processing argument", initialpart) self.trace("processing argument", (fspath, parts))
self.trace.root.indent += 1 self.trace.root.indent += 1
try: try:
yield from self._collect(initialpart) yield from self._collect(fspath, parts)
except NoMatch: except NoMatch:
report_arg = "::".join(map(str, initialpart)) report_arg = "::".join((str(fspath), *parts))
# we are inside a make_report hook so # we are inside a make_report hook so
# we cannot directly pass through the exception # we cannot directly pass through the exception
self._notfound.append((report_arg, sys.exc_info()[1])) self._notfound.append((report_arg, sys.exc_info()[1]))
@ -527,12 +528,9 @@ class Session(nodes.FSCollector):
self._collection_node_cache.clear() self._collection_node_cache.clear()
self._collection_pkg_roots.clear() self._collection_pkg_roots.clear()
def _collect(self, arg): def _collect(self, argpath, names):
from _pytest.python import Package from _pytest.python import Package
names = arg[:]
argpath = names.pop(0)
# Start with a Session root, and delve to argpath item (dir or file) # Start with a Session root, and delve to argpath item (dir or file)
# and stack all Packages found on the way. # and stack all Packages found on the way.
# No point in finding packages when collecting doctests # No point in finding packages when collecting doctests
@ -556,7 +554,7 @@ class Session(nodes.FSCollector):
# If it's a directory argument, recurse and look for any Subpackages. # If it's a directory argument, recurse and look for any Subpackages.
# Let the Package collector deal with subnodes, don't collect here. # Let the Package collector deal with subnodes, don't collect here.
if argpath.check(dir=1): if argpath.check(dir=1):
assert not names, "invalid arg {!r}".format(arg) assert not names, "invalid arg {!r}".format((argpath, names))
seen_dirs = set() seen_dirs = set()
for path in argpath.visit( for path in argpath.visit(
@ -666,19 +664,19 @@ class Session(nodes.FSCollector):
def _parsearg(self, arg): def _parsearg(self, arg):
""" return (fspath, names) tuple after checking the file exists. """ """ return (fspath, names) tuple after checking the file exists. """
parts = str(arg).split("::") strpath, *parts = str(arg).split("::")
if self.config.option.pyargs: if self.config.option.pyargs:
parts[0] = self._tryconvertpyarg(parts[0]) strpath = self._tryconvertpyarg(strpath)
relpath = parts[0].replace("/", os.sep) relpath = strpath.replace("/", os.sep)
path = self.config.invocation_dir.join(relpath, abs=True) fspath = self.config.invocation_dir.join(relpath, abs=True)
if not path.check(): if not fspath.check():
if self.config.option.pyargs: if self.config.option.pyargs:
raise UsageError( raise UsageError(
"file or package not found: " + arg + " (missing __init__.py?)" "file or package not found: " + arg + " (missing __init__.py?)"
) )
raise UsageError("file not found: " + arg) raise UsageError("file not found: " + arg)
parts[0] = path.realpath() fspath = fspath.realpath()
return parts return (fspath, parts)
def matchnodes(self, matching, names): def matchnodes(self, matching, names):
self.trace("matchnodes", matching, names) self.trace("matchnodes", matching, names)

View File

@ -438,7 +438,7 @@ class TestCustomConftests:
class TestSession: class TestSession:
def test_parsearg(self, testdir): def test_parsearg(self, testdir) -> None:
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")
subdir = testdir.mkdir("sub") subdir = testdir.mkdir("sub")
subdir.ensure("__init__.py") subdir.ensure("__init__.py")
@ -448,14 +448,14 @@ class TestSession:
config = testdir.parseconfig(p.basename) config = testdir.parseconfig(p.basename)
rcol = Session.from_config(config) rcol = Session.from_config(config)
assert rcol.fspath == subdir assert rcol.fspath == subdir
parts = rcol._parsearg(p.basename) fspath, parts = rcol._parsearg(p.basename)
assert parts[0] == target assert fspath == target
assert len(parts) == 0
fspath, parts = rcol._parsearg(p.basename + "::test_func")
assert fspath == target
assert parts[0] == "test_func"
assert len(parts) == 1 assert len(parts) == 1
parts = rcol._parsearg(p.basename + "::test_func")
assert parts[0] == target
assert parts[1] == "test_func"
assert len(parts) == 2
def test_collect_topdir(self, testdir): def test_collect_topdir(self, testdir):
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")