diff --git a/tests/conftest.py b/tests/conftest.py index 673d30d3b6..2f969e3444 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,17 @@ import pyhf +def pytest_addoption(parser): + parser.addoption( + "--disable-backend", + action="append", + type=str, + default=[], + choices=["tensorflow", "pytorch", "jax", "minuit"], + help="list of backends to disable in tests", + ) + + # Factory as fixture pattern @pytest.fixture def get_json_from_tarfile(): @@ -59,14 +70,14 @@ def reset_backend(): @pytest.fixture( scope='function', params=[ - (pyhf.tensor.numpy_backend(), None), - (pyhf.tensor.pytorch_backend(), None), - (pyhf.tensor.pytorch_backend(precision='64b'), None), - (pyhf.tensor.tensorflow_backend(), None), - (pyhf.tensor.jax_backend(), None), + (("numpy_backend", dict()), ("scipy_optimizer", dict())), + (("pytorch_backend", dict()), ("scipy_optimizer", dict())), + (("pytorch_backend", dict(precision="64b")), ("scipy_optimizer", dict())), + (("tensorflow_backend", dict()), ("scipy_optimizer", dict())), + (("jax_backend", dict()), ("scipy_optimizer", dict())), ( - pyhf.tensor.numpy_backend(poisson_from_normal=True), - pyhf.optimize.minuit_optimizer(), + ("numpy_backend", dict(poisson_from_normal=True)), + ("minuit_optimizer", dict()), ), ], ids=['numpy', 'pytorch', 'pytorch64', 'tensorflow', 'jax', 'numpy_minuit'], @@ -87,13 +98,26 @@ def backend(request): only_backends = [ pid for pid in param_ids if request.node.get_closest_marker(f'only_{pid}') ] + disable_backend = any( + backend in param_id for backend in request.config.option.disable_backend + ) if skip_backend and (param_id in only_backends): raise ValueError( f"Must specify skip_{param_id} or only_{param_id} but not both!" ) - if skip_backend: + if disable_backend: + pytest.skip( + f"skipping {func_name} as the backend is disabled via " + + " ".join( + [ + f"--disable-backend {choice}" + for choice in request.config.option.disable_backend + ] + ) + ) + elif skip_backend: pytest.skip(f"skipping {func_name} as specified") elif only_backends and param_id not in only_backends: pytest.skip( @@ -109,10 +133,14 @@ def backend(request): pytest.mark.xfail(reason=f"expect {func_name} to fail as specified") ) + tensor_config, optimizer_config = request.param + + tensor = getattr(pyhf.tensor, tensor_config[0])(**tensor_config[1]) + optimizer = getattr(pyhf.optimize, optimizer_config[0])(**optimizer_config[1]) # actual execution here, after all checks is done - pyhf.set_backend(*request.param) + pyhf.set_backend(tensor, optimizer) - yield request.param + yield (tensor, optimizer) @pytest.fixture(