From be2725c2cbe90d56a462005506f88543bcd61881 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 21 Oct 2024 23:12:12 +0800 Subject: [PATCH 01/70] add dependency checker --- dependency_checker.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index b9fee0b..f7d814e 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -144,8 +144,14 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an print("ckpt_paths:", ckpt_paths) custom_nodes = list(set(custom_nodes)) # step 0: comfyui version - comfyui_version = inspect_repo_version(BASE_PATH) - + repo_info = inspect_repo_version(BASE_PATH) + if repo_info["repo"] == "": + repo_info["require_recheck"] = True + if repo_info["name"] in custom_dependencies["custom_nodes"]: + repo_info["repo"] = custom_dependencies["custom_nodes"][repo_info["name"]].get("repo", "") + repo_info["commit"] = custom_dependencies["custom_nodes"][repo_info["name"]].get("commit", "") + comfyui_version = repo_info + # step 1: custom nodes custom_nodes_list = [] custom_nodes_names = [] From bf8f352347ad680436db5686230ffe65c2801c15 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 21 Oct 2024 23:47:43 +0800 Subject: [PATCH 02/70] skip when .git is not found --- dependency_checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dependency_checker.py b/dependency_checker.py index f7d814e..bafe875 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -59,6 +59,10 @@ def inspect_repo_version(module_path): "repo": "", "commit": "" } + + if not os.path.isdir(os.path.join(module_path, ".git")): + return result + # Get the remote repository URL try: remote_url = subprocess.check_output( From b3042fdf6fab93a934d90d58e99a670a1e963a88 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 22 Oct 2024 15:33:04 +0800 Subject: [PATCH 03/70] fix bug --- dependency_checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index bafe875..d88804c 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -114,7 +114,8 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an if node_class_type is None: raise NotImplementedError(f"Missing nodes founded, please first install the missing nodes using ComfyUI Manager") node_cls = NODE_CLASS_MAPPINGS[node_class_type] - if hasattr(node_cls, "RELATIVE_PYTHON_MODULE"): + if hasattr(node_cls, "RELATIVE_PYTHON_MODULE") and node_cls.RELATIVE_PYTHON_MODULE.startswith("custom_nodes."): + print(node_cls.RELATIVE_PYTHON_MODULE) custom_nodes.append(node_cls.RELATIVE_PYTHON_MODULE) if node_class_type in model_loaders_info: for field_name, filename in node_info["inputs"].items(): From fc1f9afbcdb99c81cc951b43a76019894ca754eb Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 22 Oct 2024 19:47:04 +0800 Subject: [PATCH 04/70] add new output nodes --- comfy-nodes/output_image.py | 1 - comfy-nodes/output_text.py | 74 +++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 comfy-nodes/output_text.py diff --git a/comfy-nodes/output_image.py b/comfy-nodes/output_image.py index 3198e4f..f1ecc48 100644 --- a/comfy-nodes/output_image.py +++ b/comfy-nodes/output_image.py @@ -80,7 +80,6 @@ def save_video(self, filenames, **kwargs): preview_image = os.path.relpath(preview_image) video_path = os.path.relpath(video_path) results = {"ui": {"image": [preview_image], "video": [video_path]}} - print(results) return results diff --git a/comfy-nodes/output_text.py b/comfy-nodes/output_text.py new file mode 100644 index 0000000..8b74fce --- /dev/null +++ b/comfy-nodes/output_text.py @@ -0,0 +1,74 @@ + +json_type_mapipng = { + "text": "string", + "float": "number", + "integer": "integer" +} + +class ShellAgentOutputText: + TYPE_STR = "text" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("STRING", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + + RETURN_TYPES = () + FUNCTION = "output_var" + + OUTPUT_NODE = True + + CATEGORY = "shellagent" + DESCRIPTION = "output the text" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["output_name"], + "type": json_type_mapipng[cls.TYPE_STR] + } + return schema + + def output_var(self, **kwargs): + results = {"ui": {"output": [kwargs[self.TYPE_STR]]}} + return results + +class ShellAgentOutputFloat(ShellAgentOutputText): + TYPE_STR = "float" + DESCRIPTION = "output the float" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("STRING", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + + +class ShellAgentOutputInteger(ShellAgentOutputText): + TYPE_STR = "integer" + DESCRIPTION = "output the integer" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("STRING", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ShellAgentPluginOutputText": ShellAgentOutputText, + "ShellAgentPluginOutputFloat": ShellAgentOutputFloat, + "ShellAgentPluginOutputInteger": ShellAgentOutputInteger +} +NODE_DISPLAY_NAME_MAPPINGS = { + "ShellAgentPluginOutputText": "Output Text (ShellAgent Plugin)", + "ShellAgentPluginOutputFloat": "Output Float (ShellAgent Plugin)", + "ShellAgentPluginOutputInteger": "Output Integer (ShellAgent Plugin)", +} \ No newline at end of file From 00fb0b91cde6d2c196219747d820a462dbf84615 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Wed, 23 Oct 2024 10:46:48 +0800 Subject: [PATCH 05/70] handle model_searcher fail --- dependency_checker.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index d88804c..b09c585 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -100,7 +100,10 @@ def fetch_model_searcher_results(model_ids): } response = requests.post(url, headers=headers, json=data) - results = [item[:10] for item in response.json()] + if response.status_code == 200: + results = [item[:10] for item in response.json()] + else: + results = None return results def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes and models at the same time @@ -195,11 +198,12 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an # try to fetch from myshell model searcher missing_model_results_myshell = fetch_model_searcher_results(missing_model_ids) - for missing_model_id, missing_model_urls in zip(missing_model_ids, missing_model_results_myshell): - if len(missing_model_urls) > 0: - models_dict[missing_model_id]["require_recheck"] = False - models_dict[missing_model_id]["urls"] = missing_model_urls - print("successfully fetch results from myshell", models_dict[missing_model_id]) + if missing_model_results_myshell is not None: + for missing_model_id, missing_model_urls in zip(missing_model_ids, missing_model_results_myshell): + if len(missing_model_urls) > 0: + models_dict[missing_model_id]["require_recheck"] = False + models_dict[missing_model_id]["urls"] = missing_model_urls + print("successfully fetch results from myshell", models_dict[missing_model_id]) # step 3: handle local files process_local_file_path_async(file_mapping_dict, max_workers=20) From 26e57ef44fa78f1f904d88738aff2d2f4db95e68 Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Wed, 23 Oct 2024 16:32:40 +0800 Subject: [PATCH 06/70] Update dependency_checker.py --- dependency_checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index b09c585..5177809 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -91,7 +91,7 @@ def inspect_repo_version(module_path): def fetch_model_searcher_results(model_ids): import requests - url = "https://shellagent.myshell.ai/models_searcher/search_urls" + url = "https://models-searcher.myshell.life/search_urls" headers = { "Content-Type": "application/json" } @@ -215,4 +215,4 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an "models": models_dict, "files": files_dict, } - return results \ No newline at end of file + return results From 15c7d7f60cca307daea09685e57087f4e1a9627b Mon Sep 17 00:00:00 2001 From: shanexi Date: Thu, 24 Oct 2024 21:41:17 +0800 Subject: [PATCH 07/70] Optimize number input convert --- web/shellagent.js | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/web/shellagent.js b/web/shellagent.js index e21a7d1..cee0476 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -93,7 +93,7 @@ app.registerExtension({ try { arr = JSON.parse(widget.value) } catch { } - } else if(Array.isArray(widget.value)) { + } else if (Array.isArray(widget.value)) { arr = widget.value } @@ -191,25 +191,32 @@ app.registerExtension({ }) } if (["number"].indexOf(w.type) > -1) { - toInput.push({ - content: `${w.name} <- Input Interger`, - callback: () => { - this.convertWidgetToInput(w); - const node = addNode("ShellAgentPluginInputInteger", this, { before: true }); - const dvn = node.widgets.find(w => w.name === 'default_value') - dvn.value = w.value; - node.connect(0, this, this.inputs.length - 1); - } - }) toInput.push({ - content: `${w.name} <- Input Float`, - callback: () => { - this.convertWidgetToInput(w); - const node = addNode("ShellAgentPluginInputFloat", this, { before: true }); - const dvn = node.widgets.find(w => w.name === 'default_value') - dvn.value = w.value; - node.connect(0, this, this.inputs.length - 1); + content: w.name, + submenu: { + options: [ + { + content: 'Input Interger', + callback: () => { + this.convertWidgetToInput(w); + const node = addNode("ShellAgentPluginInputInteger", this, { before: true }); + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = w.value; + node.connect(0, this, this.inputs.length - 1); + } + }, + { + content: 'Input Float', + callback: () => { + this.convertWidgetToInput(w); + const node = addNode("ShellAgentPluginInputFloat", this, { before: true }); + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = w.value; + node.connect(0, this, this.inputs.length - 1); + } + } + ] } }) } From 45e7caca72505c4d8d00e0eb56230900086d7707 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 25 Oct 2024 15:01:20 +0800 Subject: [PATCH 08/70] use folder_path to find the models --- dependency_checker.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index b09c585..acaa788 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -15,7 +15,7 @@ model_list_json = json.load(open(os.path.join(os.path.dirname(__file__), "model_info.json"))) model_loaders_info = json.load(open(os.path.join(os.path.dirname(__file__), "model_loader_info.json"))) node_deps_info = json.load(open(os.path.join(os.path.dirname(__file__), "node_deps_info.json"))) - +node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx"] @@ -108,6 +108,8 @@ def fetch_model_searcher_results(model_ids): def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes and models at the same time from nodes import NODE_CLASS_MAPPINGS + import folder_paths + custom_nodes = [] ckpt_paths = [] @@ -122,10 +124,12 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an custom_nodes.append(node_cls.RELATIVE_PYTHON_MODULE) if node_class_type in model_loaders_info: for field_name, filename in node_info["inputs"].items(): + if type(filename) != str: + continue for item in model_loaders_info[node_class_type]: pattern = item["field_name"] - if re.match(f"^{pattern}$", field_name): - ckpt_path = os.path.join(MODELS_DIR, item["save_path"], filename) + if re.match(f"^{pattern}$", field_name) and any([filename.endswith(possible_suffix) for possible_suffix in model_suffix]): + ckpt_path = folder_paths.get_full_path_or_raise(item["save_path"], filename) ckpt_paths.append(ckpt_path) else: for field_name, filename in node_info["inputs"].items(): @@ -140,9 +144,11 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an # find possible paths matching_files = [] # Walk through all subdirectories and files in the directory - for possible_filename in glob.glob(os.path.join(MODELS_DIR, "**", "*"), recursive=True): - if os.path.isfile(possible_filename) and possible_filename.endswith(filename): - matching_files.append(possible_filename) + for possible_folder_name in folder_paths.folder_names_and_paths: + full_path = folder_paths.get_full_path(possible_folder_name) + if full_path is None: + continue + matching_files.append(full_path) print(f"matched files: {matching_files}") if len(matching_files) == 1: ckpt_paths.append(matching_files[0]) @@ -183,6 +189,12 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an repo_info = inspect_repo_version(os.path.join("custom_nodes", deps_node["name"])) deps_node["commit"] = repo_info["commit"] custom_nodes_list.append(deps_node) + custom_nodes_names.append(deps_node["name"]) + + black_list_nodes = [] + for repo_name in custom_nodes_names: + if repo_name in node_blacklist: + black_list_nodes.append({"name": repo_name, "reason": node_blacklist[repo_name]["reason"]}) # step 2: models models_dict = {} @@ -209,10 +221,14 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an process_local_file_path_async(file_mapping_dict, max_workers=20) files_dict = {v[0]: {"filename": windows_to_linux_path(os.path.relpath(v[2], BASE_PATH)), "urls": [v[1]]} for v in file_mapping_dict.values()} - results = { + depencencies = { "comfyui_version": comfyui_version, "custom_nodes": custom_nodes_list, "models": models_dict, "files": files_dict, } - return results \ No newline at end of file + return_dict = { + "dependencies": depencencies, + "black_list_nodes": black_list_nodes, + } + return return_dict \ No newline at end of file From 232bc67c9d8723a99f543225179eedf67df09155 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 25 Oct 2024 16:01:48 +0800 Subject: [PATCH 09/70] handle relative path --- dependency_checker.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index 64677fb..c21fbf0 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -19,11 +19,8 @@ model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx"] -def handle_model_info(ckpt_path): +def handle_model_info(ckpt_path, filename, rel_save_path): ckpt_path = windows_to_linux_path(ckpt_path) - filename = os.path.basename(ckpt_path) - dirname = os.path.dirname(ckpt_path) - save_path = os.path.dirname(os.path.relpath(ckpt_path, MODELS_DIR)) metadata_path = ckpt_path + ".json" if os.path.isfile(metadata_path): metadata = json.load(open(metadata_path)) @@ -35,7 +32,7 @@ def handle_model_info(ckpt_path): model_id = compute_sha256(ckpt_path) data = { "id": model_id, - "save_path": save_path, + "save_path": rel_save_path, "filename": filename, } json.dump(data, open(metadata_path, "w")) @@ -46,7 +43,7 @@ def handle_model_info(ckpt_path): item = { "filename": filename, - "save_path": windows_to_linux_path(save_path), + "save_path": windows_to_linux_path(rel_save_path), "urls": urls, } return model_id, item @@ -111,7 +108,7 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an import folder_paths custom_nodes = [] - ckpt_paths = [] + ckpt_paths = {} file_mapping_dict = {} for node_id, node_info in prompt.items(): @@ -130,7 +127,11 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an pattern = item["field_name"] if re.match(f"^{pattern}$", field_name) and any([filename.endswith(possible_suffix) for possible_suffix in model_suffix]): ckpt_path = folder_paths.get_full_path_or_raise(item["save_path"], filename) - ckpt_paths.append(ckpt_path) + rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[item["save_path"]][0][0], folder_paths.models_dir) + ckpt_paths[ckpt_path] = { + "filename": filename, + "rel_save_path": rel_save_path + } else: for field_name, filename in node_info["inputs"].items(): if type(filename) != str: @@ -144,17 +145,23 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an # find possible paths matching_files = [] # Walk through all subdirectories and files in the directory + rel_save_path = None for possible_folder_name in folder_paths.folder_names_and_paths: full_path = folder_paths.get_full_path(possible_folder_name, filename) if full_path is None: continue + rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[possible_folder_name][0][0], folder_paths.models_dir) matching_files.append(full_path) + break print(f"matched files: {matching_files}") if len(set(matching_files)) == 1: - ckpt_paths.append(matching_files[0]) + assert rel_save_path is not None + ckpt_paths[matching_files[0]] = { + "filename": filename, + "rel_save_path": rel_save_path + } list(map(partial(collect_local_file, mapping_dict=file_mapping_dict), node_info["inputs"].values())) - ckpt_paths = list(set(ckpt_paths)) print("ckpt_paths:", ckpt_paths) custom_nodes = list(set(custom_nodes)) # step 0: comfyui version @@ -199,8 +206,8 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an # step 2: models models_dict = {} missing_model_ids = [] - for ckpt_path in ckpt_paths: - model_id, item = handle_model_info(ckpt_path) + for ckpt_path, ckpt_info in ckpt_paths.items(): + model_id, item = handle_model_info(ckpt_path, ckpt_info["filename"], ckpt_info["rel_save_path"]) models_dict[model_id] = item if len(item["urls"]) == 0: item["require_recheck"] = True From c0da8f916d7c0652e2218b70628551073be07e2a Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 25 Oct 2024 23:21:54 +0800 Subject: [PATCH 10/70] update get_full_path_or_rase --- dependency_checker.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index c21fbf0..73a3cbc 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -7,6 +7,7 @@ import glob from folder_paths import models_dir as MODELS_DIR from folder_paths import base_path as BASE_PATH +from folder_paths import get_full_path from .utils import compute_sha256, windows_to_linux_path from .file_upload import collect_local_file, process_local_file_path_async @@ -19,6 +20,14 @@ model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx"] + +def get_full_path_or_raise(folder_name: str, filename: str) -> str: + full_path = get_full_path(folder_name, filename) + if full_path is None: + raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") + return full_path + + def handle_model_info(ckpt_path, filename, rel_save_path): ckpt_path = windows_to_linux_path(ckpt_path) metadata_path = ckpt_path + ".json" @@ -126,7 +135,7 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an for item in model_loaders_info[node_class_type]: pattern = item["field_name"] if re.match(f"^{pattern}$", field_name) and any([filename.endswith(possible_suffix) for possible_suffix in model_suffix]): - ckpt_path = folder_paths.get_full_path_or_raise(item["save_path"], filename) + ckpt_path = get_full_path_or_raise(item["save_path"], filename) rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[item["save_path"]][0][0], folder_paths.models_dir) ckpt_paths[ckpt_path] = { "filename": filename, From 666605028384c0da5bc4f05efbd0db5efd182d8a Mon Sep 17 00:00:00 2001 From: yuxumin <1090414006@qq.com> Date: Fri, 25 Oct 2024 23:29:57 +0800 Subject: [PATCH 11/70] update node deps info --- node_deps_info.json | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/node_deps_info.json b/node_deps_info.json index dfb1945..e3f01c0 100644 --- a/node_deps_info.json +++ b/node_deps_info.json @@ -15,5 +15,12 @@ "repo": "https://github.com/shiimizu/ComfyUI_smZNodes.git", "commit": "" } + ], + "efficiency-nodes-comfyui": [ + { + "name": "comfyui_controlnet_aux", + "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", + "commit": "" + } ] } \ No newline at end of file From 2b19a132a6f84bcda48b193b3cfd2c2eb2d2785d Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Sat, 26 Oct 2024 11:07:15 +0800 Subject: [PATCH 12/70] addmap_legacy --- dependency_checker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index 73a3cbc..5ff38c5 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -136,7 +136,11 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an pattern = item["field_name"] if re.match(f"^{pattern}$", field_name) and any([filename.endswith(possible_suffix) for possible_suffix in model_suffix]): ckpt_path = get_full_path_or_raise(item["save_path"], filename) - rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[item["save_path"]][0][0], folder_paths.models_dir) + if hasattr(folder_paths, "map_legacy"): + save_folder = folder_paths.map_legacy(item["save_path"]) + else: + save_folder = item["save_path"] + rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[save_folder][0][0], folder_paths.models_dir) ckpt_paths[ckpt_path] = { "filename": filename, "rel_save_path": rel_save_path From d2cfd99a334b38602c1415a4b334a71de3cc248a Mon Sep 17 00:00:00 2001 From: shanexi Date: Thu, 24 Oct 2024 22:33:38 +0800 Subject: [PATCH 13/70] Ouput convert --- web/shellagent.js | 92 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 12 deletions(-) diff --git a/web/shellagent.js b/web/shellagent.js index cee0476..bd289ae 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -176,22 +176,40 @@ app.registerExtension({ addMenuHandler(nodeType, function (_, options) { if (this.widgets) { let toInput = []; - for (const w of this.widgets) { + // todo: combo need to remove and convert back + // if (w.type === 'combo' && w.name === 'image') { + // toInput.push({ + // content: `${w.name} <- Input Image`, + // callback: () => { + // this.convertWidgetToInput(w); + // const node = addNode("ShellAgentPluginInputImage", this, { before: true }); + // const dvn = node.widgets.find(w => w.name === 'default_value') + // dvn.value = w.value; + // node.connect(0, this, 0); + // } + // }) + // } if (["customtext"].indexOf(w.type) > -1) { toInput.push({ - content: `${w.name} <- Input Text`, - callback: () => { - this.convertWidgetToInput(w); - const node = addNode("ShellAgentPluginInputText", this, { before: true }); - const dvn = node.widgets.find(w => w.name === 'default_value') - dvn.value = w.value; - node.connect(0, this, this.inputs.length - 1); - } + content: w.name, + submenu: { + options: [ + { + content: 'Input Text', + callback: () => { + this.convertWidgetToInput(w); + const node = addNode("ShellAgentPluginInputText", this, { before: true }); + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = w.value; + node.connect(0, this, this.inputs.length - 1); + } + } + ] + }, }) } if (["number"].indexOf(w.type) > -1) { - toInput.push({ content: w.name, submenu: { @@ -221,16 +239,66 @@ app.registerExtension({ }) } } - if (toInput.length) { options.unshift({ - content: "Convert to ShellAgent", + content: "Convert to ShellAgent (Input)", submenu: { options: toInput } }) } } + + if (this.outputs) { + let toOutput = []; + for (const o of this.outputs) { + if (o.type === 'IMAGE') { + toOutput.push({ + content: o.name, + submenu: { + options: [ + { + content: 'Save Image', + callback: () => { + const node = addNode("ShellAgentPluginSaveImage", this); + this.connect(0, node, 0); + } + }, + { + content: 'Save Images', + callback: () => { + const node = addNode("ShellAgentPluginSaveImages", this); + this.connect(0, node, 0); + } + } + ] + } + + }) + } + + if (o.type === 'STRING') { + toOutput.push({ + content: `${o.name} -> Output Text`, + callback: () => { + const node = addNode("ShellAgentPluginOutputText", this); + this.connect(0, node, 0); + } + }) + } + + } + + if (toOutput.length) { + options.unshift({ + content: "Convert to ShellAgent (Output)", + submenu: { + options: toOutput + } + }) + } + + } }) } }, From 4fb799112a391f2283f9ad5ffc89ebdb2fe2a388 Mon Sep 17 00:00:00 2001 From: shanexi Date: Sun, 27 Oct 2024 09:42:31 +0800 Subject: [PATCH 14/70] Convert to save video --- web/shellagent.js | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/web/shellagent.js b/web/shellagent.js index bd289ae..dc57a62 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -279,14 +279,37 @@ app.registerExtension({ if (o.type === 'STRING') { toOutput.push({ - content: `${o.name} -> Output Text`, - callback: () => { - const node = addNode("ShellAgentPluginOutputText", this); - this.connect(0, node, 0); + content: o.name, + submenu: { + options: [ + { + content: `Output Text`, + callback: () => { + const node = addNode("ShellAgentPluginOutputText", this); + this.connect(0, node, 0); + } + } + ] } }) } + if (o.type === "VHS_FILENAMES") { + toOutput.push({ + content: o.name, + submenu: { + options: [ + { + content: `Save Video - VHS`, + callback: () => { + const node = addNode("ShellAgentPluginSaveVideoVHS", this); + this.connect(0, node, 0); + } + } + ] + } + }) + } } if (toOutput.length) { From 9417b644584f50e710ee010af52f76b39e0e6701 Mon Sep 17 00:00:00 2001 From: shanexi Date: Sun, 27 Oct 2024 15:23:13 +0800 Subject: [PATCH 15/70] No need to connect image combo --- web/shellagent.js | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/web/shellagent.js b/web/shellagent.js index dc57a62..e113910 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -174,22 +174,10 @@ app.registerExtension({ if (nodeData.name.indexOf('ShellAgentPlugin') === -1) { addMenuHandler(nodeType, function (_, options) { + if (this.widgets) { let toInput = []; for (const w of this.widgets) { - // todo: combo need to remove and convert back - // if (w.type === 'combo' && w.name === 'image') { - // toInput.push({ - // content: `${w.name} <- Input Image`, - // callback: () => { - // this.convertWidgetToInput(w); - // const node = addNode("ShellAgentPluginInputImage", this, { before: true }); - // const dvn = node.widgets.find(w => w.name === 'default_value') - // dvn.value = w.value; - // node.connect(0, this, 0); - // } - // }) - // } if (["customtext"].indexOf(w.type) > -1) { toInput.push({ content: w.name, From 752a0de95d5325b46430a0f03856eaadc5bf76bf Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Sun, 27 Oct 2024 15:39:16 +0800 Subject: [PATCH 16/70] add gguf --- dependency_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index 5ff38c5..f20a7b7 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -18,7 +18,7 @@ node_deps_info = json.load(open(os.path.join(os.path.dirname(__file__), "node_deps_info.json"))) node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) -model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx"] +model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf"] def get_full_path_or_raise(folder_name: str, filename: str) -> str: From 710104d709d4d85d7fc95cfc8adebce47e5f1e6f Mon Sep 17 00:00:00 2001 From: shanexi Date: Sun, 27 Oct 2024 16:02:48 +0800 Subject: [PATCH 17/70] Image input --- web/shellagent.js | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/web/shellagent.js b/web/shellagent.js index e113910..d59ef83 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -313,6 +313,12 @@ app.registerExtension({ }) } }, + + afterConfigureGraph(missingNodeTypes, app) { + const type = 'IMAGE' + const nodeId = 'ShellAgentPluginInputImage' + LiteGraph.slot_types_default_in[type].unshift(nodeId) + } }); function addMenuHandler(nodeType, cb) { From 1a826fa746f0f63b2f6b337ee01aa6b44e7bc1b4 Mon Sep 17 00:00:00 2001 From: yuxumin <1090414006@qq.com> Date: Sun, 27 Oct 2024 17:29:46 +0800 Subject: [PATCH 18/70] update node_deps json --- node_deps_info.json | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/node_deps_info.json b/node_deps_info.json index e3f01c0..7a9c640 100644 --- a/node_deps_info.json +++ b/node_deps_info.json @@ -22,5 +22,12 @@ "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", "commit": "" } + ], + "ComfyUI-Anyline": [ + { + "name": "comfyui_controlnet_aux", + "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", + "commit": "" + } ] } \ No newline at end of file From 134ccd3c2b4a6226e362dc7e06a4cbbd386d4ff7 Mon Sep 17 00:00:00 2001 From: yuxumin <1090414006@qq.com> Date: Sun, 27 Oct 2024 17:54:59 +0800 Subject: [PATCH 19/70] fix dependencies deps --- dependency_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index f20a7b7..5ac6851 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -206,7 +206,7 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an if repo_name in node_deps_info: for deps_node in node_deps_info[repo_name]: if deps_node["name"] not in custom_nodes_names: - repo_info = inspect_repo_version(os.path.join("custom_nodes", deps_node["name"])) + repo_info = inspect_repo_version(os.path.join(BASE_PATH, "custom_nodes", deps_node["name"])) deps_node["commit"] = repo_info["commit"] custom_nodes_list.append(deps_node) custom_nodes_names.append(deps_node["name"]) From ddad7b8c40bf1d93cc995552019d6c75079d2df3 Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Sun, 27 Oct 2024 23:57:11 +0800 Subject: [PATCH 20/70] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 230e204..60d055d 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ To install, either: - Input Image - Input Float - Input Integer +- Input Video Each input node supports setting a default value and additional configuration options. @@ -25,6 +26,9 @@ Each input node supports setting a default value and additional configuration op - Save Image - Save Images - Save Video - VHS +- Output Text +- Output Float +- Output Integer ### Convert Widgets to ShellAgent Inputs From ae8ed607674294b8fab5ff096c9c909723d80c57 Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 09:52:28 +0800 Subject: [PATCH 21/70] Save Image(s) on output connect pop menu --- web/shellagent.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/shellagent.js b/web/shellagent.js index d59ef83..38df820 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -302,7 +302,7 @@ app.registerExtension({ if (toOutput.length) { options.unshift({ - content: "Convert to ShellAgent (Output)", + content: "Connect to ShellAgent (Output)", submenu: { options: toOutput } @@ -315,9 +315,9 @@ app.registerExtension({ }, afterConfigureGraph(missingNodeTypes, app) { - const type = 'IMAGE' - const nodeId = 'ShellAgentPluginInputImage' - LiteGraph.slot_types_default_in[type].unshift(nodeId) + LiteGraph.slot_types_default_in['IMAGE'].unshift('ShellAgentPluginInputImage') + LiteGraph.slot_types_default_out['IMAGE'].unshift('ShellAgentPluginSaveImage') + LiteGraph.slot_types_default_out['IMAGE'].unshift('ShellAgentPluginSaveImages') } }); From b825b62a96816d5a13126d2347bcba3bcdfef557 Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 10:09:38 +0800 Subject: [PATCH 22/70] Drag to connect output text float integer --- web/shellagent.js | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/web/shellagent.js b/web/shellagent.js index 38df820..e6ae1e6 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -37,11 +37,15 @@ app.registerExtension({ }); }, async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (["ShellAgentPluginOutputText", "ShellAgentPluginOutputFloat", "ShellAgentPluginOutputInteger"].indexOf(nodeData.name) > -1) { + chainCallback(nodeType.prototype, "onNodeCreated", function () { + this.convertWidgetToInput(this.widgets[0]) + }) + } if (["ShellAgentPluginInputText", "ShellAgentPluginInputFloat", "ShellAgentPluginInputInteger"].indexOf(nodeData.name) > -1) { chainCallback(nodeType.prototype, "onNodeCreated", function () { const widget = this.widgets.find(w => w.name === 'choices') - this.addWidget('button', 'manage choices', null, () => { const container = document.createElement("div"); Object.assign(container.style, { @@ -318,6 +322,10 @@ app.registerExtension({ LiteGraph.slot_types_default_in['IMAGE'].unshift('ShellAgentPluginInputImage') LiteGraph.slot_types_default_out['IMAGE'].unshift('ShellAgentPluginSaveImage') LiteGraph.slot_types_default_out['IMAGE'].unshift('ShellAgentPluginSaveImages') + + LiteGraph.slot_types_default_out['STRING'].unshift('ShellAgentPluginOutputFloat') + LiteGraph.slot_types_default_out['STRING'].unshift('ShellAgentPluginOutputInteger') + LiteGraph.slot_types_default_out['STRING'].unshift('ShellAgentPluginOutputText') } }); From cbd714d9ea8753f7e82ddd55814fcbe6f5c1617a Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 11:02:18 +0800 Subject: [PATCH 23/70] Replace Load Image with ShellAgent Input Image --- web/shellagent.js | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/web/shellagent.js b/web/shellagent.js index e6ae1e6..1389d9a 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -141,6 +141,22 @@ app.registerExtension({ }) } + if (['LoadImage', 'LoadImageMask'].indexOf(nodeData.name) > -1) { + addMenuHandler(nodeType, function (_, options) { + options.unshift({ + content: "Replace with ShellAgent Input Image", + callback: () => { + const node = addNode("ShellAgentPluginInputImage", this, { before: true }); + app.graph.links.filter(l => l != null) + .forEach(l => { + const tn = app.graph._nodes_by_id[l.target_id] + node.connect(0, tn, 0) + }) + } + }) + }) + } + if (nodeData.name === "ShellAgentPluginInputImage") { if ( nodeData?.input?.required?.default_value?.[1]?.image_upload === true @@ -155,7 +171,6 @@ app.registerExtension({ if (nodeData.name === "ShellAgentPluginInputVideo") { addUploadWidget(nodeType, nodeData, "default_value"); chainCallback(nodeType.prototype, "onNodeCreated", function () { - // const pathWidget = this.widgets.find((w) => w.name === "video"); const pathWidget = this.widgets.find((w) => w.name === "default_value"); chainCallback(pathWidget, "callback", (value) => { if (!value) { From f9a3ce43f52a90aa937802aaed6d483b4fa5fb99 Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 11:09:29 +0800 Subject: [PATCH 24/70] Replace and remove --- web/shellagent.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/shellagent.js b/web/shellagent.js index 1389d9a..1754953 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -152,6 +152,7 @@ app.registerExtension({ const tn = app.graph._nodes_by_id[l.target_id] node.connect(0, tn, 0) }) + app.graph.remove(this); } }) }) From 20fcbb2e48c1880d1da00ad5bc7d74cc64eb9166 Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 11:14:05 +0800 Subject: [PATCH 25/70] Fix duplicated drag menu item --- web/shellagent.js | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/web/shellagent.js b/web/shellagent.js index 1754953..629773d 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -335,13 +335,24 @@ app.registerExtension({ }, afterConfigureGraph(missingNodeTypes, app) { - LiteGraph.slot_types_default_in['IMAGE'].unshift('ShellAgentPluginInputImage') - LiteGraph.slot_types_default_out['IMAGE'].unshift('ShellAgentPluginSaveImage') - LiteGraph.slot_types_default_out['IMAGE'].unshift('ShellAgentPluginSaveImages') + function addIn(type, nodeId) { + if (LiteGraph.slot_types_default_in[type].indexOf(nodeId) === -1) { + LiteGraph.slot_types_default_in[type].unshift(nodeId) + } + } + + function addOut(type, nodeId) { + if (LiteGraph.slot_types_default_out[type].indexOf(nodeId) === -1) { + LiteGraph.slot_types_default_out[type].unshift(nodeId) + } + } - LiteGraph.slot_types_default_out['STRING'].unshift('ShellAgentPluginOutputFloat') - LiteGraph.slot_types_default_out['STRING'].unshift('ShellAgentPluginOutputInteger') - LiteGraph.slot_types_default_out['STRING'].unshift('ShellAgentPluginOutputText') + addIn('IMAGE', 'ShellAgentPluginInputImage') + addOut('IMAGE', 'ShellAgentPluginSaveImage') + addOut('IMAGE', 'ShellAgentPluginSaveImages') + addOut('STRING', 'ShellAgentPluginOutputFloat') + addOut('STRING', 'ShellAgentPluginOutputInteger') + addOut('STRING', 'ShellAgentPluginOutputText') } }); From fffc29fc2b62990a81ce5fffbb749e278e885511 Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 11:15:39 +0800 Subject: [PATCH 26/70] Add missing convert output --- web/shellagent.js | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/web/shellagent.js b/web/shellagent.js index 629773d..6504ae7 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -296,6 +296,20 @@ app.registerExtension({ const node = addNode("ShellAgentPluginOutputText", this); this.connect(0, node, 0); } + }, + { + content: `Output Float`, + callback: () => { + const node = addNode("ShellAgentPluginOutputFloat", this); + this.connect(0, node, 0); + } + }, + { + content: `Output Integer`, + callback: () => { + const node = addNode("ShellAgentPluginOutputInteger", this); + this.connect(0, node, 0); + } } ] } @@ -350,8 +364,8 @@ app.registerExtension({ addIn('IMAGE', 'ShellAgentPluginInputImage') addOut('IMAGE', 'ShellAgentPluginSaveImage') addOut('IMAGE', 'ShellAgentPluginSaveImages') - addOut('STRING', 'ShellAgentPluginOutputFloat') addOut('STRING', 'ShellAgentPluginOutputInteger') + addOut('STRING', 'ShellAgentPluginOutputFloat') addOut('STRING', 'ShellAgentPluginOutputText') } }); From 344a8867926aa1357c4ebabb51f2f41e720dea51 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 28 Oct 2024 11:19:19 +0800 Subject: [PATCH 27/70] add pypi version info --- dependency_checker.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dependency_checker.py b/dependency_checker.py index f20a7b7..4968fce 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -5,10 +5,12 @@ from functools import partial import re import glob +import pkg_resources from folder_paths import models_dir as MODELS_DIR from folder_paths import base_path as BASE_PATH from folder_paths import get_full_path + from .utils import compute_sha256, windows_to_linux_path from .file_upload import collect_local_file, process_local_file_path_async @@ -21,6 +23,7 @@ model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf"] + def get_full_path_or_raise(folder_name: str, filename: str) -> str: full_path = get_full_path(folder_name, filename) if full_path is None: @@ -112,6 +115,23 @@ def fetch_model_searcher_results(model_ids): results = None return results +def split_package_version(require_line): + require_line = require_line.strip() + + pattern = r"^([a-zA-Z0-9_\-\[\]]+)(.*)$" + match = re.match(pattern, require_line.strip()) + + if match: + package_name = match.group(1) # First capturing group is the package name + version_specifier = match.group(2) if match.group(2) else "" # Second group is the version, if present + return package_name, version_specifier + else: + assert len(require_line) == 0 or require_line.strip()[0] == "#", require_line + return None, None + +def get_package_version(package_name): + return pkg_resources.get_distribution(package_name).version + def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes and models at the same time from nodes import NODE_CLASS_MAPPINGS import folder_paths @@ -189,6 +209,7 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an # step 1: custom nodes custom_nodes_list = [] custom_nodes_names = [] + requirements_lines = [] for custom_node in custom_nodes: try: repo_info = inspect_repo_version(os.path.join(BASE_PATH, custom_node.replace(".", "/"))) @@ -201,6 +222,15 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an custom_nodes_names.append(repo_info["name"]) except: print(f"failed to resolve repo info of {custom_node}") + requirement_file = os.path.join(BASE_PATH, custom_node.replace(".", "/"), "requirements.txt") + if os.path.isfile(requirement_file): + requirements_lines += open(requirement_file).readlines() + requirements_lines = list(set(requirements_lines)) + requirements_packages = [package_name for package_name, version_specifier in map(split_package_version, requirements_lines) if package_name is not None] + pypi_deps = { + package_name: get_package_version(package_name) + for package_name in requirements_packages + } for repo_name in custom_nodes_names: if repo_name in node_deps_info: @@ -246,7 +276,9 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an "custom_nodes": custom_nodes_list, "models": models_dict, "files": files_dict, + "pypi": pypi_deps } + return_dict = { "dependencies": depencencies, "black_list_nodes": black_list_nodes, From f28c2c6b31994bd5c0842f12b9e8ccc5741bb917 Mon Sep 17 00:00:00 2001 From: shanexi Date: Mon, 28 Oct 2024 11:20:03 +0800 Subject: [PATCH 28/70] Replace with default value assigned --- web/shellagent.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/web/shellagent.js b/web/shellagent.js index 6504ae7..79be03e 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -147,6 +147,10 @@ app.registerExtension({ content: "Replace with ShellAgent Input Image", callback: () => { const node = addNode("ShellAgentPluginInputImage", this, { before: true }); + + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = this.widgets.find(w => w.name === 'image')?.value + app.graph.links.filter(l => l != null) .forEach(l => { const tn = app.graph._nodes_by_id[l.target_id] From 3b5a9b52202796652f4b6e7b475d2e347773e3f0 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 28 Oct 2024 11:22:04 +0800 Subject: [PATCH 29/70] fix file upload error --- file_upload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/file_upload.py b/file_upload.py index fd471fe..a936f6a 100644 --- a/file_upload.py +++ b/file_upload.py @@ -95,7 +95,8 @@ def process_local_file_path_async(mapping_dict, max_workers=10): result = future.result() mapping_dict[filename] = result except Exception as e: - print(f"Error processing {filename}: {e}") + del mapping_dict[filename] + raise NotImplementedError(f"Error processing {filename}: {e}") end_time = time.time() logging.info(f"upload end, elapsed time: {end_time - start_time}") return \ No newline at end of file From a9d07ba4d3d38a9967be4e8636ac487d6006e972 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 28 Oct 2024 11:43:54 +0800 Subject: [PATCH 30/70] hardcode hf packages --- dependency_checker.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index 538a7e5..b9dcfd3 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -5,7 +5,7 @@ from functools import partial import re import glob -import pkg_resources +import sys from folder_paths import models_dir as MODELS_DIR from folder_paths import base_path as BASE_PATH from folder_paths import get_full_path @@ -21,7 +21,7 @@ node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf"] - +extra_packages = ["transformers", "timm", "diffusers", "accelerate"] def get_full_path_or_raise(folder_name: str, filename: str) -> str: @@ -130,7 +130,15 @@ def split_package_version(require_line): return None, None def get_package_version(package_name): - return pkg_resources.get_distribution(package_name).version + try: + if sys.version_info >= (3, 8): + from importlib.metadata import version, PackageNotFoundError + return version(package_name) + else: + from pkg_resources import get_distribution, DistributionNotFound + return get_distribution(package_name).version + except Exception: + return None def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes and models at the same time from nodes import NODE_CLASS_MAPPINGS @@ -227,9 +235,10 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an requirements_lines += open(requirement_file).readlines() requirements_lines = list(set(requirements_lines)) requirements_packages = [package_name for package_name, version_specifier in map(split_package_version, requirements_lines) if package_name is not None] + package_names = set(requirements_packages + extra_packages) pypi_deps = { package_name: get_package_version(package_name) - for package_name in requirements_packages + for package_name in package_names } for repo_name in custom_nodes_names: From 73bfa4e7b0d14b1a07dbeb5ff4a990f8ff50c6cd Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 28 Oct 2024 15:13:01 +0800 Subject: [PATCH 31/70] add warning message when no inputs/outputs founded --- custom_routes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index f3a6a91..b7718ee 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -158,13 +158,14 @@ async def shellagent_export(request): # for fname, dict_to_save in fname_mapping.items(): # with open(os.path.join(save_root, fname), "w") as f: # json.dump(dict_to_save, f, indent=2) - + warning_message = "" if dependency_results.get("black_list_nodes", []): warning_message = "The following nodes cannot be deployed to myshell:\n" for item in dependency_results["black_list_nodes"]: warning_message += f" {item['name']}: {item['reason']}\n" - else: - warning_message = "" + + if len(schemas["inputs"]) + len(schemas["outputs"]) == 0: + warning_message += f"The workflow contains neither inputs nor outputs!\n" return_dict = { "success": True, From 4eb3e8b4f79e0ddb835f636b1bdca175fd96e9b5 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 29 Oct 2024 11:36:13 +0800 Subject: [PATCH 32/70] add message details --- custom_routes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/custom_routes.py b/custom_routes.py index b7718ee..f5b94e7 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -177,6 +177,7 @@ async def shellagent_export(request): status = 400 return_dict = { "success": False, - "message": str(traceback.format_exc()), + "message_detail": str(traceback.format_exc()), + "message": str(e), } return web.json_response(return_dict, status=status) \ No newline at end of file From ae2948048ad6633f7625b944f8168f9441d18e8d Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Tue, 29 Oct 2024 12:55:25 +0800 Subject: [PATCH 33/70] Update dependency_checker.py --- dependency_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index b9dcfd3..f02cdce 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -54,7 +54,7 @@ def handle_model_info(ckpt_path, filename, rel_save_path): urls = [] item = { - "filename": filename, + "filename": windows_to_linux_path(filename), "save_path": windows_to_linux_path(rel_save_path), "urls": urls, } From 8e8d10b1c595b843d2ad217615110fb6b8a7ba27 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Wed, 30 Oct 2024 12:07:56 +0800 Subject: [PATCH 34/70] update output video --- comfy-nodes/output_image.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy-nodes/output_image.py b/comfy-nodes/output_image.py index f1ecc48..b3ddf64 100644 --- a/comfy-nodes/output_image.py +++ b/comfy-nodes/output_image.py @@ -75,11 +75,15 @@ def validate(cls, **kwargs): return schema def save_video(self, filenames, **kwargs): - status, (preview_image, video_path) = filenames + status, output_files = filenames + if len(output_files) == 0: + raise ValueError("the filenames are empty") + print("output_files", output_files) + video_path = output_files[-1] cwd = os.getcwd() preview_image = os.path.relpath(preview_image) video_path = os.path.relpath(video_path) - results = {"ui": {"image": [preview_image], "video": [video_path]}} + results = {"ui": {"video": [video_path]}} return results From 64ebfa42e93cbdc4c0244480453f38403191ac18 Mon Sep 17 00:00:00 2001 From: Wenliang Zhao Date: Wed, 30 Oct 2024 18:11:17 +0800 Subject: [PATCH 35/70] Update output_image.py --- comfy-nodes/output_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy-nodes/output_image.py b/comfy-nodes/output_image.py index b3ddf64..b302e4f 100644 --- a/comfy-nodes/output_image.py +++ b/comfy-nodes/output_image.py @@ -81,7 +81,7 @@ def save_video(self, filenames, **kwargs): print("output_files", output_files) video_path = output_files[-1] cwd = os.getcwd() - preview_image = os.path.relpath(preview_image) + # preview_image = os.path.relpath(preview_image) video_path = os.path.relpath(video_path) results = {"ui": {"video": [video_path]}} return results From 34f8feb2c02c008d4e8d1aa104a1c3675cfbaeff Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 1 Nov 2024 11:15:21 +0800 Subject: [PATCH 36/70] add boolean; fix output nodes input type bugs; fix input_video enum validate bug; compatible with none desc --- comfy-nodes/input_image.py | 2 +- comfy-nodes/input_text.py | 50 +++++++++++++++++++++++++++++++++++--- comfy-nodes/input_video.py | 7 ++++++ comfy-nodes/output_text.py | 22 ++++++++++++++--- 4 files changed, 73 insertions(+), 8 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index e2b25fc..b09a645 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -45,7 +45,7 @@ def validate(cls, **kwargs): "title": kwargs["input_name"], "type": "string", "default": kwargs["default_value"], - "description": kwargs["description"], + "description": kwargs.get("description", ""), "url_type": "image" } return schema diff --git a/comfy-nodes/input_text.py b/comfy-nodes/input_text.py index f499322..1fb3f56 100755 --- a/comfy-nodes/input_text.py +++ b/comfy-nodes/input_text.py @@ -42,7 +42,7 @@ def validate(cls, **kwargs): "title": kwargs["input_name"], "type": "string", "default": kwargs["default_value"], - "description": kwargs["description"], + "description": kwargs.get("description", ""), } if kwargs.get("choices", "") != "": schema["enums"] = eval(kwargs["choices"]) @@ -101,7 +101,7 @@ def validate(cls, **kwargs): "title": kwargs["input_name"], "type": "number", "default": kwargs["default_value"], - "description": kwargs["description"], + "description": kwargs.get("description", ""), } if kwargs.get("choices", "") != "": schema["enums"] = eval(kwargs["choices"]) @@ -184,14 +184,58 @@ def validate(cls, **kwargs): def run(self, input_name, default_value=None, display_name=None, description=None, **kwargs): return [default_value] +class ShellAgentPluginInputBoolean: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_name": ( + "STRING", + {"multiline": False, "default": "input_bool"}, + ), + }, + "optional": { + "default_value": ( + "BOOLEAN", + {"default": False}, + ), + "description": ( + "STRING", + {"multiline": True, "default": ""}, + ), + } + } + + RETURN_TYPES = ("BOOLEAN",) + RETURN_NAMES = ("boolean",) + + FUNCTION = "run" + + CATEGORY = "shellagent" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["input_name"], + "type": "boolean", + "default": kwargs["default_value"], + "description": kwargs.get("description", ""), + } + return schema + + def run(self, input_name, default_value=None, display_name=None, description=None, **kwargs): + return [default_value] + NODE_CLASS_MAPPINGS = { "ShellAgentPluginInputText": ShellAgentPluginInputText, "ShellAgentPluginInputFloat": ShellAgentPluginInputFloat, - "ShellAgentPluginInputInteger": ShellAgentPluginInputInteger + "ShellAgentPluginInputInteger": ShellAgentPluginInputInteger, + "ShellAgentPluginInputBoolean": ShellAgentPluginInputBoolean, } NODE_DISPLAY_NAME_MAPPINGS = { "ShellAgentPluginInputText": "Input Text (ShellAgent Plugin)", "ShellAgentPluginInputFloat": "Input Float (ShellAgent Plugin)", "ShellAgentPluginInputInteger": "Input Integer (ShellAgent Plugin)", + "ShellAgentPluginInputBoolean": "Input Boolean (ShellAgent Plugin)", } \ No newline at end of file diff --git a/comfy-nodes/input_video.py b/comfy-nodes/input_video.py index 48e94a0..e1e75af 100644 --- a/comfy-nodes/input_video.py +++ b/comfy-nodes/input_video.py @@ -120,6 +120,13 @@ def validate(cls, **kwargs): "url_type": "video" } return schema + + @classmethod + def VALIDATE_INPUTS(s, input_name, default_value, description=""): + video = default_value + if not folder_paths.exists_annotated_filepath(video): + return "Invalid video file: {}".format(video) + return True def run(self, input_name, default_value=None, description=None): input_dir = folder_paths.get_input_directory() diff --git a/comfy-nodes/output_text.py b/comfy-nodes/output_text.py index 8b74fce..fbb6c91 100644 --- a/comfy-nodes/output_text.py +++ b/comfy-nodes/output_text.py @@ -2,7 +2,8 @@ json_type_mapipng = { "text": "string", "float": "number", - "integer": "integer" + "integer": "integer", + "boolean": "boolean", } class ShellAgentOutputText: @@ -43,7 +44,7 @@ class ShellAgentOutputFloat(ShellAgentOutputText): def INPUT_TYPES(s): return { "required": { - s.TYPE_STR: ("STRING", {"tooltip": f"The {s.TYPE_STR} to output."}), + s.TYPE_STR: ("FLOAT", {"tooltip": f"The {s.TYPE_STR} to output."}), "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), }, } @@ -56,7 +57,19 @@ class ShellAgentOutputInteger(ShellAgentOutputText): def INPUT_TYPES(s): return { "required": { - s.TYPE_STR: ("STRING", {"tooltip": f"The {s.TYPE_STR} to output."}), + s.TYPE_STR: ("INT", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + +class ShellAgentOutputBoolean(ShellAgentOutputText): + TYPE_STR = "boolean" + DESCRIPTION = "output the integer" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("BOOLEAN", {"tooltip": f"The {s.TYPE_STR} to output."}), "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), }, } @@ -65,7 +78,8 @@ def INPUT_TYPES(s): NODE_CLASS_MAPPINGS = { "ShellAgentPluginOutputText": ShellAgentOutputText, "ShellAgentPluginOutputFloat": ShellAgentOutputFloat, - "ShellAgentPluginOutputInteger": ShellAgentOutputInteger + "ShellAgentPluginOutputInteger": ShellAgentOutputInteger, + "ShellAgentPluginOutputBoolean": ShellAgentOutputBoolean, } NODE_DISPLAY_NAME_MAPPINGS = { "ShellAgentPluginOutputText": "Output Text (ShellAgent Plugin)", From c80f15465983c456f3fc63064febdf43cc2649b7 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 1 Nov 2024 20:21:27 +0800 Subject: [PATCH 37/70] input video --- comfy-nodes/input_video.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy-nodes/input_video.py b/comfy-nodes/input_video.py index e1e75af..d087b70 100644 --- a/comfy-nodes/input_video.py +++ b/comfy-nodes/input_video.py @@ -124,6 +124,8 @@ def validate(cls, **kwargs): @classmethod def VALIDATE_INPUTS(s, input_name, default_value, description=""): video = default_value + if video.startswith("http"): + return True if not folder_paths.exists_annotated_filepath(video): return "Invalid video file: {}".format(video) return True From 0904716cd077487642f8a1f6878cf27a5ea75e04 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Sun, 3 Nov 2024 22:29:12 +0800 Subject: [PATCH 38/70] fix input video bug --- comfy-nodes/input_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy-nodes/input_video.py b/comfy-nodes/input_video.py index d087b70..3700aac 100644 --- a/comfy-nodes/input_video.py +++ b/comfy-nodes/input_video.py @@ -4,7 +4,7 @@ import torch import os import uuid -import tqdm +from tqdm import tqdm # class ShellAgentPluginInputImage: From 1faaad58f0f70bf7fe3857636d9ddb846577e32c Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Sun, 3 Nov 2024 22:38:13 +0800 Subject: [PATCH 39/70] Update input_video.py --- comfy-nodes/input_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy-nodes/input_video.py b/comfy-nodes/input_video.py index 3700aac..7aa45ab 100644 --- a/comfy-nodes/input_video.py +++ b/comfy-nodes/input_video.py @@ -151,7 +151,7 @@ def run(self, input_name, default_value=None, description=None): num_bars = int(file_size / chunk_size) with open(video_path, "wb") as out_file: - for chunk in tqdm( + for chunk in tqdm.tqdm( response.iter_content(chunk_size=chunk_size), total=num_bars, unit="KB", @@ -175,4 +175,4 @@ def run(self, input_name, default_value=None, description=None): NODE_DISPLAY_NAME_MAPPINGS = { # "ShellAgentPluginInputImage": "Input Image (ShellAgent Plugin)", "ShellAgentPluginInputVideo": "Input Video (ShellAgent Plugin)" -} \ No newline at end of file +} From 1aa0fc15e249d5de5c4810fe3fba716aff8f9d7f Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Sun, 3 Nov 2024 22:38:56 +0800 Subject: [PATCH 40/70] revert input_video.py --- comfy-nodes/input_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy-nodes/input_video.py b/comfy-nodes/input_video.py index 7aa45ab..c13c029 100644 --- a/comfy-nodes/input_video.py +++ b/comfy-nodes/input_video.py @@ -151,7 +151,7 @@ def run(self, input_name, default_value=None, description=None): num_bars = int(file_size / chunk_size) with open(video_path, "wb") as out_file: - for chunk in tqdm.tqdm( + for chunk in tqdm( response.iter_content(chunk_size=chunk_size), total=num_bars, unit="KB", From 4c8e720d0523d4b9befa215b848883e246f6ecda Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 4 Nov 2024 12:17:40 +0800 Subject: [PATCH 41/70] fix output video path error --- comfy-nodes/output_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy-nodes/output_image.py b/comfy-nodes/output_image.py index b302e4f..7a884b6 100644 --- a/comfy-nodes/output_image.py +++ b/comfy-nodes/output_image.py @@ -82,7 +82,7 @@ def save_video(self, filenames, **kwargs): video_path = output_files[-1] cwd = os.getcwd() # preview_image = os.path.relpath(preview_image) - video_path = os.path.relpath(video_path) + video_path = os.path.relpath(video_path, folder_paths.base_path) results = {"ui": {"video": [video_path]}} return results From 4d548bfb5a43ef86545264e60177a75ab9191835 Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Wed, 6 Nov 2024 20:39:39 +0800 Subject: [PATCH 42/70] support .sft for model suffix --- dependency_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index f02cdce..02f53e6 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -20,7 +20,7 @@ node_deps_info = json.load(open(os.path.join(os.path.dirname(__file__), "node_deps_info.json"))) node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) -model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf"] +model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf", ".sft"] extra_packages = ["transformers", "timm", "diffusers", "accelerate"] From 85e01a8711e670fd0ef2ace5282c2d2bb273edcc Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Thu, 7 Nov 2024 10:56:06 +0800 Subject: [PATCH 43/70] add tree_map for dependency checker --- dependency_checker.py | 62 +- file_upload.py | 2 +- utils/pytree.py | 1197 ++++++++++++++++++++++++++++++++++++ utils.py => utils/utils.py | 0 4 files changed, 1232 insertions(+), 29 deletions(-) create mode 100644 utils/pytree.py rename utils.py => utils/utils.py (100%) diff --git a/dependency_checker.py b/dependency_checker.py index f02cdce..2be79f2 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -11,7 +11,8 @@ from folder_paths import get_full_path -from .utils import compute_sha256, windows_to_linux_path +from .utils.utils import compute_sha256, windows_to_linux_path +from .utils.pytree import tree_map from .file_upload import collect_local_file, process_local_file_path_async @@ -148,6 +149,36 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an ckpt_paths = {} file_mapping_dict = {} + + + def collect_unknown_models(filename): + if type(filename) != str: + return + is_model = False + for possible_suffix in model_suffix: + if filename.endswith(possible_suffix): + is_model = True + if is_model: + print(f"find {filename}, is_model=True") + # find possible paths + matching_files = [] + # Walk through all subdirectories and files in the directory + rel_save_path = None + for possible_folder_name in folder_paths.folder_names_and_paths: + full_path = folder_paths.get_full_path(possible_folder_name, filename) + if full_path is None: + continue + rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[possible_folder_name][0][0], folder_paths.models_dir) + matching_files.append(full_path) + break + print(f"matched files: {matching_files}") + if len(set(matching_files)) == 1: + assert rel_save_path is not None + ckpt_paths[matching_files[0]] = { + "filename": filename, + "rel_save_path": rel_save_path + } + for node_id, node_info in prompt.items(): node_class_type = node_info.get("class_type") if node_class_type is None: @@ -174,33 +205,8 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an "rel_save_path": rel_save_path } else: - for field_name, filename in node_info["inputs"].items(): - if type(filename) != str: - continue - is_model = False - for possible_suffix in model_suffix: - if filename.endswith(possible_suffix): - is_model = True - if is_model: - print(f"find {filename}, is_model=True") - # find possible paths - matching_files = [] - # Walk through all subdirectories and files in the directory - rel_save_path = None - for possible_folder_name in folder_paths.folder_names_and_paths: - full_path = folder_paths.get_full_path(possible_folder_name, filename) - if full_path is None: - continue - rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[possible_folder_name][0][0], folder_paths.models_dir) - matching_files.append(full_path) - break - print(f"matched files: {matching_files}") - if len(set(matching_files)) == 1: - assert rel_save_path is not None - ckpt_paths[matching_files[0]] = { - "filename": filename, - "rel_save_path": rel_save_path - } + tree_map(collect_unknown_models, node_info["inputs"]) + list(map(partial(collect_local_file, mapping_dict=file_mapping_dict), node_info["inputs"].values())) print("ckpt_paths:", ckpt_paths) diff --git a/file_upload.py b/file_upload.py index a936f6a..3c4afbd 100644 --- a/file_upload.py +++ b/file_upload.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import folder_paths -from .utils import compute_sha256 +from .utils.utils import compute_sha256 ext_to_type = { # image diff --git a/utils/pytree.py b/utils/pytree.py new file mode 100644 index 0000000..d326a84 --- /dev/null +++ b/utils/pytree.py @@ -0,0 +1,1197 @@ +""" +Copy from https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py + +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import importlib +import json +import threading +import warnings +from collections import defaultdict, deque, namedtuple, OrderedDict +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Deque, + Dict, + FrozenSet, + Iterable, + List, + NamedTuple, + Optional, + OrderedDict as GenericOrderedDict, + overload, + Tuple, + Type, + TypeVar, + Union, +) +from easydict import EasyDict as edict + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "register_pytree_node", + "tree_flatten", + "tree_unflatten", + "tree_leaves", + "tree_structure", + "tree_map", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + + +_NODE_REGISTRY_LOCK = threading.Lock() +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + + +def register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef( + cls, + flatten_fn, + unflatten_fn, + ) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = f"{cls.__module__}.{cls.__qualname__}" + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: + return list(values) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + json_namedtuple = { + "class_name": context.__name__, + "fields": context._fields, + } + return json_namedtuple + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + class_name = dumpable_context["class_name"] + assert isinstance(class_name, str) + # type: ignore[misc] + context = namedtuple(class_name, dumpable_context["fields"]) + return context + + +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_unflatten( + values: Iterable[Any], + context: Context, +) -> GenericOrderedDict[Any, Any]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(deq: Deque[Any]) -> Tuple[List[Any], Context]: + return list(deq), deq.maxlen + + +def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", +) + + +STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( + {dict, OrderedDict, defaultdict}, +) +BUILTIN_TYPES: FrozenSet[type] = frozenset( + {tuple, list, dict, namedtuple, OrderedDict, + defaultdict, deque}, # type: ignore[arg-type] +) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(tree: Any) -> bool: + typ = type(tree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(tree: Any) -> Any: + if _is_namedtuple_instance(tree): + return namedtuple + return type(tree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(tree: PyTree) -> bool: + return _get_node_type(tree) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.num_nodes = 1 + \ + sum(spec.num_nodes for spec in self.children_specs) + self.num_leaves = sum(spec.num_leaves for spec in self.children_specs) + self.num_children = len(self.children_specs) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: + if self.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if self.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != self.type: + raise ValueError( + f"Type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if len(child_pytrees) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(child_pytrees)}.", + ) + if context != self.context: + raise ValueError( + f"Node context mismatch for custom node type {self.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES + ) + if node_type != self.type and not both_standard_dict: + raise ValueError( + f"Node type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + if len(tree) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: # dictionary types are compatible with each other + dict_context = ( + self.context + if self.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else self.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + child_pytrees = [tree[key] for key in expected_keys] + else: + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if ( + context != self.context + and self.type is not deque # ignore mismatch of `maxlen` for deque + ): + raise ValueError( + f"Node context mismatch for node type {self.type!r}; " + # namedtuple type mismatch + f"expected {self.context!r}, but got {context!r}.", + ) + + for child_pytree, child_spec in zip(child_pytrees, self.children_specs): + child_spec._flatten_up_to_helper(child_pytree, subtrees) + + def flatten_up_to(self, tree: PyTree) -> List[PyTree]: + subtrees: List[PyTree] = [] + self._flatten_up_to_helper(tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + + def __post_init__(self) -> None: + self.num_nodes = 1 + self.num_leaves = 1 + self.num_children = 0 + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def _tree_flatten_helper(tree: PyTree, leaves: List[Any]) -> TreeSpec: + if hasattr(tree, "keys"): # type(tree) == edict: + tree = {**tree} + + if _is_leaf(tree): + leaves.append(tree) + return _LEAF_SPEC + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + + # Recursively flatten the children + children_specs = [_tree_flatten_helper( + child, leaves) for child in child_pytrees] + + return TreeSpec(node_type, context, children_specs) + + +def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + leaves: List[Any] = [] + spec = _tree_flatten_helper(tree, leaves) + return leaves, spec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def _tree_leaves_helper(tree: PyTree, leaves: List[Any]) -> None: + if _is_leaf(tree): + leaves.append(tree) + return + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + _tree_leaves_helper(child, leaves) + + +def tree_leaves(tree: PyTree) -> List[Any]: + """Get a list of leaves of a pytree.""" + leaves: List[Any] = [] + _tree_leaves_helper(tree, leaves) + return leaves + + +def tree_structure(tree: PyTree) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree)[1] + + +def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable + return tree + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +@overload +def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + # @functools.wraps(func) # torch dynamo doesn't support this yet + def wrapped(x: T) -> Any: + if isinstance(x, __type_or_types): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + __type_or_types: Type[T], + func: Fn[T, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, +) -> PyTree: + ... + + +def tree_map_only( + __type_or_types: TypeAny, + func: FnAny[Any], + tree: PyTree, +) -> PyTree: + return tree_map(map_only(__type_or_types)(func), tree) + + +@overload +def tree_map_only_( + __type_or_types: Type[T], + func: Fn[T, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, +) -> PyTree: + ... + + +def tree_map_only_( + __type_or_types: TypeAny, + func: FnAny[Any], + tree: PyTree, +) -> PyTree: + return tree_map_(map_only(__type_or_types)(func), tree) + + +def tree_all(pred: Callable[[Any], bool], tree: PyTree) -> bool: + flat_args = tree_leaves(tree) + return all(map(pred, flat_args)) + + +def tree_any(pred: Callable[[Any], bool], tree: PyTree) -> bool: + flat_args = tree_leaves(tree) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, +) -> bool: + ... + + +def tree_all_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, +) -> bool: + flat_args = tree_leaves(tree) + return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +@overload +def tree_any_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, +) -> bool: + ... + + +def tree_any_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, +) -> bool: + flat_args = tree_leaves(tree) + return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten(tree: PyTree, treespec: TreeSpec) -> Optional[List[Any]]: + assert isinstance(treespec, TreeSpec) + + if _is_leaf(tree): + return [tree] * treespec.num_leaves + if isinstance(treespec, LeafSpec): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """ + _TreeSpecSchema is the schema used to serialize the TreeSpec + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + """ + + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if isinstance(treespec, LeafSpec): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context( + treespec.context) + + child_schemas = [_treespec_to_json(child) + for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if ( + json_schema["type"] is None + and json_schema["context"] is None + and len(json_schema["children_spec"]) == 0 + ): + return LeafSpec() + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context( + json_schema["context"]) + + children_specs = [] + for child_string in json_schema["children_spec"]: + children_specs.append(_json_to_treespec(child_string)) + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def pytree_to_str(treespec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: + """Get a flat list of arguments to this function + + A slightly faster version of tree_leaves((args, kwargs)) + """ + leaves: List[Any] = [] + for a in args: + _tree_leaves_helper(a, leaves) + for a in kwargs.values(): + _tree_leaves_helper(a, leaves) + return leaves \ No newline at end of file diff --git a/utils.py b/utils/utils.py similarity index 100% rename from utils.py rename to utils/utils.py From 42981b7889bd078ee39af5c8b8b469c6386ba906 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 11 Nov 2024 15:28:04 +0800 Subject: [PATCH 44/70] add validation for variable name --- custom_routes.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/custom_routes.py b/custom_routes.py index f5b94e7..4ab5e2b 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -29,6 +29,8 @@ from datetime import datetime import nodes import traceback +import re +import keyword from .dependency_checker import resolve_dependencies @@ -45,6 +47,14 @@ "ShellAgentPluginSaveVideoVHS": "video", } +# Regular expression for a valid Python variable name +variable_name_pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' + +def is_valid_variable_name(name): + # Check if it matches the pattern and is not a keyword + if re.match(variable_name_pattern, name) and not keyword.iskeyword(name): + return True + return False def schema_validator(prompt): from nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS @@ -84,6 +94,9 @@ def schema_validator(prompt): continue if hasattr(node_cls, "validate"): schema = node_cls.validate(**node_info["inputs"]) + # validate schema + if not is_valid_variable_name(schema["title"]): + raise ValueError(f'{schema["title"]} is not a valid variable name!') else: raise NotImplementedError("the validate is not implemented") schemas[mode][node_id] = schema From 35b27002511fe530b26d214683f36b0be78d60d6 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 11 Nov 2024 15:30:45 +0800 Subject: [PATCH 45/70] add backtick --- custom_routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_routes.py b/custom_routes.py index 4ab5e2b..5dedf90 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -96,7 +96,7 @@ def schema_validator(prompt): schema = node_cls.validate(**node_info["inputs"]) # validate schema if not is_valid_variable_name(schema["title"]): - raise ValueError(f'{schema["title"]} is not a valid variable name!') + raise ValueError(f'`{schema["title"]}` is not a valid variable name!') else: raise NotImplementedError("the validate is not implemented") schemas[mode][node_id] = schema From ffa4123f07986da79f27ccf3057880cc002da620 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 12 Nov 2024 16:58:27 +0800 Subject: [PATCH 46/70] support skip model check for some nodes --- dependency_checker.py | 11 ++++++++++- node_remote.json | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 node_remote.json diff --git a/dependency_checker.py b/dependency_checker.py index e4c2843..ecdfe6c 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -20,6 +20,7 @@ model_loaders_info = json.load(open(os.path.join(os.path.dirname(__file__), "model_loader_info.json"))) node_deps_info = json.load(open(os.path.join(os.path.dirname(__file__), "node_deps_info.json"))) node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) +node_remote_skip_models = json.load(open(os.path.join(os.path.dirname(__file__), "node_remote.json"))) model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf", ".sft"] extra_packages = ["transformers", "timm", "diffusers", "accelerate"] @@ -184,9 +185,17 @@ def collect_unknown_models(filename): if node_class_type is None: raise NotImplementedError(f"Missing nodes founded, please first install the missing nodes using ComfyUI Manager") node_cls = NODE_CLASS_MAPPINGS[node_class_type] + + skip_model_check = False + if hasattr(node_cls, "RELATIVE_PYTHON_MODULE") and node_cls.RELATIVE_PYTHON_MODULE.startswith("custom_nodes."): print(node_cls.RELATIVE_PYTHON_MODULE) custom_nodes.append(node_cls.RELATIVE_PYTHON_MODULE) + + if node_cls.RELATIVE_PYTHON_MODULE[len("custom_nodes."):] in node_remote_skip_models: + skip_model_check = True + print(f"skip model check for {node_class_type}") + if node_class_type in model_loaders_info: for field_name, filename in node_info["inputs"].items(): if type(filename) != str: @@ -204,7 +213,7 @@ def collect_unknown_models(filename): "filename": filename, "rel_save_path": rel_save_path } - else: + elif not skip_model_check: tree_map(collect_unknown_models, node_info["inputs"]) list(map(partial(collect_local_file, mapping_dict=file_mapping_dict), node_info["inputs"].values())) diff --git a/node_remote.json b/node_remote.json new file mode 100644 index 0000000..f089f82 --- /dev/null +++ b/node_remote.json @@ -0,0 +1,3 @@ +[ + "BizyAir" +] \ No newline at end of file From d9d20018be34fed0ebfeb48bda60ca0187296d57 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Thu, 14 Nov 2024 10:19:59 +0800 Subject: [PATCH 47/70] improve abs_path of file dependency --- dependency_checker.py | 5 ++++- file_upload.py | 21 +++++++++++++++------ utils/utils.py | 17 ++++++++++++++++- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index ecdfe6c..613325c 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -293,7 +293,10 @@ def collect_unknown_models(filename): # step 3: handle local files process_local_file_path_async(file_mapping_dict, max_workers=20) - files_dict = {v[0]: {"filename": windows_to_linux_path(os.path.relpath(v[2], BASE_PATH)), "urls": [v[1]]} for v in file_mapping_dict.values()} + files_dict = { + v[0]: { + "filename": windows_to_linux_path(os.path.relpath(v[2], BASE_PATH)) if not v[3] else v[2], + "urls": [v[1]]} for v in file_mapping_dict.values()} depencencies = { "comfyui_version": comfyui_version, diff --git a/file_upload.py b/file_upload.py index 3c4afbd..aa13ff8 100644 --- a/file_upload.py +++ b/file_upload.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import folder_paths -from .utils.utils import compute_sha256 +from .utils.utils import compute_sha256, get_alphanumeric_hash ext_to_type = { # image @@ -27,7 +27,7 @@ '.m4a': 'audio/mp4', } -def upload_file_to_myshell(local_file: str) -> str: +def upload_file_to_myshell(local_file: str, target_path: str, is_abs) -> str: ''' Now we only support upload file one-by-one ''' MYSHELL_KEY = os.environ.get('MYSHELL_KEY', "OPENSOURCE_FIXED") @@ -51,8 +51,8 @@ def upload_file_to_myshell(local_file: str) -> str: response = requests.request("POST", server_url, headers=headers, files=files) if response.status_code == 200: end_time = time.time() - logging.info(f"{local_file} uploaded, time elapsed: {end_time - start_time}") - return [sha256sum, response.json()['url'], local_file] + logging.info(f"{local_file} uploaded, time elapsed: {end_time - start_time}, will be saved to {target_path}") + return [sha256sum, response.json()['url'], target_path, is_abs] else: raise Exception( f"[HTTP ERROR] {response.status_code} - {response.text} \n" @@ -66,8 +66,11 @@ def collect_local_file(item, mapping_dict={}): abspath = os.path.abspath(item) input_abspath = os.path.join(input_dir, item) # required file type + is_abs = False if os.path.isfile(abspath): fpath = abspath + is_abs = True + elif os.path.isfile(input_abspath): fpath = input_abspath else: @@ -75,7 +78,13 @@ def collect_local_file(item, mapping_dict={}): if fpath is not None: ext = os.path.splitext(fpath)[1] if ext.lower() in ext_to_type.keys(): - mapping_dict[item] = fpath + if is_abs: # if use abs path, replace it + filename_hash = get_alphanumeric_hash(abspath)[:16] + count = len(mapping_dict) + target_path = f"/ShellAgentDeploy/ComfyUI/input/{filename_hash}_{count:06d}{ext}" + mapping_dict[item] = (fpath, target_path, is_abs) + else: + mapping_dict[item] = (fpath, fpath, is_abs) return else: return @@ -86,7 +95,7 @@ def process_local_file_path_async(mapping_dict, max_workers=10): start_time = time.time() with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit tasks to the executor - futures = {executor.submit(upload_file_to_myshell, full_path): filename for filename, full_path in mapping_dict.items()} + futures = {executor.submit(upload_file_to_myshell, source_path, target_path, is_abs): filename for filename, (source_path, target_path, is_abs) in mapping_dict.items()} logging.info("submit done") # Collect the results as they complete for future in as_completed(futures): diff --git a/utils/utils.py b/utils/utils.py index 4854509..13e1216 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,6 +1,8 @@ import hashlib import time from pathlib import PurePosixPath, Path, PureWindowsPath +import base64 +import re def windows_to_linux_path(windows_path): return PureWindowsPath(windows_path).as_posix() @@ -17,4 +19,17 @@ def compute_sha256(file_path, chunk_size=1024 ** 2): sha256.update(chunk) print("finish compute sha256 for", file_path, f"time: {time.time() - start}") # Return the hexadecimal digest of the hash - return sha256.hexdigest() \ No newline at end of file + return sha256.hexdigest() + + +def get_alphanumeric_hash(input_string: str) -> str: + # Generate a SHA-256 hash of the input string + sha256_hash = hashlib.sha256(input_string.encode()).digest() + + # Encode the hash in base64 to get a string with [A-Za-z0-9+/=] + base64_hash = base64.b64encode(sha256_hash).decode('ascii') + + # Remove any non-alphanumeric characters (+, /, =) + alphanumeric_hash = re.sub(r'[^a-zA-Z0-9]', '', base64_hash) + + return alphanumeric_hash \ No newline at end of file From 92673900c56a1a6fd8233bd291d0b54ffce15aa0 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Thu, 14 Nov 2024 16:31:43 +0800 Subject: [PATCH 48/70] add glob to search models, and raise error when no models founded / multiple models founded --- dependency_checker.py | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/dependency_checker.py b/dependency_checker.py index 613325c..f2b2200 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -152,7 +152,7 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an file_mapping_dict = {} - def collect_unknown_models(filename): + def collect_unknown_models(filename, node_id, node_info): if type(filename) != str: return is_model = False @@ -162,7 +162,7 @@ def collect_unknown_models(filename): if is_model: print(f"find {filename}, is_model=True") # find possible paths - matching_files = [] + matching_files = {} # Walk through all subdirectories and files in the directory rel_save_path = None for possible_folder_name in folder_paths.folder_names_and_paths: @@ -170,15 +170,37 @@ def collect_unknown_models(filename): if full_path is None: continue rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[possible_folder_name][0][0], folder_paths.models_dir) - matching_files.append(full_path) - break - print(f"matched files: {matching_files}") - if len(set(matching_files)) == 1: - assert rel_save_path is not None - ckpt_paths[matching_files[0]] = { - "filename": filename, + matching_files[full_path] = { "rel_save_path": rel_save_path } + + print(f"matched files: {matching_files}") + + # step 2: search for all the files under "models" + + for full_path in glob.glob(f"{folder_paths.models_dir}/**/*", recursive=True): + if os.path.isfile(full_path) and full_path.endswith(filename) and full_path not in matching_files: + folder_path = full_path[:-len(filename)] + rel_save_path = os.path.relpath(folder_path, folder_paths.models_dir) + matching_files[full_path] = { + "rel_save_path": rel_save_path + } + + print(f"matched files: {matching_files}") + + if len(matching_files) == 0: + raise ValueError(f"Cannot find model: `{filename}`, Node ID: `{node_id}`, Node Info: `{node_info}`") + + elif len(matching_files) <= 3: + for full_path, info in matching_files.items(): + ckpt_paths[full_path] = { + "filename": filename, + "rel_save_path": info["rel_save_path"] + } + return + else: + raise ValueError(f"Multiple models of `{filename}` founded, Node ID: `{node_id}`, Node Info: `{node_info}`, Possible paths: `{list(matching_files.keys())}`") + for node_id, node_info in prompt.items(): node_class_type = node_info.get("class_type") @@ -214,7 +236,7 @@ def collect_unknown_models(filename): "rel_save_path": rel_save_path } elif not skip_model_check: - tree_map(collect_unknown_models, node_info["inputs"]) + tree_map(lambda x: collect_unknown_models(x, node_id, node_info), node_info["inputs"]) list(map(partial(collect_local_file, mapping_dict=file_mapping_dict), node_info["inputs"].values())) From 27613ed685bee03af9928b840571ff5c18f3c5cc Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 22 Nov 2024 11:49:10 +0800 Subject: [PATCH 49/70] support models in nodes / skip configs / input image support mask --- comfy-nodes/input_image.py | 59 ++++++++++++++++++++++++++++++++------ dependency_checker.py | 22 +++++++++++--- 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index b09a645..ebfe847 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -1,5 +1,7 @@ import folder_paths -from PIL import Image, ImageOps +import node_helpers + +from PIL import Image, ImageOps, ImageSequence, ImageFile import numpy as np import torch import os @@ -32,8 +34,8 @@ def INPUT_TYPES(s): } } - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("image",) + RETURN_TYPES = ("IMAGE", "MASK") + # RETURN_NAMES = ("image",) FUNCTION = "run" @@ -60,6 +62,46 @@ def VALIDATE_INPUTS(s, input_name, default_value, description=""): return "Invalid image file: {}".format(image) return True + + def convert_image_mask(self, img): + output_images = [] + output_masks = [] + w, h = None, None + + excluded_formats = ['MPO'] + + for i in ImageSequence.Iterator(img): + i = node_helpers.pillow(ImageOps.exif_transpose, i) + + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") + + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] + + if image.size[0] != w or image.size[1] != h: + continue + + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1 and img.format not in excluded_formats: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) + else: + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) def run(self, input_name, default_value=None, display_name=None, description=None): input_dir = folder_paths.get_input_directory() @@ -84,11 +126,12 @@ def run(self, input_name, default_value=None, display_name=None, description=Non image_path = os.path.join(input_dir, image_path) image = Image.open(image_path).convert("RGB") - image = ImageOps.exif_transpose(image) - image = image.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - return [image] + return self.convert_image_mask(image) + # image = ImageOps.exif_transpose(image) + # image = image.convert("RGB") + # image = np.array(image).astype(np.float32) / 255.0 + # image = torch.from_numpy(image)[None,] + # return [image] except Exception as e: raise e diff --git a/dependency_checker.py b/dependency_checker.py index f2b2200..1e73e05 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -151,8 +151,8 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an file_mapping_dict = {} - - def collect_unknown_models(filename, node_id, node_info): + SKIP_FOLDER_NAMES = ["configs", "custom_nodes"] + def collect_unknown_models(filename, node_id, node_info, custom_node_path): if type(filename) != str: return is_model = False @@ -166,6 +166,9 @@ def collect_unknown_models(filename, node_id, node_info): # Walk through all subdirectories and files in the directory rel_save_path = None for possible_folder_name in folder_paths.folder_names_and_paths: + if possible_folder_name in SKIP_FOLDER_NAMES: + print(f"skip {possible_folder_name}") + continue full_path = folder_paths.get_full_path(possible_folder_name, filename) if full_path is None: continue @@ -188,6 +191,16 @@ def collect_unknown_models(filename, node_id, node_info): print(f"matched files: {matching_files}") + # step 3: search inside the custom nodes + if custom_node_path is not None: + for full_path in glob.glob(f"{custom_node_path}/**/*", recursive=True): + if os.path.isfile(full_path) and full_path.endswith(filename) and full_path not in matching_files: + folder_path = full_path[:-len(filename)] + rel_save_path = os.path.relpath(folder_path, folder_paths.models_dir) + matching_files[full_path] = { + "rel_save_path": rel_save_path + } + if len(matching_files) == 0: raise ValueError(f"Cannot find model: `{filename}`, Node ID: `{node_id}`, Node Info: `{node_info}`") @@ -210,10 +223,11 @@ def collect_unknown_models(filename, node_id, node_info): skip_model_check = False + custom_node_path = None if hasattr(node_cls, "RELATIVE_PYTHON_MODULE") and node_cls.RELATIVE_PYTHON_MODULE.startswith("custom_nodes."): print(node_cls.RELATIVE_PYTHON_MODULE) custom_nodes.append(node_cls.RELATIVE_PYTHON_MODULE) - + custom_node_path = os.path.join(BASE_PATH, node_cls.RELATIVE_PYTHON_MODULE.replace(".", "/")) if node_cls.RELATIVE_PYTHON_MODULE[len("custom_nodes."):] in node_remote_skip_models: skip_model_check = True print(f"skip model check for {node_class_type}") @@ -236,7 +250,7 @@ def collect_unknown_models(filename, node_id, node_info): "rel_save_path": rel_save_path } elif not skip_model_check: - tree_map(lambda x: collect_unknown_models(x, node_id, node_info), node_info["inputs"]) + tree_map(lambda x: collect_unknown_models(x, node_id, node_info, custom_node_path), node_info["inputs"]) list(map(partial(collect_local_file, mapping_dict=file_mapping_dict), node_info["inputs"].values())) From 2d21584447c2a43a398441a10c55458f1ccc3bc6 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 3 Dec 2024 15:20:17 +0800 Subject: [PATCH 50/70] add route to inspect version --- custom_routes.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index 5dedf90..768af6e 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -32,8 +32,8 @@ import re import keyword -from .dependency_checker import resolve_dependencies - +from .dependency_checker import resolve_dependencies, inspect_repo_version +from folder_paths import base_path as BASE_PATH WORKFLOW_ROOT = "shellagent/comfy_workflow" @@ -193,4 +193,16 @@ async def shellagent_export(request): "message_detail": str(traceback.format_exc()), "message": str(e), } - return web.json_response(return_dict, status=status) \ No newline at end of file + return web.json_response(return_dict, status=status) + + +@server.PromptServer.instance.routes.post("/shellagent/inspect_version") # data same as queue prompt, plus workflow_name +async def shellagent_export(request): + data = await request.json() + comfyui_version = inspect_repo_version(BASE_PATH) + comfyui_shellagent_plugin_version = inspect_repo_version(os.path.dirname(__file__)) + return_dict = { + "comfyui_version": comfyui_version, + "comfyui_shellagent_plugin_version": comfyui_shellagent_plugin_version, + } + return web.json_response(return_dict, status=200) \ No newline at end of file From 623ea454cc3c17eaeb1156b1acc87a09d75992a7 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 10 Dec 2024 16:35:26 +0800 Subject: [PATCH 51/70] support input/output audio backend --- comfy-nodes/input_audio.py | 195 +++++++++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 comfy-nodes/input_audio.py diff --git a/comfy-nodes/input_audio.py b/comfy-nodes/input_audio.py new file mode 100644 index 0000000..df02294 --- /dev/null +++ b/comfy-nodes/input_audio.py @@ -0,0 +1,195 @@ +import folder_paths +import node_helpers + +from PIL import Image, ImageOps, ImageSequence, ImageFile +import numpy as np +import torch +import os +import uuid +import tqdm +import torchaudio +import hashlib +from comfy_extras.nodes_audio import SaveAudio + + +class LoadAudio: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = folder_paths.filter_files_content_types( + os.listdir(input_dir), ["audio", "video"]) + return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + + CATEGORY = "audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + waveform, sample_rate = torchaudio.load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + @classmethod + def IS_CHANGED(s, audio): + image_path = folder_paths.get_annotated_filepath(audio) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + +class ShellAgentPluginInputAudio: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = folder_paths.filter_files_content_types( + os.listdir(input_dir), ["audio", "video"]) + return { + "required": { + "input_name": ( + "STRING", + {"multiline": False, "default": "input_audio", "forceInput": False}, + ), + "default_value": ( + sorted(files), {"audio_upload": True, "forceInput": False} + ), + }, + "optional": { + "description": ( + "STRING", + {"multiline": True, "default": "", "forceInput": False}, + ), + } + } + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + CATEGORY = "shellagent" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["input_name"], + "type": "string", + "default": kwargs["default_value"], + "description": kwargs.get("description", ""), + "url_type": "audio" + } + return schema + + @classmethod + def VALIDATE_INPUTS(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + @classmethod + def VALIDATE_INPUTS(s, input_name, default_value, description=""): + audio = default_value + if audio.startswith("http"): + return True + + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + def load(self, input_name, default_value=None, display_name=None, description=None): + input_dir = folder_paths.get_input_directory() + audio_path = default_value + try: + if audio_path.startswith('http'): + import requests + from io import BytesIO + print("Fetching audio from url: ", audio_path) + response = requests.get(audio_path) + response.raise_for_status() + audio_file = BytesIO(response.content) + waveform, sample_rate = torchaudio.load(audio_file) + else: + if not os.path.isfile(audio_path): # abs path + # local path + audio_path = os.path.join(input_dir, audio_path) + waveform, sample_rate = torchaudio.load(audio_path) + + audio = {"waveform": waveform.unsqueeze( + 0), "sample_rate": sample_rate} + return (audio, ) + # image = ImageOps.exif_transpose(image) + # image = image.convert("RGB") + # image = np.array(image).astype(np.float32) / 255.0 + # image = torch.from_numpy(image)[None,] + # return [image] + except Exception as e: + raise e + + +class ShellAgentSaveAudios(SaveAudio): + @classmethod + def INPUT_TYPES(s): + return {"required": {"audio": ("AUDIO", ), + "output_name": ("STRING", {"multiline": False, "default": "output_audio"},), + "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + # { + # "required": { + # "images": ("IMAGE", {"tooltip": "The audio to save."}), + # "output_name": ("STRING", {"multiline": False, "default": "output_image"},), + # "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + # }, + # "hidden": { + # "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + # }, + # } + + CATEGORY = "shellagent" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["output_name"], + "type": "array", + "items": { + "type": "string", + "url_type": "audio", + } + } + return schema + + def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, **extra_kwargs): + results = super().save_audio(audio, filename_prefix, prompt, extra_pnginfo) + results["shellagent_kwargs"] = extra_kwargs + return results + + +class ShellAgentSaveAudio(ShellAgentSaveAudios): + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["output_name"], + "type": "string", + "url_type": "audio", + } + return schema + + +NODE_CLASS_MAPPINGS = { + "ShellAgentPluginInputAudio": ShellAgentPluginInputAudio, + "ShellAgentPluginSaveAudios": ShellAgentSaveAudios, + "ShellAgentPluginSaveAudio": ShellAgentSaveAudio, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ShellAgentPluginInputAudio": "Input Audio (ShellAgent Plugin)", + "ShellAgentPluginSaveAudios": "Save Audios (ShellAgent Plugin)", + "ShellAgentPluginSaveAudio": "Save Audio (ShellAgent Plugin)", +} From 676c40691fcbf27fbc436a4d5275ef8f52408614 Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Thu, 12 Dec 2024 14:00:55 +0800 Subject: [PATCH 52/70] Update node_deps_info.json --- node_deps_info.json | 50 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/node_deps_info.json b/node_deps_info.json index 7a9c640..2126f4d 100644 --- a/node_deps_info.json +++ b/node_deps_info.json @@ -17,7 +17,46 @@ } ], "efficiency-nodes-comfyui": [ - { + {{ + "ComfyUI-Easy-Use": [ + { + "name": "ComfyUI-Inspire-Pack", + "repo": "https://github.com/ltdrdata/ComfyUI-Inspire-Pack.git", + "commit": "" + }, + { + "name": "ComfyUI-Advanced-ControlNet", + "repo": "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet.git", + "commit": "" + }, + { + "name": "ComfyUI_smZNodes", + "repo": "https://github.com/shiimizu/ComfyUI_smZNodes.git", + "commit": "" + } + ], + "efficiency-nodes-comfyui": [ + { + "name": "comfyui_controlnet_aux", + "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", + "commit": "" + } + ], + "ComfyUI-Anyline": [ + { + "name": "comfyui_controlnet_aux", + "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", + "commit": "" + } + ], + "ComfyUI-Impact-Pack": [ + { + "name": "ComfyUI-Impact-Subpack", + "repo": "https://github.com/ltdrdata/ComfyUI-Impact-Subpack.git", + "commit": "" + } + ] +} "name": "comfyui_controlnet_aux", "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", "commit": "" @@ -29,5 +68,12 @@ "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", "commit": "" } + ], + "ComfyUI-Impact-Pack": [ + { + "name": "ComfyUI-Impact-Subpack", + "repo": "https://github.com/ltdrdata/ComfyUI-Impact-Subpack.git", + "commit": "" + } ] -} \ No newline at end of file +} From f1135ac55adfc65e89d36fcf6430014f7e7340f4 Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Thu, 12 Dec 2024 14:01:48 +0800 Subject: [PATCH 53/70] Update node_deps_info.json --- node_deps_info.json | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/node_deps_info.json b/node_deps_info.json index 2126f4d..bdb905a 100644 --- a/node_deps_info.json +++ b/node_deps_info.json @@ -1,23 +1,4 @@ { - "ComfyUI-Easy-Use": [ - { - "name": "ComfyUI-Inspire-Pack", - "repo": "https://github.com/ltdrdata/ComfyUI-Inspire-Pack.git", - "commit": "" - }, - { - "name": "ComfyUI-Advanced-ControlNet", - "repo": "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet.git", - "commit": "" - }, - { - "name": "ComfyUI_smZNodes", - "repo": "https://github.com/shiimizu/ComfyUI_smZNodes.git", - "commit": "" - } - ], - "efficiency-nodes-comfyui": [ - {{ "ComfyUI-Easy-Use": [ { "name": "ComfyUI-Inspire-Pack", @@ -56,24 +37,4 @@ "commit": "" } ] -} - "name": "comfyui_controlnet_aux", - "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", - "commit": "" - } - ], - "ComfyUI-Anyline": [ - { - "name": "comfyui_controlnet_aux", - "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", - "commit": "" - } - ], - "ComfyUI-Impact-Pack": [ - { - "name": "ComfyUI-Impact-Subpack", - "repo": "https://github.com/ltdrdata/ComfyUI-Impact-Subpack.git", - "commit": "" - } - ] } From 7238c1f40bdb9e2a22a4fe45e3e186dc57d97f06 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 13 Dec 2024 11:45:43 +0800 Subject: [PATCH 54/70] fix error message when empty input image --- comfy-nodes/input_image.py | 6 +++++- dependency_checker.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index ebfe847..65fb2c2 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -55,10 +55,14 @@ def validate(cls, **kwargs): @classmethod def VALIDATE_INPUTS(s, input_name, default_value, description=""): image = default_value + if image.startswith("http"): return True - if not folder_paths.exists_annotated_filepath(image): + if not os.path.isfile(image): + return "Invalid image file: please check if the image is empty or invalid" + + if folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) return True diff --git a/dependency_checker.py b/dependency_checker.py index 1e73e05..7aa176d 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -283,7 +283,10 @@ def collect_unknown_models(filename, node_id, node_info, custom_node_path): print(f"failed to resolve repo info of {custom_node}") requirement_file = os.path.join(BASE_PATH, custom_node.replace(".", "/"), "requirements.txt") if os.path.isfile(requirement_file): - requirements_lines += open(requirement_file).readlines() + try: + requirements_lines += open(requirement_file).readlines() + except: + pass requirements_lines = list(set(requirements_lines)) requirements_packages = [package_name for package_name, version_specifier in map(split_package_version, requirements_lines) if package_name is not None] package_names = set(requirements_packages + extra_packages) From 82f3a05f1c9adc30e9a4350206728f6c7a128569 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Fri, 13 Dec 2024 14:52:16 +0800 Subject: [PATCH 55/70] fix mask bug --- comfy-nodes/input_image.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 65fb2c2..9896832 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -59,10 +59,10 @@ def VALIDATE_INPUTS(s, input_name, default_value, description=""): if image.startswith("http"): return True - if not os.path.isfile(image): + if image == "": return "Invalid image file: please check if the image is empty or invalid" - if folder_paths.exists_annotated_filepath(image): + if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) return True @@ -107,6 +107,7 @@ def convert_image_mask(self, img): return (output_image, output_mask) + def run(self, input_name, default_value=None, display_name=None, description=None): input_dir = folder_paths.get_input_directory() image_path = default_value @@ -128,7 +129,7 @@ def run(self, input_name, default_value=None, display_name=None, description=Non if not os.path.isfile(image_path): # abs path # local path image_path = os.path.join(input_dir, image_path) - image = Image.open(image_path).convert("RGB") + image = node_helpers.pillow(Image.open, image_path) return self.convert_image_mask(image) # image = ImageOps.exif_transpose(image) From 604d34900a165edb4b75f3900842079bd345fe53 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 16 Dec 2024 17:25:23 +0800 Subject: [PATCH 56/70] support heif image --- comfy-nodes/input_image.py | 3 +++ requirements.txt | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 9896832..9df4a00 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -7,6 +7,9 @@ import os import uuid import tqdm +from pillow_heif import register_heif_opener + +register_heif_opener() class ShellAgentPluginInputImage: diff --git a/requirements.txt b/requirements.txt index f1d3a61..6666fe6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ pydantic opencv-python imageio-ffmpeg brotli -# logfire \ No newline at end of file +pillow_heif +# logfire From 637bc88feca0054c80155b01841e7c9ba6080d5d Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Tue, 17 Dec 2024 19:24:59 +0800 Subject: [PATCH 57/70] pass validation when os.path.isfile --- comfy-nodes/input_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 9df4a00..dd2c57f 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -65,6 +65,9 @@ def VALIDATE_INPUTS(s, input_name, default_value, description=""): if image == "": return "Invalid image file: please check if the image is empty or invalid" + if os.path.isfile(image): + return True + if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) From 2310c339663fd29dde7998fd6089d0fc5254fd1f Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Wed, 18 Dec 2024 12:01:11 +0800 Subject: [PATCH 58/70] add get mac_addr --- custom_routes.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/custom_routes.py b/custom_routes.py index 768af6e..cf227ed 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -31,6 +31,7 @@ import traceback import re import keyword +import uuid from .dependency_checker import resolve_dependencies, inspect_repo_version from folder_paths import base_path as BASE_PATH @@ -205,4 +206,13 @@ async def shellagent_export(request): "comfyui_version": comfyui_version, "comfyui_shellagent_plugin_version": comfyui_shellagent_plugin_version, } + return web.json_response(return_dict, status=200) + + +@server.PromptServer.instance.routes.post("/shellagent/get_mac_addr") # data same as queue prompt, plus workflow_name +async def shellagent_export(request): + data = await request.json() + return_dict = { + "mac_addr": uuid.getnode() + } return web.json_response(return_dict, status=200) \ No newline at end of file From 37eb10c32750364478b1f68d482ae5c7fa6f6318 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Wed, 18 Dec 2024 14:42:28 +0800 Subject: [PATCH 59/70] add check_exist --- custom_routes.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/custom_routes.py b/custom_routes.py index cf227ed..c85fec1 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -198,7 +198,7 @@ async def shellagent_export(request): @server.PromptServer.instance.routes.post("/shellagent/inspect_version") # data same as queue prompt, plus workflow_name -async def shellagent_export(request): +async def shellagent_inspect_version(request): data = await request.json() comfyui_version = inspect_repo_version(BASE_PATH) comfyui_shellagent_plugin_version = inspect_repo_version(os.path.dirname(__file__)) @@ -210,9 +210,18 @@ async def shellagent_export(request): @server.PromptServer.instance.routes.post("/shellagent/get_mac_addr") # data same as queue prompt, plus workflow_name -async def shellagent_export(request): +async def shellagent_get_mac_addr(request): data = await request.json() return_dict = { "mac_addr": uuid.getnode() } + return web.json_response(return_dict, status=200) + +@server.PromptServer.instance.routes.post("/shellagent/check_exist") # check if the file or folder exist +async def shellagent_check_exist(request): + data = await request.json() + + return_dict = { + "exist": os.path.exists(data["path"]) + } return web.json_response(return_dict, status=200) \ No newline at end of file From 580ac932df430e94924ae3c294293f99daf82f40 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Wed, 18 Dec 2024 14:49:38 +0800 Subject: [PATCH 60/70] add mac_addr to check realy exist --- custom_routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_routes.py b/custom_routes.py index c85fec1..9e2dca2 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -222,6 +222,6 @@ async def shellagent_check_exist(request): data = await request.json() return_dict = { - "exist": os.path.exists(data["path"]) + "exist": uuid.getnode() == data["mac_addr"] and os.path.exists(data["path"]) # really exist, instead of same name } return web.json_response(return_dict, status=200) \ No newline at end of file From 681f716bfa8816370d1a9f1333fb7ea4efb035d7 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Wed, 18 Dec 2024 15:18:48 +0800 Subject: [PATCH 61/70] update safe open image function --- comfy-nodes/input_image.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index dd2c57f..73e96ec 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -7,10 +7,33 @@ import os import uuid import tqdm +from io import BytesIO +import PIL +import cv2 from pillow_heif import register_heif_opener register_heif_opener() +def safe_open_image(image_bytes): + try: + image_pil = Image.open(BytesIO(image_bytes)) + except PIL.UnidentifiedImageError as e: + print(e) + # Convert response content (bytes) to a NumPy array + image_array = np.frombuffer(image_bytes, np.uint8) + + # Decode the image from the NumPy array (OpenCV format: BGR) + image_cv = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + + if image_cv is not None: + # Convert the BGR image to RGB + image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) + + # Convert the RGB NumPy array to a PIL Image + image_pil = Image.fromarray(image_rgb) + else: + raise ValueError("The image cannot be identified by neither PIL nor OpenCV") + return image_pil class ShellAgentPluginInputImage: @classmethod @@ -123,7 +146,7 @@ def run(self, input_name, default_value=None, display_name=None, description=Non from io import BytesIO print("Fetching image from url: ", image_path) response = requests.get(image_path) - image = Image.open(BytesIO(response.content)) + image = safe_open_image(response.content) elif image_path.startswith('data:image/png;base64,') or image_path.startswith('data:image/jpeg;base64,') or image_path.startswith('data:image/jpg;base64,'): import base64 from io import BytesIO From 4e2bfd56209a9ecbe9aa51e9017bff1277b72397 Mon Sep 17 00:00:00 2001 From: Wenliang Zhao Date: Mon, 30 Dec 2024 14:55:30 +0800 Subject: [PATCH 62/70] Update dependency_checker.py --- dependency_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dependency_checker.py b/dependency_checker.py index 7aa176d..b516da9 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -22,7 +22,7 @@ node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) node_remote_skip_models = json.load(open(os.path.join(os.path.dirname(__file__), "node_remote.json"))) -model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf", ".sft"] +model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf", ".sft", ".ttf"] extra_packages = ["transformers", "timm", "diffusers", "accelerate"] From 04d33d8a8edb9cdf8bc41fca22cb5277e7182679 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Mon, 13 Jan 2025 15:55:40 +0800 Subject: [PATCH 63/70] add easydict --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 6666fe6..6b6d458 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ opencv-python imageio-ffmpeg brotli pillow_heif +easydict # logfire From 078b3f5ea55e12b55bce4c7f639b6a0b4ba1f6a0 Mon Sep 17 00:00:00 2001 From: Xumin Yu <1090414006@qq.com> Date: Thu, 16 Jan 2025 16:03:08 +0800 Subject: [PATCH 64/70] Add joint dependency for Easy use (ComfyUI_IPAdapter_plus) --- node_deps_info.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/node_deps_info.json b/node_deps_info.json index bdb905a..040669e 100644 --- a/node_deps_info.json +++ b/node_deps_info.json @@ -14,6 +14,11 @@ "name": "ComfyUI_smZNodes", "repo": "https://github.com/shiimizu/ComfyUI_smZNodes.git", "commit": "" + }, + { + "name": "ComfyUI_IPAdapter_plus", + "repo": "https://github.com/cubiq/ComfyUI_IPAdapter_plus.git", + "commit": "" } ], "efficiency-nodes-comfyui": [ From bc62e8a44cb79d85ea4f957d6be503be980ae2b4 Mon Sep 17 00:00:00 2001 From: shanexi Date: Wed, 12 Feb 2025 18:10:01 +0800 Subject: [PATCH 65/70] feat: support input audio --- web/shellagent.js | 124 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/web/shellagent.js b/web/shellagent.js index 79be03e..78781fd 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -1,6 +1,9 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; +var __defProp = Object.defineProperty; +var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); + app.registerExtension({ name: "Shellagent.extension", async setup() { @@ -173,6 +176,18 @@ app.registerExtension({ } } + if (nodeData.name === "ShellAgentPluginInputAudio") { + if ( + nodeData?.input?.required?.default_value?.[1]?.audio_upload === true + ) { + nodeData.input.required.audioUI = ["AUDIO_UI"]; + nodeData.input.required.upload = [ + "SHELLAGENT_AUDIOUPLOAD", + { widget: "default_value" }, + ]; + } + } + if (nodeData.name === "ShellAgentPluginInputVideo") { addUploadWidget(nodeType, nodeData, "default_value"); chainCallback(nodeType.prototype, "onNodeCreated", function () { @@ -354,23 +369,81 @@ app.registerExtension({ afterConfigureGraph(missingNodeTypes, app) { function addIn(type, nodeId) { + if(LiteGraph.slot_types_default_in[type] == null) { + LiteGraph.slot_types_default_in[type] = [] + } if (LiteGraph.slot_types_default_in[type].indexOf(nodeId) === -1) { LiteGraph.slot_types_default_in[type].unshift(nodeId) } } function addOut(type, nodeId) { + if(LiteGraph.slot_types_default_out[type] == null) { + LiteGraph.slot_types_default_out[type] = [] + } if (LiteGraph.slot_types_default_out[type].indexOf(nodeId) === -1) { LiteGraph.slot_types_default_out[type].unshift(nodeId) } } addIn('IMAGE', 'ShellAgentPluginInputImage') + addIn('AUDIO', 'ShellAgentPluginInputAudio') addOut('IMAGE', 'ShellAgentPluginSaveImage') addOut('IMAGE', 'ShellAgentPluginSaveImages') + addOut('AUDIO', 'ShellAgentPluginSaveAudios') + addOut('AUDIO', 'ShellAgentPluginSaveAudio') addOut('STRING', 'ShellAgentPluginOutputInteger') addOut('STRING', 'ShellAgentPluginOutputFloat') addOut('STRING', 'ShellAgentPluginOutputText') + }, + getCustomWidgets() { + return { + SHELLAGENT_AUDIOUPLOAD(node, inputName) { + const audioWidget = node.widgets.find( + (w) => w.name === "default_value" + ); + const audioUIWidget = node.widgets.find( + (w) => w.name === "audioUI" + ); + const onAudioWidgetUpdate = /* @__PURE__ */ __name(() => { + audioUIWidget.element.src = api.apiURL( + getResourceURL(...splitFilePath(audioWidget.value)) + ); + }, "onAudioWidgetUpdate"); + if (audioWidget.value) { + onAudioWidgetUpdate(); + } + audioWidget.callback = onAudioWidgetUpdate; + const onGraphConfigured = node.onGraphConfigured; + node.onGraphConfigured = function() { + onGraphConfigured?.apply(this, arguments); + if (audioWidget.value) { + onAudioWidgetUpdate(); + } + }; + const fileInput = document.createElement("input"); + fileInput.type = "file"; + fileInput.accept = "audio/*"; + fileInput.style.display = "none"; + fileInput.onchange = () => { + if (fileInput.files.length) { + uploadFileAudio(audioWidget, audioUIWidget, fileInput.files[0], true); + } + }; + const uploadWidget = node.addWidget( + "button", + inputName, + /* value=*/ + "", + () => { + fileInput.click(); + }, + { serialize: false } + ); + uploadWidget.label = "choose file to upload"; + return { widget: uploadWidget }; + } + }; } }); @@ -744,4 +817,55 @@ function addLoadVideoCommon(nodeType, nodeData) { } }); }); +} + +function getResourceURL(subfolder, filename, type = "input") { + const params = [ + "filename=" + encodeURIComponent(filename), + "type=" + type, + "subfolder=" + subfolder, + app.getRandParam().substring(1) + ].join("&"); + return `/view?${params}`; +} + +function splitFilePath(path) { + const folder_separator = path.lastIndexOf("/"); + if (folder_separator === -1) { + return ["", path]; + } + return [ + path.substring(0, folder_separator), + path.substring(folder_separator + 1) + ]; +} + +async function uploadFileAudio(audioWidget, audioUIWidget, file2, updateNode, pasted = false) { + try { + const body = new FormData(); + body.append("image", file2); + if (pasted) body.append("subfolder", "pasted"); + const resp = await api.fetchApi("/upload/image", { + method: "POST", + body + }); + if (resp.status === 200) { + const data = await resp.json(); + let path = data.name; + if (data.subfolder) path = data.subfolder + "/" + path; + if (!audioWidget.options.values.includes(path)) { + audioWidget.options.values.push(path); + } + if (updateNode) { + audioUIWidget.element.src = api.apiURL( + getResourceURL(...splitFilePath(path)) + ); + audioWidget.value = path; + } + } else { + window.alert(resp.status + " - " + resp.statusText); + } + } catch (error) { + window.alert(error); + } } \ No newline at end of file From 06163a06d4a2cdeb2c1c85fcb01c68b1f80112f5 Mon Sep 17 00:00:00 2001 From: Chengwei Ouyang Date: Mon, 10 Mar 2025 23:21:36 +0800 Subject: [PATCH 66/70] i[date --- comfy-nodes/input_image.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 73e96ec..83e7e79 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -47,10 +47,12 @@ def INPUT_TYPES(s): "STRING", {"multiline": False, "default": "input_image", "forceInput": False}, ), - "default_value": ( - # "STRING", {"image_upload": True, "default": files[0] if len(files) else ""}, + "image": ( sorted(files), {"image_upload": True, "forceInput": False} ), + "default_value": ( + "STRING", {"forceInput": False} + ), }, "optional": { "description": ( @@ -79,20 +81,21 @@ def validate(cls, **kwargs): return schema @classmethod - def VALIDATE_INPUTS(s, input_name, default_value, description=""): - image = default_value + def VALIDATE_INPUTS(s, input_name, default_value, image=None, description=""): + # check default_value first + image_to_check = default_value if default_value else image - if image.startswith("http"): + if image_to_check.startswith("http"): return True - if image == "": + if image_to_check == "": return "Invalid image file: please check if the image is empty or invalid" - if os.path.isfile(image): + if os.path.isfile(image_to_check): return True - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) + if not folder_paths.exists_annotated_filepath(image_to_check): + return "Invalid image file: {}".format(image_to_check) return True @@ -137,9 +140,10 @@ def convert_image_mask(self, img): return (output_image, output_mask) - def run(self, input_name, default_value=None, display_name=None, description=None): + def run(self, input_name, default_value=None, image=None, display_name=None, description=None): + # use default_value if it exists, otherwise use image + image_path = default_value if default_value else image input_dir = folder_paths.get_input_directory() - image_path = default_value try: if image_path.startswith('http'): import requests From 0599875a7902fd660f2280d7b57aafab50ebea5d Mon Sep 17 00:00:00 2001 From: shanexi Date: Wed, 12 Mar 2025 14:55:39 +0800 Subject: [PATCH 67/70] fix: input image error in new comfyui version --- comfy-nodes/input_image.py | 2 +- web/shellagent.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 83e7e79..3dd6230 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -57,7 +57,7 @@ def INPUT_TYPES(s): "optional": { "description": ( "STRING", - {"multiline": True, "default": "", "forceInput": False}, + {"multiline": False, "default": "", "forceInput": False}, ), } } diff --git a/web/shellagent.js b/web/shellagent.js index 78781fd..8480e4a 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -171,7 +171,7 @@ app.registerExtension({ ) { nodeData.input.required.upload = [ "IMAGEUPLOAD", - { widget: "default_value" }, + { widget: "default_value", imageInputName: "default_value", image_upload: true }, ]; } } From ced3e0de2fe5ed737a9a9dc28675e18f0b891e64 Mon Sep 17 00:00:00 2001 From: Chengwei Ouyang Date: Wed, 12 Mar 2025 15:11:54 +0800 Subject: [PATCH 68/70] remove image --- comfy-nodes/input_image.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 3dd6230..986054f 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -47,11 +47,8 @@ def INPUT_TYPES(s): "STRING", {"multiline": False, "default": "input_image", "forceInput": False}, ), - "image": ( - sorted(files), {"image_upload": True, "forceInput": False} - ), "default_value": ( - "STRING", {"forceInput": False} + sorted(files), {"image_upload": True, "forceInput": False} ), }, "optional": { From 1f16525118f92c8375b4074adc2538c9a8645132 Mon Sep 17 00:00:00 2001 From: Chengwei Ouyang Date: Wed, 12 Mar 2025 20:35:19 +0800 Subject: [PATCH 69/70] update --- comfy-nodes/input_image.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index 986054f..3399a4f 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -78,21 +78,20 @@ def validate(cls, **kwargs): return schema @classmethod - def VALIDATE_INPUTS(s, input_name, default_value, image=None, description=""): - # check default_value first - image_to_check = default_value if default_value else image + def VALIDATE_INPUTS(s, input_name, default_value, description=""): + image = default_value - if image_to_check.startswith("http"): + if image.startswith("http"): return True - if image_to_check == "": + if image == "": return "Invalid image file: please check if the image is empty or invalid" - if os.path.isfile(image_to_check): + if os.path.isfile(image): return True - if not folder_paths.exists_annotated_filepath(image_to_check): - return "Invalid image file: {}".format(image_to_check) + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) return True @@ -137,9 +136,8 @@ def convert_image_mask(self, img): return (output_image, output_mask) - def run(self, input_name, default_value=None, image=None, display_name=None, description=None): - # use default_value if it exists, otherwise use image - image_path = default_value if default_value else image + def run(self, input_name, default_value=None, display_name=None, description=None): + image_path = default_value input_dir = folder_paths.get_input_directory() try: if image_path.startswith('http'): From d27fafbb003badd19246b05049dd0caca483a2a1 Mon Sep 17 00:00:00 2001 From: Chengwei Ouyang Date: Thu, 22 May 2025 14:54:40 +0800 Subject: [PATCH 70/70] fix shellagent save audio for new version comfyui --- comfy-nodes/input_audio.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy-nodes/input_audio.py b/comfy-nodes/input_audio.py index df02294..7e5997f 100644 --- a/comfy-nodes/input_audio.py +++ b/comfy-nodes/input_audio.py @@ -169,6 +169,11 @@ def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginf results = super().save_audio(audio, filename_prefix, prompt, extra_pnginfo) results["shellagent_kwargs"] = extra_kwargs return results + + def save_flac(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, **extra_kwargs): + results = super().save_flac(audio, filename_prefix, "flac", prompt, extra_pnginfo) + results["shellagent_kwargs"] = extra_kwargs + return results class ShellAgentSaveAudio(ShellAgentSaveAudios):