From 0ae1c22d3527f7ad63d7ce11d4042c9833c2195e Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Mon, 6 Nov 2023 01:08:25 -0300 Subject: [PATCH] Multi level get_extension_calling to allow subfolder precision --- comfy/model_management.py | 8 +++++--- comfy/utils.py | 10 +++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d17cc8bc2..92a633e6e8d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -77,9 +77,11 @@ def get_torch_device(): global cpu_state global extensions_devices - extension = comfy.utils.get_extension_calling() - if extension is not None and extension in extensions_devices: - return torch.device(extensions_devices[extension]) + extension_stack = comfy.utils.get_extension_calling() + if extension_stack is not None: + for extension in extension_stack: + if extension in extensions_devices: + return torch.device(extensions_devices[extension]) if directml_enabled: global directml_device diff --git a/comfy/utils.py b/comfy/utils.py index 3ef66d23a52..c53332b5b05 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -14,7 +14,15 @@ def get_extension_calling(): if os.sep + "custom_nodes" + os.sep in frame.filename: stack_module = inspect.getmodule(frame[0]) if stack_module: - return re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__.replace("\\", ".").replace("/", ".")).split(".")[0] + stack = [] + + parts = re.sub(r".*\.?custom_nodes\.([^\.]+).*", r"\1", stack_module.__name__.replace(os.sep, ".")).split(".") + + while len(parts) > 0: + stack.append(".".join(parts)) + parts.pop() + + return stack return None