From 03aa0fd2c3035891d945c95ad4f78fd201df7953 Mon Sep 17 00:00:00 2001 From: Giordon Stark Date: Wed, 20 Sep 2023 10:47:11 -0700 Subject: [PATCH] add a --disable-backend option that adds backends to be disabled --- tests/conftest.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 673d30d3b6..1338e7d36a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,17 @@ import pyhf +def pytest_addoption(parser): + parser.addoption( + "--disable-backend", + action="append", + type=str, + default=[], + choices=["tensorflow", "pytorch", "jax", "minuit"], + help="list of backends to disable in tests", + ) + + # Factory as fixture pattern @pytest.fixture def get_json_from_tarfile(): @@ -59,14 +70,14 @@ def reset_backend(): @pytest.fixture( scope='function', params=[ - (pyhf.tensor.numpy_backend(), None), - (pyhf.tensor.pytorch_backend(), None), - (pyhf.tensor.pytorch_backend(precision='64b'), None), - (pyhf.tensor.tensorflow_backend(), None), - (pyhf.tensor.jax_backend(), None), + (("numpy_backend", dict()), ("scipy_optimizer", dict())), + (("pytorch_backend", dict()), ("scipy_optimizer", dict())), + (("pytorch_backend", dict(precision="64b")), ("scipy_optimizer", dict())), + (("tensorflow_backend", dict()), ("scipy_optimizer", dict())), + (("jax_backend", dict()), ("scipy_optimizer", dict())), ( - pyhf.tensor.numpy_backend(poisson_from_normal=True), - pyhf.optimize.minuit_optimizer(), + ("numpy_backend", dict(poisson_from_normal=True)), + ("minuit_optimizer", dict()), ), ], ids=['numpy', 'pytorch', 'pytorch64', 'tensorflow', 'jax', 'numpy_minuit'], @@ -87,13 +98,20 @@ def backend(request): only_backends = [ pid for pid in param_ids if request.node.get_closest_marker(f'only_{pid}') ] + disable_backend = any( + backend in param_id for backend in request.config.disable_backend + ) if skip_backend and (param_id in only_backends): raise ValueError( f"Must specify skip_{param_id} or only_{param_id} but not both!" ) - if skip_backend: + if disable_backend: + pytest.skip( + f"skipping {func_name} as the backend is disabled: {request.config.disable_backend}" + ) + elif skip_backend: pytest.skip(f"skipping {func_name} as specified") elif only_backends and param_id not in only_backends: pytest.skip(