From cafb13c95f30d4f46f275fb60077966440643529 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 1 Jun 2019 13:50:15 -0700 Subject: [PATCH] Fix `pytest.mark.parametrize` when the argvalue is an iterator --- changelog/5354.bugfix.rst | 1 + src/_pytest/mark/structures.py | 10 +++++++--- testing/test_mark.py | 22 ++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 changelog/5354.bugfix.rst diff --git a/changelog/5354.bugfix.rst b/changelog/5354.bugfix.rst new file mode 100644 index 000000000..812ea8364 --- /dev/null +++ b/changelog/5354.bugfix.rst @@ -0,0 +1 @@ +Fix ``pytest.mark.parametrize`` when the argvalues is an iterator. diff --git a/src/_pytest/mark/structures.py b/src/_pytest/mark/structures.py index 561ccc3f4..9602e8acf 100644 --- a/src/_pytest/mark/structures.py +++ b/src/_pytest/mark/structures.py @@ -113,14 +113,18 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")): force_tuple = len(argnames) == 1 else: force_tuple = False - parameters = [ + return argnames, force_tuple + + @staticmethod + def _parse_parametrize_parameters(argvalues, force_tuple): + return [ ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues ] - return argnames, parameters @classmethod def _for_parametrize(cls, argnames, argvalues, func, config, function_definition): - argnames, parameters = cls._parse_parametrize_args(argnames, argvalues) + argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues) + parameters = cls._parse_parametrize_parameters(argvalues, force_tuple) del argvalues if parameters: diff --git a/testing/test_mark.py b/testing/test_mark.py index 5bd97d547..dd9d35230 100644 --- a/testing/test_mark.py +++ b/testing/test_mark.py @@ -413,6 +413,28 @@ def test_parametrized_with_kwargs(testdir): assert result.ret == 0 +def test_parametrize_iterator(testdir): + """parametrize should work with generators (#5354).""" + py_file = testdir.makepyfile( + """\ + import pytest + + def gen(): + yield 1 + yield 2 + yield 3 + + @pytest.mark.parametrize('a', gen()) + def test(a): + assert a >= 1 + """ + ) + result = testdir.runpytest(py_file) + assert result.ret == 0 + # should not skip any tests + result.stdout.fnmatch_lines(["*3 passed*"]) + + class TestFunctional(object): def test_merging_markers_deep(self, testdir): # issue 199 - propagate markers into nested classes