Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for installing node from registry #41

Merged
merged 6 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions comfy_cli/command/custom_nodes/command.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import typer
from typing_extensions import List, Annotated
from typing import Optional

from comfy_cli import tracking
import os
import pathlib
import subprocess
import sys
from rich import print
import uuid
from comfy_cli.config_manager import ConfigManager
from comfy_cli.workspace_manager import WorkspaceManager
from typing import Optional

import typer
from rich import print
from typing_extensions import List, Annotated

from comfy_cli import ui, logging, tracking
from comfy_cli.config_manager import ConfigManager
from comfy_cli.file_utils import (
download_file,
upload_file_to_signed_url,
zip_files,
extract_package_as_zip,
)
from comfy_cli.registry import (
RegistryAPI,
extract_node_configuration,
upload_file_to_signed_url,
zip_files,
initialize_project_config,
)
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
manager_app = typer.Typer()
Expand Down Expand Up @@ -559,3 +564,103 @@ def scaffold():
typer.echo(
"pyproject.toml created successfully. Defaults were filled in. Please check before publishing."
)


@app.command("registry-list", help="List all nodes in the registry", hidden=True)
@tracking.track_command("node")
def display_all_nodes():
"""
Display all nodes in the registry.
"""

nodes = None
try:
nodes = registry_api.list_all_nodes()
except Exception as e:
logging.error(f"Failed to fetch nodes from the registry: {str(e)}")
ui.display_error_message("Failed to fetch nodes from the registry.")

# Map Node data class instances to tuples for display
node_data = [
(
node.id,
node.name,
node.description,
node.author or "N/A",
node.license or "N/A",
", ".join(node.tags),
node.latest_version.version if node.latest_version else "N/A",
)
for node in nodes
]
ui.display_table(
node_data,
[
"ID",
"Name",
"Description",
"Author",
"License",
"Tags",
"Latest Version",
],
title="List of All Nodes",
)


@app.command("registry-install", help="Install a node from the registry", hidden=True)
@tracking.track_command("node")
def install(node_id: str, version: Optional[str] = None):
"""
Install a node from the registry.
Args:
node_id: The ID of the node to install.
version: The version of the node to install. If not provided, the latest version will be installed.
"""

# If the node ID is not provided, prompt the user to enter it
if not node_id:
node_id = typer.prompt("Enter the ID of the node you want to install")

node_version = None
try:
# Call the API to install the node
node_version = registry_api.install_node(node_id, version)
if not node_version.download_url:
logging.error("Download URL not provided from the registry.")
ui.display_error_message(f"Failed to download the custom node {node_id}.")
return

except Exception as e:
logging.error(
f"Encountered an error while installing the node. error: {str(e)}"
)
ui.display_error_message(f"Failed to download the custom node {node_id}.")
return

# Download the node archive
custom_nodes_path = pathlib.Path(workspace_manager.workspace_path) / "custom_nodes"
node_specific_path = custom_nodes_path / node_id # Subdirectory for the node
node_specific_path.mkdir(
parents=True, exist_ok=True
) # Create the directory if it doesn't exist

local_filename = node_specific_path / f"{node_id}-{node_version.version}.zip"
logging.debug(
f"Start downloading the node {node_id} version {node_version.version} to {local_filename}"
)
download_file(node_version.download_url, local_filename)

# Extract the downloaded archive to the custom_node directory on the workspace.
logging.debug(
f"Start extracting the node {node_id} version {node_version.version} to {custom_nodes_path}"
)
extract_package_as_zip(local_filename, node_specific_path)

# Delete the downloaded archive
logging.debug(f"Deleting the downloaded archive {local_filename}")
os.remove(local_filename)

logging.info(
f"Node {node_id} version {node_version.version} has been successfully installed."
)
46 changes: 1 addition & 45 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from comfy_cli import tracking, ui
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.file_utils import download_file, DownloadException
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
Expand All @@ -18,10 +19,6 @@ def get_workspace() -> pathlib.Path:
return pathlib.Path(workspace_manager.workspace_path)


class DownloadException(Exception):
pass


def potentially_strip_param_url(path_name: str) -> str:
path_name = path_name.split("?")[0]
return path_name
Expand Down Expand Up @@ -158,47 +155,6 @@ def list(
ui.display_table(data, column_names)


def guess_status_code_reason(status_code: int) -> str:
if status_code == 401:
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
elif status_code == 404:
return "Sorry, your model is in another castle (404)"
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

import httpx

local_filepath.parent.mkdir(
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
with open(local_filepath, "wb") as f:
for data in ui.show_progress(
response.iter_bytes(),
total,
description=f"Downloading {total//1024//1024} MB",
):
f.write(data)
except KeyboardInterrupt:
delete_eh = ui.prompt_confirm_action(
"Download interrupted, cleanup files?"
)
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
raise DownloadException(f"Failed to download file.\n{status_reason}")


def list_models(path: pathlib.Path) -> list:
"""List all models in the specified directory."""
return [file for file in path.iterdir() if file.is_file()]
108 changes: 108 additions & 0 deletions comfy_cli/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import pathlib
import zipfile

import requests
from pathspec import pathspec

from comfy_cli import ui


class DownloadException(Exception):
pass


def guess_status_code_reason(status_code: int) -> str:
if status_code == 401:
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
elif status_code == 404:
return "Sorry, your model is in another castle (404)"
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

import httpx

local_filepath.parent.mkdir(
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
with open(local_filepath, "wb") as f:
for data in ui.show_progress(
response.iter_bytes(),
total,
description=f"Downloading {total // 1024 // 1024} MB",
):
f.write(data)
except KeyboardInterrupt:
delete_eh = ui.prompt_confirm_action(
"Download interrupted, cleanup files?"
)
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
raise DownloadException(f"Failed to download file.\n{status_reason}")


def zip_files(zip_filename):
gitignore_path = ".gitignore"
if not os.path.exists(gitignore_path):
print(f"No .gitignore file found in {os.getcwd()}, proceeding without it.")
gitignore = ""
else:
with open(gitignore_path, "r") as file:
gitignore = file.read()

spec = pathspec.PathSpec.from_lines("gitwildmatch", gitignore.splitlines())

with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk("."):
if ".git" in dirs:
dirs.remove(".git")
for file in files:
file_path = os.path.join(root, file)
if not spec.match_file(file_path):
zipf.write(
file_path, os.path.relpath(file_path, os.path.join(root, ".."))
)


def upload_file_to_signed_url(signed_url: str, file_path: str):
try:
with open(file_path, "rb") as f:
headers = {"Content-Type": "application/gzip"}
response = requests.put(signed_url, data=f, headers=headers)

# Simple success check
if response.status_code == 200:
print("Upload successful.")
else:
# Print a generic error message with status code and response text
print(
f"Upload failed with status code: {response.status_code}. Error: {response.text}"
)

except requests.exceptions.RequestException as e:
# Print error related to the HTTP request
print(f"An error occurred during the upload: {str(e)}")
except FileNotFoundError:
# Print file not found error
print(f"Error: The file {file_path} does not exist.")


def extract_package_as_zip(file_path: pathlib.Path, extract_path: pathlib.Path):
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(extract_path)
print(f"Extracted zip file to {extract_path}")
except zipfile.BadZipFile:
print("File is not a zip or is corrupted.")
8 changes: 3 additions & 5 deletions comfy_cli/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from .api import RegistryAPI, upload_file_to_signed_url
from .api import RegistryAPI

from .config_parser import extract_node_configuration, initialize_project_config
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion
from .zip import zip_files
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion, Node

__all__ = [
"RegistryAPI",
"extract_node_configuration",
"PyProjectConfig",
"PublishNodeVersionResponse",
"NodeVersion",
"zip_files",
"upload_file_to_signed_url",
"Node",
"initialize_project_config",
]
Loading
Loading