@@ -517,9 +517,6 @@ def install_rocm_zluda(torch_command):
517
517
log .info ("For ZLUDA support specify '--use-zluda'" )
518
518
log .info ('Using CPU-only torch' )
519
519
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch torchvision' )
520
-
521
- # conceal ROCm installed
522
- rocm .conceal ()
523
520
else :
524
521
if rocm .version is None or float (rocm .version ) > 6.1 : # assume the latest if version check fails
525
522
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch torchvision --index-url https://download.pytorch.org/whl/rocm6.1' )
@@ -596,16 +593,6 @@ def install_openvino(torch_command):
596
593
return torch_command
597
594
598
595
599
- def is_rocm_available (allow_rocm ):
600
- if not allow_rocm :
601
- return False
602
- if installed ('torch-directml' , quiet = True ):
603
- log .debug ('DirectML installation is detected. Skipping HIP SDK check.' )
604
- return False
605
- from modules .rocm import is_installed
606
- return is_installed
607
-
608
-
609
596
def install_torch_addons ():
610
597
xformers_package = os .environ .get ('XFORMERS_PACKAGE' , '--pre xformers' ) if opts .get ('cross_attention_optimization' , '' ) == 'xFormers' or args .use_xformers else 'none'
611
598
triton_command = os .environ .get ('TRITON_COMMAND' , 'triton' ) if sys .platform == 'linux' else None
@@ -648,6 +635,7 @@ def check_torch():
648
635
if args .profile :
649
636
pr = cProfile .Profile ()
650
637
pr .enable ()
638
+ from modules import rocm
651
639
allow_cuda = not (args .use_rocm or args .use_directml or args .use_ipex or args .use_openvino )
652
640
allow_rocm = not (args .use_cuda or args .use_directml or args .use_ipex or args .use_openvino )
653
641
allow_ipex = not (args .use_cuda or args .use_rocm or args .use_directml or args .use_openvino )
@@ -663,15 +651,8 @@ def check_torch():
663
651
log .info ('nVidia CUDA toolkit detected: nvidia-smi present' )
664
652
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch torchvision --index-url https://download.pytorch.org/whl/cu121' )
665
653
install ('onnxruntime-gpu' , 'onnxruntime-gpu' , ignore = True , quiet = True )
666
- elif is_rocm_available ( allow_rocm ) :
654
+ elif allow_rocm and rocm . is_installed :
667
655
torch_command = install_rocm_zluda (torch_command )
668
-
669
- from modules import rocm
670
- if rocm .is_wsl : # WSL ROCm
671
- try :
672
- rocm .load_hsa_runtime ()
673
- except OSError :
674
- log .error ("Failed to preload HSA Runtime library." )
675
656
elif is_ipex_available (allow_ipex ):
676
657
torch_command = install_ipex (torch_command )
677
658
elif allow_openvino and args .use_openvino :
@@ -686,9 +667,6 @@ def check_torch():
686
667
if 'torch' in torch_command and not args .version :
687
668
install (torch_command , 'torch torchvision' )
688
669
install ('onnxruntime-directml' , 'onnxruntime-directml' , ignore = True )
689
- from modules import rocm
690
- if rocm .is_installed :
691
- rocm .conceal ()
692
670
else :
693
671
if args .use_zluda :
694
672
log .warning ("ZLUDA failed to initialize: no HIP SDK found" )
@@ -734,6 +712,14 @@ def check_torch():
734
712
log .error (f'Could not load torch: { e } ' )
735
713
if not args .ignore :
736
714
sys .exit (1 )
715
+ if rocm .is_installed :
716
+ if sys .platform == "win32" : # CPU, DirectML, ZLUDA
717
+ rocm .conceal ()
718
+ elif rocm .is_wsl : # WSL ROCm
719
+ try :
720
+ rocm .load_hsa_runtime ()
721
+ except OSError :
722
+ log .error ("Failed to preload HSA Runtime library." )
737
723
if args .version :
738
724
return
739
725
if not args .skip_all :
0 commit comments