From e78207c936c43478aa5d5531d7c0b90aa240c9e0 Mon Sep 17 00:00:00 2001 From: Prashant Anand Date: Tue, 9 Jun 2020 09:54:22 +0900 Subject: [PATCH] 7119: data loss with mistyped --basetemp (#7170) Co-authored-by: Bruno Oliveira Co-authored-by: Ran Benita --- AUTHORS | 1 + changelog/7119.improvement.rst | 2 ++ src/_pytest/main.py | 31 +++++++++++++++++++++++++++++++ testing/test_main.py | 23 +++++++++++++++++++++++ 4 files changed, 57 insertions(+) create mode 100644 changelog/7119.improvement.rst diff --git a/AUTHORS b/AUTHORS index e1b195b9a..4c5ca41af 100644 --- a/AUTHORS +++ b/AUTHORS @@ -227,6 +227,7 @@ Pedro Algarvio Philipp Loose Pieter Mulder Piotr Banaszkiewicz +Prashant Anand Pulkit Goyal Punyashloka Biswal Quentin Pradet diff --git a/changelog/7119.improvement.rst b/changelog/7119.improvement.rst new file mode 100644 index 000000000..6cef98836 --- /dev/null +++ b/changelog/7119.improvement.rst @@ -0,0 +1,2 @@ +Exit with an error if the ``--basetemp`` argument is empty, the current working directory or parent directory of it. +This is done to protect against accidental data loss, as any directory passed to this argument is cleared. diff --git a/src/_pytest/main.py b/src/_pytest/main.py index 84ee00881..a95f2f2e7 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -1,4 +1,5 @@ """ core implementation of testing process: init, session, runtest loop. """ +import argparse import fnmatch import functools import importlib @@ -30,6 +31,7 @@ from _pytest.config import UsageError from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureManager from _pytest.outcomes import exit +from _pytest.pathlib import Path from _pytest.reports import CollectReport from _pytest.reports import TestReport from _pytest.runner import collect_one_node @@ -177,6 +179,7 @@ def pytest_addoption(parser: Parser) -> None: "--basetemp", dest="basetemp", default=None, + type=validate_basetemp, metavar="dir", help=( "base temporary directory for this test run." @@ -185,6 +188,34 @@ def pytest_addoption(parser: Parser) -> None: ) +def validate_basetemp(path: str) -> str: + # GH 7119 + msg = "basetemp must not be empty, the current working directory or any parent directory of it" + + # empty path + if not path: + raise argparse.ArgumentTypeError(msg) + + def is_ancestor(base: Path, query: Path) -> bool: + """ return True if query is an ancestor of base, else False.""" + if base == query: + return True + for parent in base.parents: + if parent == query: + return True + return False + + # check if path is an ancestor of cwd + if is_ancestor(Path.cwd(), Path(path).absolute()): + raise argparse.ArgumentTypeError(msg) + + # check symlinks for ancestors + if is_ancestor(Path.cwd().resolve(), Path(path).resolve()): + raise argparse.ArgumentTypeError(msg) + + return path + + def wrap_session( config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]] ) -> Union[int, ExitCode]: diff --git a/testing/test_main.py b/testing/test_main.py index 07aca3a1e..ee8349a9f 100644 --- a/testing/test_main.py +++ b/testing/test_main.py @@ -1,7 +1,9 @@ +import argparse from typing import Optional import pytest from _pytest.config import ExitCode +from _pytest.main import validate_basetemp from _pytest.pytester import Testdir @@ -75,3 +77,24 @@ def test_wrap_session_exit_sessionfinish( assert result.ret == ExitCode.NO_TESTS_COLLECTED assert result.stdout.lines[-1] == "collected 0 items" assert result.stderr.lines == ["Exit: exit_pytest_sessionfinish"] + + +@pytest.mark.parametrize("basetemp", ["foo", "foo/bar"]) +def test_validate_basetemp_ok(tmp_path, basetemp, monkeypatch): + monkeypatch.chdir(str(tmp_path)) + validate_basetemp(tmp_path / basetemp) + + +@pytest.mark.parametrize("basetemp", ["", ".", ".."]) +def test_validate_basetemp_fails(tmp_path, basetemp, monkeypatch): + monkeypatch.chdir(str(tmp_path)) + msg = "basetemp must not be empty, the current working directory or any parent directory of it" + with pytest.raises(argparse.ArgumentTypeError, match=msg): + if basetemp: + basetemp = tmp_path / basetemp + validate_basetemp(basetemp) + + +def test_validate_basetemp_integration(testdir): + result = testdir.runpytest("--basetemp=.") + result.stderr.fnmatch_lines("*basetemp must not be*")