@@ -482,13 +482,33 @@ def is_rocm_available():
482
482
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.2.0 torchvision --index-url https://download.pytorch.org/whl/cu118' )
483
483
log .warning ("ZLUDA support: experimental" )
484
484
zluda_need_dll_patch = is_windows and not installed ('torch' )
485
+ zluda_path = find_zluda ()
486
+ if zluda_path is None :
487
+ import urllib .request
488
+ if is_windows :
489
+ import zipfile
490
+ archive_type = zipfile .ZipFile
491
+ zluda_url = 'https://github.com/lshqqytiger/ZLUDA/releases/download/v3.5-win/ZLUDA-windows-amd64.zip'
492
+ else :
493
+ import tarfile
494
+ archive_type = tarfile .TarFile
495
+ zluda_url = 'https://github.com/vosen/ZLUDA/releases/download/v3/zluda-3-linux.tar.gz'
496
+ urllib .request .urlretrieve (zluda_url , '_zluda' )
497
+ with archive_type ('_zluda' , 'r' ) as f :
498
+ f .extractall ('.zluda' )
499
+ zluda_path = os .path .abspath ('./.zluda' )
500
+ os .remove ('_zluda' )
501
+ log .debug (f'Found ZLUDA in { zluda_path } ' )
502
+ paths = os .environ .get ('PATH' , '.' )
503
+ if zluda_path not in paths :
504
+ os .environ ['PATH' ] = paths + ';' + zluda_path
485
505
elif is_windows : # TODO TBD after ROCm for Windows is released
486
506
log .warning ("HIP SDK is detected, but no Torch release for Windows available" )
487
507
log .info ("For ZLUDA support specify '--use-zluda'" )
488
508
log .info ('Using CPU-only torch' )
489
509
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch torchvision' )
490
510
else :
491
- if rocm_ver in {"5.7" }:
511
+ if rocm_ver in {"5.7" , "6.0" }:
492
512
torch_command = os .environ .get ('TORCH_COMMAND' , f'torch torchvision --pre --index-url https://download.pytorch.org/whl/nightly/rocm{ rocm_ver } ' )
493
513
elif rocm_ver in {"5.5" , "5.6" }:
494
514
torch_command = os .environ .get ('TORCH_COMMAND' , f'torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm{ rocm_ver } ' )
@@ -909,14 +929,19 @@ def get_onnxruntime_source_for_rocm(rocm_ver):
909
929
return 'onnxruntime-gpu'
910
930
911
931
912
- def patch_zluda ():
932
+ def find_zluda ():
913
933
zluda_path = os .environ .get ('ZLUDA' , None )
914
934
if zluda_path is None :
915
935
paths = os .environ .get ('PATH' , '' ).split (';' )
916
936
for path in paths :
917
937
if os .path .exists (os .path .join (path , 'zluda_redirect.dll' )):
918
938
zluda_path = path
919
939
break
940
+ return zluda_path
941
+
942
+
943
+ def patch_zluda ():
944
+ zluda_path = find_zluda ()
920
945
if zluda_path is None :
921
946
log .warning ('Failed to automatically patch torch with ZLUDA. Could not find ZLUDA from PATH.' )
922
947
return
@@ -1028,8 +1053,6 @@ def check_timestamp():
1028
1053
1029
1054
def add_args (parser ):
1030
1055
group = parser .add_argument_group ('Setup options' )
1031
- group .add_argument ("--log" , type = str , default = os .environ .get ("SD_LOG" , None ), help = "Set log file, default: %(default)s" )
1032
- group .add_argument ('--debug' , default = os .environ .get ("SD_DEBUG" ,False ), action = 'store_true' , help = "Run installer with debug logging, default: %(default)s" )
1033
1056
group .add_argument ('--reset' , default = os .environ .get ("SD_RESET" ,False ), action = 'store_true' , help = "Reset main repository to latest version, default: %(default)s" )
1034
1057
group .add_argument ('--upgrade' , default = os .environ .get ("SD_UPGRADE" ,False ), action = 'store_true' , help = "Upgrade main repository to latest version, default: %(default)s" )
1035
1058
group .add_argument ('--requirements' , default = os .environ .get ("SD_REQUIREMENTS" ,False ), action = 'store_true' , help = "Force re-check of requirements, default: %(default)s" )
@@ -1039,6 +1062,7 @@ def add_args(parser):
1039
1062
group .add_argument ("--use-ipex" , default = os .environ .get ("SD_USEIPEX" ,False ), action = 'store_true' , help = "Force use Intel OneAPI XPU backend, default: %(default)s" )
1040
1063
group .add_argument ("--use-cuda" , default = os .environ .get ("SD_USECUDA" ,False ), action = 'store_true' , help = "Force use nVidia CUDA backend, default: %(default)s" )
1041
1064
group .add_argument ("--use-rocm" , default = os .environ .get ("SD_USEROCM" ,False ), action = 'store_true' , help = "Force use AMD ROCm backend, default: %(default)s" )
1065
+ group .add_argument ('--use-zluda' , default = os .environ .get ("SD_USEZLUDA" , False ), action = 'store_true' , help = "Force use ZLUDA, AMD GPUs only, default: %(default)s" )
1042
1066
group .add_argument ("--use-xformers" , default = os .environ .get ("SD_USEXFORMERS" ,False ), action = 'store_true' , help = "Force use xFormers cross-optimization, default: %(default)s" )
1043
1067
group .add_argument ('--skip-requirements' , default = os .environ .get ("SD_SKIPREQUIREMENTS" ,False ), action = 'store_true' , help = "Skips checking and installing requirements, default: %(default)s" )
1044
1068
group .add_argument ('--skip-extensions' , default = os .environ .get ("SD_SKIPEXTENSION" ,False ), action = 'store_true' , help = "Skips running individual extension installers, default: %(default)s" )
@@ -1053,6 +1077,13 @@ def add_args(parser):
1053
1077
group .add_argument ('--ignore' , default = os .environ .get ("SD_IGNORE" ,False ), action = 'store_true' , help = "Ignore any errors and attempt to continue" )
1054
1078
group .add_argument ('--safe' , default = os .environ .get ("SD_SAFE" ,False ), action = 'store_true' , help = "Run in safe mode with no user extensions" )
1055
1079
1080
+ group = parser .add_argument_group ('Logging options' )
1081
+ group .add_argument ("--log" , type = str , default = os .environ .get ("SD_LOG" , None ), help = "Set log file, default: %(default)s" )
1082
+ group .add_argument ('--debug' , default = os .environ .get ("SD_DEBUG" ,False ), action = 'store_true' , help = "Run installer with debug logging, default: %(default)s" )
1083
+ group .add_argument ("--profile" , default = os .environ .get ("SD_PROFILE" , False ), action = 'store_true' , help = "Run profiler, default: %(default)s" )
1084
+ group .add_argument ('--docs' , default = os .environ .get ("SD_DOCS" , False ), action = 'store_true' , help = "Mount API docs, default: %(default)s" )
1085
+ group .add_argument ("--api-log" , default = os .environ .get ("SD_APILOG" , False ), action = 'store_true' , help = "Enable logging of all API requests, default: %(default)s" )
1086
+
1056
1087
1057
1088
def parse_args (parser ):
1058
1089
# command line args
0 commit comments