Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ repos:
- attrs>=19.2.0
- packaging
- tomli
- types-pkg_resources
- types-setuptools
# for mypy running on python>=3.11 since exceptiongroup is only a dependency
# on <3.11
- exceptiongroup>=1.0.0rc8
- repo: local
hooks:
- id: rst
name: rst
entry: rst-lint --encoding utf-8
entry: rst-lint
files: ^(RELEASING.rst|README.rst|TIDELIFT.rst)$
language: python
additional_dependencies: [pygments, restructuredtext_lint]
Expand Down
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Bruno Oliveira
Cal Leeming
Carl Friedrich Bolz
Carlos Jenkins
Casey Brooks
Ceridwen
Charles Cloud
Charles Machalow
Expand Down
1 change: 1 addition & 0 deletions changelog/0007.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Class-level marks now merge across multiple inheritance following MRO; identical duplicates are removed.
81 changes: 81 additions & 0 deletions src/_pytest/mark/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
Expand Down Expand Up @@ -363,6 +364,86 @@ def get_unpacked_marks(obj: object) -> Iterable[Mark]:
return normalize_mark_list(mark_list)


def get_unpacked_class_marks(cls: type) -> List[Mark]:
"""Return all unique marks declared on a class following its MRO."""

def _mark_components(
mark: Mark,
) -> Tuple[str, Tuple[Any, ...], Tuple[Tuple[str, Any], ...]]:
return (
mark.name,
tuple(_freeze(arg) for arg in mark.args),
tuple(
sorted((name, _freeze(value)) for name, value in mark.kwargs.items())
),
)

def _freeze(value: Any) -> Any:
if isinstance(value, MarkDecorator):
return ("MarkDecorator", *_mark_components(value.mark))
if isinstance(value, Mark):
return ("Mark", *_mark_components(value))
if isinstance(value, (list, tuple)):
return tuple(_freeze(v) for v in value)
if isinstance(value, set):
return tuple(sorted(_freeze(v) for v in value))
if isinstance(value, dict):
return tuple(sorted((k, _freeze(v)) for k, v in value.items()))
return value

def _mark_key(
mark: Mark,
) -> Tuple[str, Tuple[Any, ...], Tuple[Tuple[str, Any], ...]]:
return _mark_components(mark)

def _marks_for(klass: type, cache: Dict[type, List[Mark]]) -> List[Mark]:
cached = cache.get(klass)
if cached is not None:
return cached
raw = klass.__dict__.get("pytestmark", [])
if not isinstance(raw, list):
raw = [raw]
marks = list(normalize_mark_list(raw))
cache[klass] = marks
return marks

dedup_key: Set[Tuple[str, Tuple[Any, ...], Tuple[Tuple[str, Any], ...]]] = set()
collected: List[Mark] = []
marks_cache: Dict[type, List[Mark]] = {}

for base in cls.__mro__:
if base is object:
break

marks = list(_marks_for(base, marks_cache))
if marks:
inherited_keys = {
_mark_key(mark)
for ancestor in base.__mro__[1:]
if ancestor is not object
for mark in _marks_for(ancestor, marks_cache)
}
if inherited_keys:
trimmed: List[Mark] = []
prefix_skipping = True
for mark in marks:
key = _mark_key(mark)
if prefix_skipping and key in inherited_keys:
continue
prefix_skipping = False
trimmed.append(mark)
marks = trimmed

for mark in marks:
key = _mark_key(mark)
if key in dedup_key:
continue
dedup_key.add(key)
collected.append(mark)

return collected


def normalize_mark_list(
mark_list: Iterable[Union[Mark, MarkDecorator]]
) -> Iterable[Mark]:
Expand Down
6 changes: 5 additions & 1 deletion src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from _pytest.main import Session
from _pytest.mark import MARK_GEN
from _pytest.mark import ParameterSet
from _pytest.mark.structures import get_unpacked_class_marks
from _pytest.mark.structures import get_unpacked_marks
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
Expand Down Expand Up @@ -311,7 +312,10 @@ def obj(self):
# XXX evil hack
# used to avoid Function marker duplication
if self._ALLOW_MARKERS:
self.own_markers.extend(get_unpacked_marks(self.obj))
if isinstance(self, Class):
self.own_markers.extend(get_unpacked_class_marks(obj))
else:
self.own_markers.extend(get_unpacked_marks(obj))
# This assumes that `obj` is called before there is a chance
# to add custom keys to `self.keywords`, so no fear of overriding.
self.keywords.update((mark.name, mark) for mark in self.own_markers)
Expand Down
225 changes: 225 additions & 0 deletions testing/test_markers_mro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import pytest


def _register_marks(pytester: pytest.Pytester) -> None:
pytester.makeini(
"""
[pytest]
markers =
alpha(value): mark used for MRO ordering checks
beta(value): class-level mark defined on subclasses
gamma(value): function-level mark used for ordering assertions
shared(value): mark used to exercise deduplication
method(value): additional function-level mark
diamond(value): mark applied in diamond inheritance scenarios
"""
)


def test_multiple_inheritance_merges_marks(pytester: pytest.Pytester) -> None:
_register_marks(pytester)
pytester.makepyfile(
"""
import pytest

class BaseA:
pytestmark = pytest.mark.alpha("base-a")

class BaseB:
pytestmark = [pytest.mark.alpha("base-b")]

@pytest.mark.beta("derived")
class TestDerived(BaseA, BaseB):
@pytest.mark.gamma("method")
def test_method(self, request):
seen = [(mark.name, mark.args) for mark in request.node.iter_markers()]
assert seen == [
("gamma", ("method",)),
("beta", ("derived",)),
("alpha", ("base-a",)),
("alpha", ("base-b",)),
]
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=1)


def test_function_marks_precede_class_marks(pytester: pytest.Pytester) -> None:
_register_marks(pytester)
pytester.makepyfile(
"""
import pytest

@pytest.mark.alpha("base")
class Base:
pass

class TestDerived(Base):
pytestmark = [pytest.mark.beta("derived-list")]

@pytest.mark.method("first")
@pytest.mark.gamma("second")
def test_method(self, request):
names = [mark.name for mark in request.node.iter_markers()]
assert names[:2] == ["gamma", "method"]
assert names[2:] == ["beta", "alpha"]
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=1)


def test_diamond_inheritance_deduplicates_marks(pytester: pytest.Pytester) -> None:
_register_marks(pytester)
pytester.makepyfile(
"""
import pytest

@pytest.mark.shared("root")
class Root:
pass

@pytest.mark.shared("branch")
@pytest.mark.alpha("left")
class Left(Root):
pass

@pytest.mark.shared("branch")
@pytest.mark.alpha("right")
class Right(Root):
pass

class TestLeaf(Left, Right):
@pytest.mark.gamma("method")
def test_leaf(self, request):
seen = [(mark.name, mark.args) for mark in request.node.iter_markers()]
assert seen == [
("gamma", ("method",)),
("alpha", ("left",)),
("shared", ("branch",)),
("alpha", ("right",)),
("shared", ("root",)),
]
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=1)


def test_identical_marks_across_bases_deduped_once(
pytester: pytest.Pytester,
) -> None:
_register_marks(pytester)
pytester.makepyfile(
"""
import pytest

class BaseA:
pytestmark = [pytest.mark.shared("duplicate")]

class BaseB:
pytestmark = pytest.mark.shared("duplicate")

class TestDerived(BaseA, BaseB):
def test_marks(self, request):
seen = [mark.args for mark in request.node.iter_markers(name="shared")]
assert seen == [("duplicate",)]
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=1)


def test_parametrize_marks_merge_and_dedupe(pytester: pytest.Pytester) -> None:
pytester.makepyfile(
"""
import pytest

class ParamBaseA:
pytestmark = pytest.mark.parametrize("left", ["L1", "L2"])

class ParamBaseB:
pytestmark = [
pytest.mark.parametrize("right", ["R1", "R2"]),
pytest.mark.parametrize("left", ["L1", "L2"]),
]

class TestParam(ParamBaseA, ParamBaseB):
def test_params(self, left, right):
pass
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=4)


def test_parametrize_marks_with_param_objects(pytester: pytest.Pytester) -> None:
_register_marks(pytester)
pytester.makepyfile(
"""
import pytest

class ParamBaseA:
pytestmark = pytest.mark.parametrize(
"left",
[
pytest.param("L1", marks=pytest.mark.shared("dup")),
pytest.param("L2", marks=pytest.mark.alpha("left")),
],
)

class ParamBaseB:
pytestmark = [
pytest.mark.parametrize(
"right",
[
pytest.param("R1", marks=pytest.mark.shared("dup")),
pytest.param("R2", marks=pytest.mark.beta("right")),
],
)
]

class TestParamCombined(ParamBaseA, ParamBaseB):
def test_params(self, left, right, request):
shared_args = {mark.args for mark in request.node.iter_markers(name="shared")}
if left == "L1" or right == "R1":
assert shared_args == {("dup",)}
else:
assert shared_args == set()
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=4)


def test_class_pytestmark_forms_supported(pytester: pytest.Pytester) -> None:
_register_marks(pytester)
pytester.makepyfile(
"""
import pytest

class Decorated:
pytestmark = [pytest.mark.alpha("decorated")]

@pytest.mark.beta("derived")
class TestDerived(Decorated):
@pytest.mark.gamma("method")
def test_forms(self, request):
seen = [(mark.name, mark.args) for mark in request.node.iter_markers()]
assert seen == [
("gamma", ("method",)),
("beta", ("derived",)),
("alpha", ("decorated",)),
]
"""
)

result = pytester.runpytest()
result.assert_outcomes(passed=1)