diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d17cc8bc2..92a633e6e8d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -77,9 +77,11 @@ def get_torch_device(): global cpu_state global extensions_devices - extension = comfy.utils.get_extension_calling() - if extension is not None and extension in extensions_devices: - return torch.device(extensions_devices[extension]) + extension_stack = comfy.utils.get_extension_calling() + if extension_stack is not None: + for extension in extension_stack: + if extension in extensions_devices: + return torch.device(extensions_devices[extension]) if directml_enabled: global directml_device diff --git a/comfy/utils.py b/comfy/utils.py index 3ef66d23a52..c53332b5b05 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -14,7 +14,15 @@ def get_extension_calling(): if os.sep + "custom_nodes" + os.sep in frame.filename: stack_module = inspect.getmodule(frame[0]) if stack_module: - return re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__.replace("\\", ".").replace("/", ".")).split(".")[0] + stack = [] + + parts = re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__.replace(os.sep, ".")).split(".") + + while len(parts) > 0: + stack.append(".".join(parts)) + parts.pop() + + return stack return None