Skip to content

Commit 0d57fa3

Browse files
committed
fix zluda torch cpp_extension
1 parent db0f6c7 commit 0d57fa3

File tree

3 files changed

+25
-39
lines changed

3 files changed

+25
-39
lines changed

installer.py

+10-24
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,6 @@ def install_rocm_zluda(torch_command):
517517
log.info("For ZLUDA support specify '--use-zluda'")
518518
log.info('Using CPU-only torch')
519519
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision')
520-
521-
# conceal ROCm installed
522-
rocm.conceal()
523520
else:
524521
if rocm.version is None or float(rocm.version) > 6.1: # assume the latest if version check fails
525522
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/rocm6.1')
@@ -596,16 +593,6 @@ def install_openvino(torch_command):
596593
return torch_command
597594

598595

599-
def is_rocm_available(allow_rocm):
600-
if not allow_rocm:
601-
return False
602-
if installed('torch-directml', quiet=True):
603-
log.debug('DirectML installation is detected. Skipping HIP SDK check.')
604-
return False
605-
from modules.rocm import is_installed
606-
return is_installed
607-
608-
609596
def install_torch_addons():
610597
xformers_package = os.environ.get('XFORMERS_PACKAGE', '--pre xformers') if opts.get('cross_attention_optimization', '') == 'xFormers' or args.use_xformers else 'none'
611598
triton_command = os.environ.get('TRITON_COMMAND', 'triton') if sys.platform == 'linux' else None
@@ -648,6 +635,7 @@ def check_torch():
648635
if args.profile:
649636
pr = cProfile.Profile()
650637
pr.enable()
638+
from modules import rocm
651639
allow_cuda = not (args.use_rocm or args.use_directml or args.use_ipex or args.use_openvino)
652640
allow_rocm = not (args.use_cuda or args.use_directml or args.use_ipex or args.use_openvino)
653641
allow_ipex = not (args.use_cuda or args.use_rocm or args.use_directml or args.use_openvino)
@@ -663,15 +651,8 @@ def check_torch():
663651
log.info('nVidia CUDA toolkit detected: nvidia-smi present')
664652
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/cu121')
665653
install('onnxruntime-gpu', 'onnxruntime-gpu', ignore=True, quiet=True)
666-
elif is_rocm_available(allow_rocm):
654+
elif allow_rocm and rocm.is_installed:
667655
torch_command = install_rocm_zluda(torch_command)
668-
669-
from modules import rocm
670-
if rocm.is_wsl: # WSL ROCm
671-
try:
672-
rocm.load_hsa_runtime()
673-
except OSError:
674-
log.error("Failed to preload HSA Runtime library.")
675656
elif is_ipex_available(allow_ipex):
676657
torch_command = install_ipex(torch_command)
677658
elif allow_openvino and args.use_openvino:
@@ -686,9 +667,6 @@ def check_torch():
686667
if 'torch' in torch_command and not args.version:
687668
install(torch_command, 'torch torchvision')
688669
install('onnxruntime-directml', 'onnxruntime-directml', ignore=True)
689-
from modules import rocm
690-
if rocm.is_installed:
691-
rocm.conceal()
692670
else:
693671
if args.use_zluda:
694672
log.warning("ZLUDA failed to initialize: no HIP SDK found")
@@ -734,6 +712,14 @@ def check_torch():
734712
log.error(f'Could not load torch: {e}')
735713
if not args.ignore:
736714
sys.exit(1)
715+
if rocm.is_installed:
716+
if sys.platform == "win32": # CPU, DirectML, ZLUDA
717+
rocm.conceal()
718+
elif rocm.is_wsl: # WSL ROCm
719+
try:
720+
rocm.load_hsa_runtime()
721+
except OSError:
722+
log.error("Failed to preload HSA Runtime library.")
737723
if args.version:
738724
return
739725
if not args.skip_all:

modules/zluda_hijacks.py

-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import sys
31
import torch
42

53

@@ -10,19 +8,6 @@ def topk(tensor: torch.Tensor, *args, **kwargs):
108
return torch.return_types.topk((values.to(device), indices.to(device),))
119

1210

13-
def _join_rocm_home(*paths) -> str:
14-
from torch.utils.cpp_extension import ROCM_HOME
15-
return os.path.join(ROCM_HOME, *paths)
16-
17-
1811
def do_hijack():
1912
torch.version.hip = "5.7"
2013
torch.topk = topk
21-
platform = sys.platform
22-
sys.platform = ""
23-
from torch.utils import cpp_extension
24-
sys.platform = platform
25-
cpp_extension.IS_WINDOWS = platform == "win32"
26-
cpp_extension.IS_MACOS = False
27-
cpp_extension.IS_LINUX = platform.startswith('linux')
28-
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access

modules/zluda_installer.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23
import ctypes
34
import shutil
45
import zipfile
@@ -61,3 +62,17 @@ def load(zluda_path: os.PathLike) -> None:
6162
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
6263
for v in DLL_MAPPING.values():
6364
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
65+
66+
def conceal():
67+
import torch # pylint: disable=unused-import
68+
platform = sys.platform
69+
sys.platform = ""
70+
from torch.utils import cpp_extension
71+
sys.platform = platform
72+
cpp_extension.IS_WINDOWS = platform == "win32"
73+
cpp_extension.IS_MACOS = False
74+
cpp_extension.IS_LINUX = platform.startswith('linux')
75+
def _join_rocm_home(*paths) -> str:
76+
return os.path.join(cpp_extension.ROCM_HOME, *paths)
77+
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
78+
rocm.conceal = conceal

0 commit comments

Comments
 (0)