@@ -544,7 +544,7 @@ def check_onnx():
544
544
if not installed ('onnx' , quiet = True ):
545
545
install ('onnx' , 'onnx' , ignore = True )
546
546
if not installed ('onnxruntime' , quiet = True ) and not (installed ('onnxruntime-gpu' , quiet = True ) or installed ('onnxruntime-openvino' , quiet = True ) or installed ('onnxruntime-training' , quiet = True )): # allow either
547
- install ('onnxruntime ' , 'onnxruntime' , ignore = True )
547
+ install (os . environ . get ( 'ONNXRUNTIME_COMMAND ' , 'onnxruntime' ) , ignore = True )
548
548
ts ('onnx' , t_start )
549
549
550
550
@@ -555,7 +555,6 @@ def install_cuda():
555
555
if args .use_nightly :
556
556
cmd = os .environ .get ('TORCH_COMMAND' , 'pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --extra-index-url https://download.pytorch.org/whl/nightly/cu126' )
557
557
else :
558
- # cmd = os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cu124 torchvision==0.20.1+cu124 --index-url https://download.pytorch.org/whl/cu124')
559
558
cmd = os .environ .get ('TORCH_COMMAND' , 'torch==2.6.0+cu126 torchvision==0.21.0+cu126 --index-url https://download.pytorch.org/whl/cu126' )
560
559
return cmd
561
560
@@ -643,9 +642,7 @@ def install_rocm_zluda():
643
642
log .info ('Using CPU-only torch' )
644
643
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch torchvision' )
645
644
else :
646
- # Python 3.12 will cause compatibility issues with other dependencies
647
- # ROCm supports Python 3.12 so don't block it but don't advertise it in the error message
648
- check_python (supported_minors = [9 , 10 , 11 , 12 ], reason = 'ROCm backend requires Python 3.9, 3.10 or 3.11' )
645
+ check_python (supported_minors = [9 , 10 , 11 , 12 ], reason = 'ROCm backend requires a Python version between 3.9 and 3.12' )
649
646
if args .use_nightly :
650
647
if rocm .version is None or float (rocm .version ) >= 6.3 : # assume the latest if version check fails
651
648
torch_command = os .environ .get ('TORCH_COMMAND' , '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.3' )
@@ -660,8 +657,7 @@ def install_rocm_zluda():
660
657
elif rocm .version == "6.1" :
661
658
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.6.0+rocm6.1 torchvision==0.21.0+rocm6.1 --index-url https://download.pytorch.org/whl/rocm6.1' )
662
659
elif rocm .version == "6.0" :
663
- # lock to 2.4.1 instead of 2.5.1 for performance reasons
664
- # there are no support for torch 2.6.0 for rocm 6.0
660
+ # lock to 2.4.1 instead of 2.5.1 for performance reasons there are no support for torch 2.6.0 for rocm 6.0
665
661
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.4.1+rocm6.0 torchvision==0.19.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0' )
666
662
elif float (rocm .version ) < 5.5 : # oldest supported version is 5.5
667
663
log .warning (f"ROCm: unsupported version={ rocm .version } " )
@@ -671,14 +667,6 @@ def install_rocm_zluda():
671
667
# older rocm (5.7) uses torch 2.3 or older
672
668
torch_command = os .environ .get ('TORCH_COMMAND' , f'torch torchvision --index-url https://download.pytorch.org/whl/rocm{ rocm .version } ' )
673
669
674
- if sys .version_info < (3 , 11 ):
675
- ort_version = os .environ .get ('ONNXRUNTIME_VERSION' , None )
676
- if rocm .version is None or float (rocm .version ) > 6.0 :
677
- ort_package = os .environ .get ('ONNXRUNTIME_PACKAGE' , f"--pre onnxruntime-training{ '' if ort_version is None else ('==' + ort_version )} --index-url https://pypi.lsh.sh/60 --extra-index-url https://pypi.org/simple" )
678
- else :
679
- ort_package = os .environ .get ('ONNXRUNTIME_PACKAGE' , f"--pre onnxruntime-training{ '' if ort_version is None else ('==' + ort_version )} --index-url https://pypi.lsh.sh/{ rocm .version [0 ]} { rocm .version [2 ]} --extra-index-url https://pypi.org/simple" )
680
- install (ort_package , 'onnxruntime-training' )
681
-
682
670
if installed ("torch" ) and device is not None :
683
671
if 'Flash attention' in opts .get ('sdp_options' , '' ):
684
672
if not installed ('flash-attn' ):
@@ -705,9 +693,7 @@ def install_rocm_zluda():
705
693
706
694
def install_ipex (torch_command ):
707
695
t_start = time .time ()
708
- # Python 3.12 will cause compatibility issues with other dependencies
709
- # IPEX supports Python 3.12 so don't block it but don't advertise it in the error message
710
- check_python (supported_minors = [9 , 10 , 11 , 12 ], reason = 'IPEX backend requires Python 3.9, 3.10 or 3.11' )
696
+ check_python (supported_minors = [9 , 10 , 11 , 12 ], reason = 'IPEX backend requires a Python version between 3.9 and 3.12' )
711
697
args .use_ipex = True # pylint: disable=attribute-defined-outside-init
712
698
log .info ('IPEX: Intel OneAPI toolkit detected' )
713
699
@@ -731,34 +717,24 @@ def install_ipex(torch_command):
731
717
if args .use_nightly :
732
718
torch_command = os .environ .get ('TORCH_COMMAND' , '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu' )
733
719
else :
734
- if "linux" in sys .platform :
735
- # default to US server. If The China server is needed, change .../release-whl/stable/xpu/us/ to .../release-whl/stable/xpu/cn/
736
- torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.5.1+cxx11.abi torchvision==0.20.1+cxx11.abi intel-extension-for-pytorch==2.5.10+xpu oneccl_bind_pt==2.5.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/' )
737
- if os .environ .get ('TRITON_COMMAND' , None ) is None :
738
- os .environ .setdefault ('TRITON_COMMAND' , '--pre pytorch-triton-xpu==3.1.0+91b14bf559 --index-url https://download.pytorch.org/whl/nightly/xpu' )
739
- # os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.15.1 intel-extension-for-tensorflow[xpu]==2.15.0.2')
740
- else :
741
- torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.6.0+xpu torchvision==0.21.0+xpu --index-url https://download.pytorch.org/whl/xpu' )
720
+ torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.6.0+xpu torchvision==0.21.0+xpu --index-url https://download.pytorch.org/whl/xpu' )
742
721
743
- install (os .environ .get ('OPENVINO_PACKAGE ' , 'openvino==2024.6.0' ), 'openvino' , ignore = True )
722
+ install (os .environ .get ('OPENVINO_COMMAND ' , 'openvino==2024.6.0' ), 'openvino' , ignore = True )
744
723
install ('nncf==2.7.0' , ignore = True , no_deps = True ) # requires older pandas
745
- install (os .environ .get ('ONNXRUNTIME_PACKAGE' , 'onnxruntime-openvino' ), 'onnxruntime-openvino' , ignore = True )
746
724
ts ('ipex' , t_start )
747
725
return torch_command
748
726
749
727
750
728
def install_openvino (torch_command ):
751
729
t_start = time .time ()
752
- # Python 3.12 will cause compatibility issues with other dependencies.
753
- # OpenVINO supports Python 3.12 so don't block it but don't advertise it in the error message
754
- check_python (supported_minors = [9 , 10 , 11 , 12 ], reason = 'OpenVINO backend requires Python 3.9, 3.10 or 3.11' )
730
+ check_python (supported_minors = [9 , 10 , 11 , 12 ], reason = 'OpenVINO backend requires a Python version between 3.9 and 3.12' )
755
731
log .info ('OpenVINO: selected' )
756
732
if sys .platform == 'darwin' :
757
733
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.3.1 torchvision==0.18.1' )
758
734
else :
759
735
torch_command = os .environ .get ('TORCH_COMMAND' , 'torch==2.3.1+cpu torchvision==0.18.1+cpu --index-url https://download.pytorch.org/whl/cpu' )
760
- install ( os . environ . get ( 'OPENVINO_PACKAGE' , 'openvino==2024.6.0' ), 'openvino' )
761
- install (os .environ .get ('ONNXRUNTIME_PACKAGE ' , 'onnxruntime- openvino' ), 'onnxruntime- openvino' , ignore = True )
736
+
737
+ install (os .environ .get ('OPENVINO_COMMAND ' , 'openvino==2024.6.0 ' ), 'openvino' )
762
738
install ('nncf==2.14.1' , 'nncf' )
763
739
os .environ .setdefault ('PYTORCH_TRACING_MODE' , 'TORCHFX' )
764
740
if os .environ .get ("NEOReadDebugKeys" , None ) is None :
0 commit comments