diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index a67f7775b1..88caa5c0b8 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -13,6 +13,14 @@ except ImportError: # noqa: F401 pytest.skip("triton is not installed", allow_module_level=True) # noqa: F401 import torch + +# Skip entire test if CUDA is not available or ROCM is enabled +if not torch.cuda.is_available() or torch.version.hip is not None: + pytest.skip( + "CUDA is not available/ ROCM support is under development", + allow_module_level=True, + ) + from bitsandbytes.functional import ( create_dynamic_map, dequantize_blockwise,