From 6b4786c0e689ad8d5c5f30ed05ef671f61806d81 Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+hongil0316@users.noreply.github.com> Date: Sat, 11 May 2024 18:19:11 -0400 Subject: [PATCH] Implementation for installing node from registry (#41) * [WIP] Implementation for installing node from registry * Refactoring + Complete installation process * Refactor + improve logging messages * Remove default values used for testing * Fix some minor errors * fmt --- comfy_cli/command/custom_nodes/command.py | 125 ++++++++++++++++-- comfy_cli/command/models/models.py | 46 +------ comfy_cli/file_utils.py | 108 ++++++++++++++++ comfy_cli/registry/__init__.py | 8 +- comfy_cli/registry/api.py | 148 ++++++++++++++++------ comfy_cli/registry/types.py | 16 ++- comfy_cli/registry/zip.py | 30 ----- comfy_cli/ui.py | 10 ++ 8 files changed, 363 insertions(+), 128 deletions(-) create mode 100644 comfy_cli/file_utils.py delete mode 100644 comfy_cli/registry/zip.py diff --git a/comfy_cli/command/custom_nodes/command.py b/comfy_cli/command/custom_nodes/command.py index 4e18e85..948dc55 100644 --- a/comfy_cli/command/custom_nodes/command.py +++ b/comfy_cli/command/custom_nodes/command.py @@ -1,23 +1,28 @@ -import typer -from typing_extensions import List, Annotated -from typing import Optional - -from comfy_cli import tracking import os +import pathlib import subprocess import sys -from rich import print import uuid -from comfy_cli.config_manager import ConfigManager -from comfy_cli.workspace_manager import WorkspaceManager +from typing import Optional +import typer +from rich import print +from typing_extensions import List, Annotated + +from comfy_cli import ui, logging, tracking +from comfy_cli.config_manager import ConfigManager +from comfy_cli.file_utils import ( + download_file, + upload_file_to_signed_url, + zip_files, + extract_package_as_zip, +) from comfy_cli.registry import ( RegistryAPI, extract_node_configuration, - upload_file_to_signed_url, - zip_files, initialize_project_config, ) +from comfy_cli.workspace_manager import WorkspaceManager app = typer.Typer() manager_app = typer.Typer() @@ -559,3 +564,103 @@ def scaffold(): typer.echo( "pyproject.toml created successfully. Defaults were filled in. Please check before publishing." ) + + +@app.command("registry-list", help="List all nodes in the registry", hidden=True) +@tracking.track_command("node") +def display_all_nodes(): + """ + Display all nodes in the registry. + """ + + nodes = None + try: + nodes = registry_api.list_all_nodes() + except Exception as e: + logging.error(f"Failed to fetch nodes from the registry: {str(e)}") + ui.display_error_message("Failed to fetch nodes from the registry.") + + # Map Node data class instances to tuples for display + node_data = [ + ( + node.id, + node.name, + node.description, + node.author or "N/A", + node.license or "N/A", + ", ".join(node.tags), + node.latest_version.version if node.latest_version else "N/A", + ) + for node in nodes + ] + ui.display_table( + node_data, + [ + "ID", + "Name", + "Description", + "Author", + "License", + "Tags", + "Latest Version", + ], + title="List of All Nodes", + ) + + +@app.command("registry-install", help="Install a node from the registry", hidden=True) +@tracking.track_command("node") +def install(node_id: str, version: Optional[str] = None): + """ + Install a node from the registry. + Args: + node_id: The ID of the node to install. + version: The version of the node to install. If not provided, the latest version will be installed. + """ + + # If the node ID is not provided, prompt the user to enter it + if not node_id: + node_id = typer.prompt("Enter the ID of the node you want to install") + + node_version = None + try: + # Call the API to install the node + node_version = registry_api.install_node(node_id, version) + if not node_version.download_url: + logging.error("Download URL not provided from the registry.") + ui.display_error_message(f"Failed to download the custom node {node_id}.") + return + + except Exception as e: + logging.error( + f"Encountered an error while installing the node. error: {str(e)}" + ) + ui.display_error_message(f"Failed to download the custom node {node_id}.") + return + + # Download the node archive + custom_nodes_path = pathlib.Path(workspace_manager.workspace_path) / "custom_nodes" + node_specific_path = custom_nodes_path / node_id # Subdirectory for the node + node_specific_path.mkdir( + parents=True, exist_ok=True + ) # Create the directory if it doesn't exist + + local_filename = node_specific_path / f"{node_id}-{node_version.version}.zip" + logging.debug( + f"Start downloading the node {node_id} version {node_version.version} to {local_filename}" + ) + download_file(node_version.download_url, local_filename) + + # Extract the downloaded archive to the custom_node directory on the workspace. + logging.debug( + f"Start extracting the node {node_id} version {node_version.version} to {custom_nodes_path}" + ) + extract_package_as_zip(local_filename, node_specific_path) + + # Delete the downloaded archive + logging.debug(f"Deleting the downloaded archive {local_filename}") + os.remove(local_filename) + + logging.info( + f"Node {node_id} version {node_version.version} has been successfully installed." + ) diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index a245ee2..9d0bff8 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -7,6 +7,7 @@ from comfy_cli import tracking, ui from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH +from comfy_cli.file_utils import download_file, DownloadException from comfy_cli.workspace_manager import WorkspaceManager app = typer.Typer() @@ -18,10 +19,6 @@ def get_workspace() -> pathlib.Path: return pathlib.Path(workspace_manager.workspace_path) -class DownloadException(Exception): - pass - - def potentially_strip_param_url(path_name: str) -> str: path_name = path_name.split("?")[0] return path_name @@ -158,47 +155,6 @@ def list( ui.display_table(data, column_names) -def guess_status_code_reason(status_code: int) -> str: - if status_code == 401: - return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one" - elif status_code == 403: - return f"Forbidden url ({status_code}), you might need to manually log into browser to download one" - elif status_code == 404: - return "Sorry, your model is in another castle (404)" - return f"Unknown error occurred (status code: {status_code})" - - -def download_file(url: str, local_filepath: pathlib.Path): - """Helper function to download a file.""" - - import httpx - - local_filepath.parent.mkdir( - parents=True, exist_ok=True - ) # Ensure the directory exists - - with httpx.stream("GET", url, follow_redirects=True) as response: - if response.status_code == 200: - total = int(response.headers["Content-Length"]) - try: - with open(local_filepath, "wb") as f: - for data in ui.show_progress( - response.iter_bytes(), - total, - description=f"Downloading {total//1024//1024} MB", - ): - f.write(data) - except KeyboardInterrupt: - delete_eh = ui.prompt_confirm_action( - "Download interrupted, cleanup files?" - ) - if delete_eh: - local_filepath.unlink() - else: - status_reason = guess_status_code_reason(response.status_code) - raise DownloadException(f"Failed to download file.\n{status_reason}") - - def list_models(path: pathlib.Path) -> list: """List all models in the specified directory.""" return [file for file in path.iterdir() if file.is_file()] diff --git a/comfy_cli/file_utils.py b/comfy_cli/file_utils.py new file mode 100644 index 0000000..6eee407 --- /dev/null +++ b/comfy_cli/file_utils.py @@ -0,0 +1,108 @@ +import os +import pathlib +import zipfile + +import requests +from pathspec import pathspec + +from comfy_cli import ui + + +class DownloadException(Exception): + pass + + +def guess_status_code_reason(status_code: int) -> str: + if status_code == 401: + return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one" + elif status_code == 403: + return f"Forbidden url ({status_code}), you might need to manually log into browser to download one" + elif status_code == 404: + return "Sorry, your model is in another castle (404)" + return f"Unknown error occurred (status code: {status_code})" + + +def download_file(url: str, local_filepath: pathlib.Path): + """Helper function to download a file.""" + + import httpx + + local_filepath.parent.mkdir( + parents=True, exist_ok=True + ) # Ensure the directory exists + + with httpx.stream("GET", url, follow_redirects=True) as response: + if response.status_code == 200: + total = int(response.headers["Content-Length"]) + try: + with open(local_filepath, "wb") as f: + for data in ui.show_progress( + response.iter_bytes(), + total, + description=f"Downloading {total // 1024 // 1024} MB", + ): + f.write(data) + except KeyboardInterrupt: + delete_eh = ui.prompt_confirm_action( + "Download interrupted, cleanup files?" + ) + if delete_eh: + local_filepath.unlink() + else: + status_reason = guess_status_code_reason(response.status_code) + raise DownloadException(f"Failed to download file.\n{status_reason}") + + +def zip_files(zip_filename): + gitignore_path = ".gitignore" + if not os.path.exists(gitignore_path): + print(f"No .gitignore file found in {os.getcwd()}, proceeding without it.") + gitignore = "" + else: + with open(gitignore_path, "r") as file: + gitignore = file.read() + + spec = pathspec.PathSpec.from_lines("gitwildmatch", gitignore.splitlines()) + + with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, dirs, files in os.walk("."): + if ".git" in dirs: + dirs.remove(".git") + for file in files: + file_path = os.path.join(root, file) + if not spec.match_file(file_path): + zipf.write( + file_path, os.path.relpath(file_path, os.path.join(root, "..")) + ) + + +def upload_file_to_signed_url(signed_url: str, file_path: str): + try: + with open(file_path, "rb") as f: + headers = {"Content-Type": "application/gzip"} + response = requests.put(signed_url, data=f, headers=headers) + + # Simple success check + if response.status_code == 200: + print("Upload successful.") + else: + # Print a generic error message with status code and response text + print( + f"Upload failed with status code: {response.status_code}. Error: {response.text}" + ) + + except requests.exceptions.RequestException as e: + # Print error related to the HTTP request + print(f"An error occurred during the upload: {str(e)}") + except FileNotFoundError: + # Print file not found error + print(f"Error: The file {file_path} does not exist.") + + +def extract_package_as_zip(file_path: pathlib.Path, extract_path: pathlib.Path): + try: + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(extract_path) + print(f"Extracted zip file to {extract_path}") + except zipfile.BadZipFile: + print("File is not a zip or is corrupted.") diff --git a/comfy_cli/registry/__init__.py b/comfy_cli/registry/__init__.py index 5e2c5e6..2e65c98 100644 --- a/comfy_cli/registry/__init__.py +++ b/comfy_cli/registry/__init__.py @@ -1,8 +1,7 @@ -from .api import RegistryAPI, upload_file_to_signed_url +from .api import RegistryAPI from .config_parser import extract_node_configuration, initialize_project_config -from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion -from .zip import zip_files +from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion, Node __all__ = [ "RegistryAPI", @@ -10,7 +9,6 @@ "PyProjectConfig", "PublishNodeVersionResponse", "NodeVersion", - "zip_files", - "upload_file_to_signed_url", + "Node", "initialize_project_config", ] diff --git a/comfy_cli/registry/api.py b/comfy_cli/registry/api.py index 6e71a90..68d4b96 100644 --- a/comfy_cli/registry/api.py +++ b/comfy_cli/registry/api.py @@ -1,12 +1,11 @@ +import logging import os import requests import json -from comfy_cli.registry.types import ( - PyProjectConfig, - PublishNodeVersionResponse, - NodeVersion, -) + +# Reduced global imports from comfy_cli.registry +from comfy_cli.registry.types import NodeVersion, Node class RegistryAPI: @@ -19,9 +18,7 @@ def determine_base_url(self): else: return "https://api-frontend-dev-qod3oz2v2q-uc.a.run.app" - def publish_node_version( - self, node_config: PyProjectConfig, token: str - ) -> PublishNodeVersionResponse: + def publish_node_version(self, node_config, token): """ Publishes a new version of a node. @@ -32,6 +29,9 @@ def publish_node_version( Returns: PublishNodeVersionResponse: The response object from the API server. """ + # Local import to prevent circular dependency + from comfy_cli.registry.types import PyProjectConfig, PublishNodeVersionResponse + url = f"{self.base_url}/publishers/{node_config.tool_comfy.publisher_id}/nodes/{node_config.project.name}/versions" headers = {"Content-Type": "application/json"} body = { @@ -53,40 +53,114 @@ def publish_node_version( if response.status_code == 201: data = response.json() - node_version = NodeVersion( - changelog=data["node_version"]["changelog"], - dependencies=data["node_version"]["dependencies"], - deprecated=data["node_version"]["deprecated"], - id=data["node_version"]["id"], - version=data["node_version"]["version"], - ) return PublishNodeVersionResponse( - node_version=node_version, signedUrl=data["signedUrl"] + node_version=map_node_version(data["node_version"]), + signedUrl=data["signedUrl"], ) else: raise Exception( f"Failed to publish node version: {response.status_code} {response.text}" ) + def list_all_nodes(self): + """ + Retrieves a list of all nodes and maps them to Node dataclass instances. + + Returns: + list: A list of Node instances. + """ + url = f"{self.base_url}/nodes" + response = requests.get(url) + if response.status_code == 200: + raw_nodes = response.json()["nodes"] + mapped_nodes = [map_node_to_node_class(node) for node in raw_nodes] + return mapped_nodes + else: + raise Exception( + f"Failed to retrieve nodes: {response.status_code} - {response.text}" + ) + + def install_node(self, node_id, version=None): + """ + Retrieves the node version for installation. + + Args: + node_id (str): The unique identifier of the node. + version (str, optional): Specific version of the node to retrieve. If omitted, the latest version is returned. + + Returns: + NodeVersion: Node version data or error message. + """ + if version is None: + url = f"{self.base_url}/nodes/{node_id}/install" + else: + url = f"{self.base_url}/nodes/{node_id}/install?version={version}" + + response = requests.get(url) + if response.status_code == 200: + # Convert the API response to a NodeVersion object + logging.debug(f"RegistryAPI install_node response: {response.json()}") + return map_node_version(response.json()) + else: + raise Exception( + f"Failed to install node: {response.status_code} - {response.text}" + ) + + +def map_node_version(api_node_version): + """ + Maps node version data from API response to NodeVersion dataclass. + + Args: + api_data (dict): The 'node_version' part of the API response. + + Returns: + NodeVersion: An instance of NodeVersion dataclass populated with data from the API. + """ + return NodeVersion( + changelog=api_node_version.get( + "changelog", "" + ), # Provide a default value if 'changelog' is missing + dependencies=api_node_version.get( + "dependencies", [] + ), # Provide a default empty list if 'dependencies' is missing + deprecated=api_node_version.get( + "deprecated", False + ), # Assume False if 'deprecated' is not specified + id=api_node_version[ + "id" + ], # 'id' should be mandatory; raise KeyError if missing + version=api_node_version[ + "version" + ], # 'version' should be mandatory; raise KeyError if missing + download_url=api_node_version.get( + "downloadUrl", "" + ), # Provide a default value if 'downloadUrl' is missing + ) + + +def map_node_to_node_class(api_node_data): + """ + Maps node data from API response to Node dataclass. + + Args: + api_node_data (dict): The node data from the API. -def upload_file_to_signed_url(signed_url: str, file_path: str): - try: - with open(file_path, "rb") as f: - headers = {"Content-Type": "application/gzip"} - response = requests.put(signed_url, data=f, headers=headers) - - # Simple success check - if response.status_code == 200: - print("Upload successful.") - else: - # Print a generic error message with status code and response text - print( - f"Upload failed with status code: {response.status_code}. Error: {response.text}" - ) - - except requests.exceptions.RequestException as e: - # Print error related to the HTTP request - print(f"An error occurred during the upload: {str(e)}") - except FileNotFoundError: - # Print file not found error - print(f"Error: The file {file_path} does not exist.") + Returns: + Node: An instance of Node dataclass populated with API data. + """ + return Node( + id=api_node_data["id"], + name=api_node_data["name"], + description=api_node_data["description"], + author=api_node_data.get("author"), + license=api_node_data.get("license"), + icon=api_node_data.get("icon"), + repository=api_node_data.get("repository"), + tags=api_node_data.get("tags", []), + latest_version=( + map_node_version(api_node_data["latest_version"]) + if "latest_version" in api_node_data + else None + ), + ) diff --git a/comfy_cli/registry/types.py b/comfy_cli/registry/types.py index a3c1b8d..1b355bd 100644 --- a/comfy_cli/registry/types.py +++ b/comfy_cli/registry/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List +from typing import List, Optional @dataclass @@ -9,6 +9,20 @@ class NodeVersion: deprecated: bool id: str version: str + download_url: str + + +@dataclass +class Node: + id: str + name: str + description: str + author: Optional[str] = None + license: Optional[str] = None + icon: Optional[str] = None + repository: Optional[str] = None + tags: List[str] = field(default_factory=list) + latest_version: Optional[NodeVersion] = None @dataclass diff --git a/comfy_cli/registry/zip.py b/comfy_cli/registry/zip.py deleted file mode 100644 index 285381c..0000000 --- a/comfy_cli/registry/zip.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import zipfile -import pathspec - - -def zip_files(zip_filename): - gitignore_path = ".gitignore" - if not os.path.exists(gitignore_path): - print(f"No .gitignore file found in {os.getcwd()}, proceeding without it.") - gitignore = "" - else: - with open(gitignore_path, "r") as file: - gitignore = file.read() - - spec = pathspec.PathSpec.from_lines("gitwildmatch", gitignore.splitlines()) - - with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: - for root, dirs, files in os.walk("."): - if ".git" in dirs: - dirs.remove(".git") - for file in files: - file_path = os.path.join(root, file) - if not spec.match_file(file_path): - zipf.write( - file_path, os.path.relpath(file_path, os.path.join(root, "..")) - ) - - -# TODO: check this code. this make slow down comfy-cli extremely -# zip_files("node.tar.gz") diff --git a/comfy_cli/ui.py b/comfy_cli/ui.py index 425a15a..b36f6ff 100644 --- a/comfy_cli/ui.py +++ b/comfy_cli/ui.py @@ -129,3 +129,13 @@ def display_table(data: List[Tuple], column_names: List[str], title: str = "") - table.add_row(*[str(item) for item in row]) console.print(table) + + +def display_error_message(message: str) -> None: + """ + Displays an error message to the user in red text. + + Args: + message (str): The error message to display. + """ + console.print(f"[red]{message}[/]")