diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index da8b92711380e..3915780445ab9 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -6,8 +6,11 @@ import pytest import ray import torch -from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit) +from vllm.utils import is_hip + +if (not is_hip()): + from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs