Skip to content

Commit 885f272

Browse files
committed
zluda fix recursion
1 parent 84d813b commit 885f272

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

modules/zluda.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch._prims_common import DeviceLikeType
66
import onnxruntime as ort
77
from modules import shared, devices
8+
from modules.onnx_impl.execution_providers import available_execution_providers, ExecutionProvider
89

910

1011
PLATFORM = sys.platform
@@ -61,10 +62,10 @@ def initialize_zluda():
6162
shared.opts.sdp_options = ['Math attention']
6263

6364
# ONNX Runtime is not supported
64-
ort.capi._pybind_state.get_available_providers = lambda: [v for v in ort.get_available_providers() if v != 'CUDAExecutionProvider'] # pylint: disable=protected-access
65+
ort.capi._pybind_state.get_available_providers = lambda: [v for v in available_execution_providers if v != ExecutionProvider.CUDA] # pylint: disable=protected-access
6566
ort.get_available_providers = ort.capi._pybind_state.get_available_providers # pylint: disable=protected-access
66-
if shared.opts.onnx_execution_provider == 'CUDAExecutionProvider':
67-
shared.opts.onnx_execution_provider = 'CPUExecutionProvider'
67+
if shared.opts.onnx_execution_provider == ExecutionProvider.CUDA:
68+
shared.opts.onnx_execution_provider = ExecutionProvider.CPU
6869

6970
devices.device_codeformer = devices.cpu
7071

0 commit comments

Comments
 (0)