diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19329cb21b8..2a45c776e09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,7 +69,7 @@ repos: - attrs>=19.2.0 - packaging - tomli - - types-pkg_resources + - types-setuptools # for mypy running on python>=3.11 since exceptiongroup is only a dependency # on <3.11 - exceptiongroup>=1.0.0rc8 @@ -77,7 +77,7 @@ repos: hooks: - id: rst name: rst - entry: rst-lint --encoding utf-8 + entry: rst-lint files: ^(RELEASING.rst|README.rst|TIDELIFT.rst)$ language: python additional_dependencies: [pygments, restructuredtext_lint] diff --git a/AUTHORS b/AUTHORS index ca2872f32a4..c5c774dcf85 100644 --- a/AUTHORS +++ b/AUTHORS @@ -60,6 +60,7 @@ Bruno Oliveira Cal Leeming Carl Friedrich Bolz Carlos Jenkins +Casey Brooks Ceridwen Charles Cloud Charles Machalow diff --git a/changelog/0007.bugfix.rst b/changelog/0007.bugfix.rst new file mode 100644 index 00000000000..b3270c2e407 --- /dev/null +++ b/changelog/0007.bugfix.rst @@ -0,0 +1 @@ +Class-level marks now merge across multiple inheritance following MRO; identical duplicates are removed. diff --git a/src/_pytest/mark/structures.py b/src/_pytest/mark/structures.py index 800a25c9243..0e27e0823aa 100644 --- a/src/_pytest/mark/structures.py +++ b/src/_pytest/mark/structures.py @@ -4,6 +4,7 @@ from typing import Any from typing import Callable from typing import Collection +from typing import Dict from typing import Iterable from typing import Iterator from typing import List @@ -363,6 +364,86 @@ def get_unpacked_marks(obj: object) -> Iterable[Mark]: return normalize_mark_list(mark_list) +def get_unpacked_class_marks(cls: type) -> List[Mark]: + """Return all unique marks declared on a class following its MRO.""" + + def _mark_components( + mark: Mark, + ) -> Tuple[str, Tuple[Any, ...], Tuple[Tuple[str, Any], ...]]: + return ( + mark.name, + tuple(_freeze(arg) for arg in mark.args), + tuple( + sorted((name, _freeze(value)) for name, value in mark.kwargs.items()) + ), + ) + + def _freeze(value: Any) -> Any: + if isinstance(value, MarkDecorator): + return ("MarkDecorator", *_mark_components(value.mark)) + if isinstance(value, Mark): + return ("Mark", *_mark_components(value)) + if isinstance(value, (list, tuple)): + return tuple(_freeze(v) for v in value) + if isinstance(value, set): + return tuple(sorted(_freeze(v) for v in value)) + if isinstance(value, dict): + return tuple(sorted((k, _freeze(v)) for k, v in value.items())) + return value + + def _mark_key( + mark: Mark, + ) -> Tuple[str, Tuple[Any, ...], Tuple[Tuple[str, Any], ...]]: + return _mark_components(mark) + + def _marks_for(klass: type, cache: Dict[type, List[Mark]]) -> List[Mark]: + cached = cache.get(klass) + if cached is not None: + return cached + raw = klass.__dict__.get("pytestmark", []) + if not isinstance(raw, list): + raw = [raw] + marks = list(normalize_mark_list(raw)) + cache[klass] = marks + return marks + + dedup_key: Set[Tuple[str, Tuple[Any, ...], Tuple[Tuple[str, Any], ...]]] = set() + collected: List[Mark] = [] + marks_cache: Dict[type, List[Mark]] = {} + + for base in cls.__mro__: + if base is object: + break + + marks = list(_marks_for(base, marks_cache)) + if marks: + inherited_keys = { + _mark_key(mark) + for ancestor in base.__mro__[1:] + if ancestor is not object + for mark in _marks_for(ancestor, marks_cache) + } + if inherited_keys: + trimmed: List[Mark] = [] + prefix_skipping = True + for mark in marks: + key = _mark_key(mark) + if prefix_skipping and key in inherited_keys: + continue + prefix_skipping = False + trimmed.append(mark) + marks = trimmed + + for mark in marks: + key = _mark_key(mark) + if key in dedup_key: + continue + dedup_key.add(key) + collected.append(mark) + + return collected + + def normalize_mark_list( mark_list: Iterable[Union[Mark, MarkDecorator]] ) -> Iterable[Mark]: diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 3db8775061b..67ead592148 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -63,6 +63,7 @@ from _pytest.main import Session from _pytest.mark import MARK_GEN from _pytest.mark import ParameterSet +from _pytest.mark.structures import get_unpacked_class_marks from _pytest.mark.structures import get_unpacked_marks from _pytest.mark.structures import Mark from _pytest.mark.structures import MarkDecorator @@ -311,7 +312,10 @@ def obj(self): # XXX evil hack # used to avoid Function marker duplication if self._ALLOW_MARKERS: - self.own_markers.extend(get_unpacked_marks(self.obj)) + if isinstance(self, Class): + self.own_markers.extend(get_unpacked_class_marks(obj)) + else: + self.own_markers.extend(get_unpacked_marks(obj)) # This assumes that `obj` is called before there is a chance # to add custom keys to `self.keywords`, so no fear of overriding. self.keywords.update((mark.name, mark) for mark in self.own_markers) diff --git a/testing/test_markers_mro.py b/testing/test_markers_mro.py new file mode 100644 index 00000000000..b2036976903 --- /dev/null +++ b/testing/test_markers_mro.py @@ -0,0 +1,225 @@ +import pytest + + +def _register_marks(pytester: pytest.Pytester) -> None: + pytester.makeini( + """ + [pytest] + markers = + alpha(value): mark used for MRO ordering checks + beta(value): class-level mark defined on subclasses + gamma(value): function-level mark used for ordering assertions + shared(value): mark used to exercise deduplication + method(value): additional function-level mark + diamond(value): mark applied in diamond inheritance scenarios + """ + ) + + +def test_multiple_inheritance_merges_marks(pytester: pytest.Pytester) -> None: + _register_marks(pytester) + pytester.makepyfile( + """ + import pytest + + class BaseA: + pytestmark = pytest.mark.alpha("base-a") + + class BaseB: + pytestmark = [pytest.mark.alpha("base-b")] + + @pytest.mark.beta("derived") + class TestDerived(BaseA, BaseB): + @pytest.mark.gamma("method") + def test_method(self, request): + seen = [(mark.name, mark.args) for mark in request.node.iter_markers()] + assert seen == [ + ("gamma", ("method",)), + ("beta", ("derived",)), + ("alpha", ("base-a",)), + ("alpha", ("base-b",)), + ] + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=1) + + +def test_function_marks_precede_class_marks(pytester: pytest.Pytester) -> None: + _register_marks(pytester) + pytester.makepyfile( + """ + import pytest + + @pytest.mark.alpha("base") + class Base: + pass + + class TestDerived(Base): + pytestmark = [pytest.mark.beta("derived-list")] + + @pytest.mark.method("first") + @pytest.mark.gamma("second") + def test_method(self, request): + names = [mark.name for mark in request.node.iter_markers()] + assert names[:2] == ["gamma", "method"] + assert names[2:] == ["beta", "alpha"] + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=1) + + +def test_diamond_inheritance_deduplicates_marks(pytester: pytest.Pytester) -> None: + _register_marks(pytester) + pytester.makepyfile( + """ + import pytest + + @pytest.mark.shared("root") + class Root: + pass + + @pytest.mark.shared("branch") + @pytest.mark.alpha("left") + class Left(Root): + pass + + @pytest.mark.shared("branch") + @pytest.mark.alpha("right") + class Right(Root): + pass + + class TestLeaf(Left, Right): + @pytest.mark.gamma("method") + def test_leaf(self, request): + seen = [(mark.name, mark.args) for mark in request.node.iter_markers()] + assert seen == [ + ("gamma", ("method",)), + ("alpha", ("left",)), + ("shared", ("branch",)), + ("alpha", ("right",)), + ("shared", ("root",)), + ] + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=1) + + +def test_identical_marks_across_bases_deduped_once( + pytester: pytest.Pytester, +) -> None: + _register_marks(pytester) + pytester.makepyfile( + """ + import pytest + + class BaseA: + pytestmark = [pytest.mark.shared("duplicate")] + + class BaseB: + pytestmark = pytest.mark.shared("duplicate") + + class TestDerived(BaseA, BaseB): + def test_marks(self, request): + seen = [mark.args for mark in request.node.iter_markers(name="shared")] + assert seen == [("duplicate",)] + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=1) + + +def test_parametrize_marks_merge_and_dedupe(pytester: pytest.Pytester) -> None: + pytester.makepyfile( + """ + import pytest + + class ParamBaseA: + pytestmark = pytest.mark.parametrize("left", ["L1", "L2"]) + + class ParamBaseB: + pytestmark = [ + pytest.mark.parametrize("right", ["R1", "R2"]), + pytest.mark.parametrize("left", ["L1", "L2"]), + ] + + class TestParam(ParamBaseA, ParamBaseB): + def test_params(self, left, right): + pass + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=4) + + +def test_parametrize_marks_with_param_objects(pytester: pytest.Pytester) -> None: + _register_marks(pytester) + pytester.makepyfile( + """ + import pytest + + class ParamBaseA: + pytestmark = pytest.mark.parametrize( + "left", + [ + pytest.param("L1", marks=pytest.mark.shared("dup")), + pytest.param("L2", marks=pytest.mark.alpha("left")), + ], + ) + + class ParamBaseB: + pytestmark = [ + pytest.mark.parametrize( + "right", + [ + pytest.param("R1", marks=pytest.mark.shared("dup")), + pytest.param("R2", marks=pytest.mark.beta("right")), + ], + ) + ] + + class TestParamCombined(ParamBaseA, ParamBaseB): + def test_params(self, left, right, request): + shared_args = {mark.args for mark in request.node.iter_markers(name="shared")} + if left == "L1" or right == "R1": + assert shared_args == {("dup",)} + else: + assert shared_args == set() + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=4) + + +def test_class_pytestmark_forms_supported(pytester: pytest.Pytester) -> None: + _register_marks(pytester) + pytester.makepyfile( + """ + import pytest + + class Decorated: + pytestmark = [pytest.mark.alpha("decorated")] + + @pytest.mark.beta("derived") + class TestDerived(Decorated): + @pytest.mark.gamma("method") + def test_forms(self, request): + seen = [(mark.name, mark.args) for mark in request.node.iter_markers()] + assert seen == [ + ("gamma", ("method",)), + ("beta", ("derived",)), + ("alpha", ("decorated",)), + ] + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=1)