Skip to content

Commit 61961f5

Browse files
committed
hipblaslt check torch version
1 parent 0d57fa3 commit 61961f5

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

installer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def install_rocm_zluda(torch_command):
531531
ort_package = os.environ.get('ONNXRUNTIME_PACKAGE', f"--pre onnxruntime-training{'' if ort_version is None else ('==' + ort_version)} --index-url https://pypi.lsh.sh/{rocm.version[0]}{rocm.version[2]} --extra-index-url https://pypi.org/simple")
532532
install(ort_package, 'onnxruntime-training')
533533

534-
if bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1"))) and rocm.version != "6.2":
534+
if rocm.version == rocm.version_torch and rocm.get_blaslt_enabled():
535535
log.debug(f'hipBLASLt arch={hip_default_device.name} available={hip_default_device.blaslt_supported}')
536536
rocm.set_blaslt_enabled(hip_default_device.blaslt_supported)
537537
return torch_command

modules/rocm.py

+12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ctypes
44
import shutil
55
import subprocess
6+
import importlib.metadata
67
from typing import Union, List
78

89

@@ -81,6 +82,13 @@ def get_gfx_version(self) -> Union[str, None]:
8182
return None
8283

8384

85+
def get_version_torch() -> Union[str, None]:
86+
version_ = importlib.metadata.version("torch")
87+
if "+rocm" not in version_: # unofficial build, non-rocm torch.
88+
return None
89+
return version_.split("+rocm")[1]
90+
91+
8492
if sys.platform == "win32":
8593
def find() -> Union[str, None]:
8694
hip_path = shutil.which("hipconfig")
@@ -173,10 +181,14 @@ def set_blaslt_enabled(enabled: bool) -> None:
173181
else:
174182
os.environ["TORCH_BLAS_PREFER_HIPBLASLT"] = "0"
175183

184+
def get_blaslt_enabled() -> bool:
185+
return bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1")))
186+
176187
is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', None) is not None
177188
path = find()
178189
is_installed = False
179190
version = None
191+
version_torch = get_version_torch()
180192
if path is not None:
181193
is_installed = True
182194
version = get_version()

0 commit comments

Comments
 (0)