From b7e349dd132aa8f09c70434b43b82859edafd387 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Tue, 22 Oct 2024 23:29:58 -0500 Subject: [PATCH] use cuda with onnx if available --- basic_pitch/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 2544cd9..25b062c 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -129,7 +129,10 @@ def __init__(self, model_path: Union[pathlib.Path, str]): present.append("ONNX") try: self.model_type = Model.MODEL_TYPES.ONNX - self.model = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + providers = ["CPUExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers(): + providers.insert(0, "CUDAExecutionProvider") + self.model = ort.InferenceSession(str(model_path), providers=providers) return except Exception as e: if str(model_path).endswith(".onnx"):