Skip to content

Commit

Permalink
Merge branch 'main' into dt-install-script
Browse files Browse the repository at this point in the history
  • Loading branch information
ltdrdata committed May 12, 2024
2 parents 9dadfc4 + 6b4786c commit e82ba8c
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 128 deletions.
125 changes: 115 additions & 10 deletions comfy_cli/command/custom_nodes/command.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import typer
from typing_extensions import List, Annotated
from typing import Optional

from comfy_cli import tracking
import os
import pathlib
import subprocess
import sys
from rich import print
import uuid
from comfy_cli.config_manager import ConfigManager
from comfy_cli.workspace_manager import WorkspaceManager
from typing import Optional

import typer
from rich import print
from typing_extensions import List, Annotated

from comfy_cli import ui, logging, tracking
from comfy_cli.config_manager import ConfigManager
from comfy_cli.file_utils import (
download_file,
upload_file_to_signed_url,
zip_files,
extract_package_as_zip,
)
from comfy_cli.registry import (
RegistryAPI,
extract_node_configuration,
upload_file_to_signed_url,
zip_files,
initialize_project_config,
)
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
manager_app = typer.Typer()
Expand Down Expand Up @@ -559,3 +564,103 @@ def scaffold():
typer.echo(
"pyproject.toml created successfully. Defaults were filled in. Please check before publishing."
)


@app.command("registry-list", help="List all nodes in the registry", hidden=True)
@tracking.track_command("node")
def display_all_nodes():
"""
Display all nodes in the registry.
"""

nodes = None
try:
nodes = registry_api.list_all_nodes()
except Exception as e:
logging.error(f"Failed to fetch nodes from the registry: {str(e)}")
ui.display_error_message("Failed to fetch nodes from the registry.")

# Map Node data class instances to tuples for display
node_data = [
(
node.id,
node.name,
node.description,
node.author or "N/A",
node.license or "N/A",
", ".join(node.tags),
node.latest_version.version if node.latest_version else "N/A",
)
for node in nodes
]
ui.display_table(
node_data,
[
"ID",
"Name",
"Description",
"Author",
"License",
"Tags",
"Latest Version",
],
title="List of All Nodes",
)


@app.command("registry-install", help="Install a node from the registry", hidden=True)
@tracking.track_command("node")
def install(node_id: str, version: Optional[str] = None):
"""
Install a node from the registry.
Args:
node_id: The ID of the node to install.
version: The version of the node to install. If not provided, the latest version will be installed.
"""

# If the node ID is not provided, prompt the user to enter it
if not node_id:
node_id = typer.prompt("Enter the ID of the node you want to install")

node_version = None
try:
# Call the API to install the node
node_version = registry_api.install_node(node_id, version)
if not node_version.download_url:
logging.error("Download URL not provided from the registry.")
ui.display_error_message(f"Failed to download the custom node {node_id}.")
return

except Exception as e:
logging.error(
f"Encountered an error while installing the node. error: {str(e)}"
)
ui.display_error_message(f"Failed to download the custom node {node_id}.")
return

# Download the node archive
custom_nodes_path = pathlib.Path(workspace_manager.workspace_path) / "custom_nodes"
node_specific_path = custom_nodes_path / node_id # Subdirectory for the node
node_specific_path.mkdir(
parents=True, exist_ok=True
) # Create the directory if it doesn't exist

local_filename = node_specific_path / f"{node_id}-{node_version.version}.zip"
logging.debug(
f"Start downloading the node {node_id} version {node_version.version} to {local_filename}"
)
download_file(node_version.download_url, local_filename)

# Extract the downloaded archive to the custom_node directory on the workspace.
logging.debug(
f"Start extracting the node {node_id} version {node_version.version} to {custom_nodes_path}"
)
extract_package_as_zip(local_filename, node_specific_path)

# Delete the downloaded archive
logging.debug(f"Deleting the downloaded archive {local_filename}")
os.remove(local_filename)

logging.info(
f"Node {node_id} version {node_version.version} has been successfully installed."
)
46 changes: 1 addition & 45 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from comfy_cli import tracking, ui
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.file_utils import download_file, DownloadException
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
Expand All @@ -18,10 +19,6 @@ def get_workspace() -> pathlib.Path:
return pathlib.Path(workspace_manager.workspace_path)


class DownloadException(Exception):
pass


def potentially_strip_param_url(path_name: str) -> str:
path_name = path_name.split("?")[0]
return path_name
Expand Down Expand Up @@ -158,47 +155,6 @@ def list(
ui.display_table(data, column_names)


def guess_status_code_reason(status_code: int) -> str:
if status_code == 401:
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
elif status_code == 404:
return "Sorry, your model is in another castle (404)"
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

import httpx

local_filepath.parent.mkdir(
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
with open(local_filepath, "wb") as f:
for data in ui.show_progress(
response.iter_bytes(),
total,
description=f"Downloading {total//1024//1024} MB",
):
f.write(data)
except KeyboardInterrupt:
delete_eh = ui.prompt_confirm_action(
"Download interrupted, cleanup files?"
)
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
raise DownloadException(f"Failed to download file.\n{status_reason}")


def list_models(path: pathlib.Path) -> list:
"""List all models in the specified directory."""
return [file for file in path.iterdir() if file.is_file()]
108 changes: 108 additions & 0 deletions comfy_cli/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import pathlib
import zipfile

import requests
from pathspec import pathspec

from comfy_cli import ui


class DownloadException(Exception):
pass


def guess_status_code_reason(status_code: int) -> str:
if status_code == 401:
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
elif status_code == 404:
return "Sorry, your model is in another castle (404)"
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

import httpx

local_filepath.parent.mkdir(
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
with open(local_filepath, "wb") as f:
for data in ui.show_progress(
response.iter_bytes(),
total,
description=f"Downloading {total // 1024 // 1024} MB",
):
f.write(data)
except KeyboardInterrupt:
delete_eh = ui.prompt_confirm_action(
"Download interrupted, cleanup files?"
)
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
raise DownloadException(f"Failed to download file.\n{status_reason}")


def zip_files(zip_filename):
gitignore_path = ".gitignore"
if not os.path.exists(gitignore_path):
print(f"No .gitignore file found in {os.getcwd()}, proceeding without it.")
gitignore = ""
else:
with open(gitignore_path, "r") as file:
gitignore = file.read()

spec = pathspec.PathSpec.from_lines("gitwildmatch", gitignore.splitlines())

with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk("."):
if ".git" in dirs:
dirs.remove(".git")
for file in files:
file_path = os.path.join(root, file)
if not spec.match_file(file_path):
zipf.write(
file_path, os.path.relpath(file_path, os.path.join(root, ".."))
)


def upload_file_to_signed_url(signed_url: str, file_path: str):
try:
with open(file_path, "rb") as f:
headers = {"Content-Type": "application/gzip"}
response = requests.put(signed_url, data=f, headers=headers)

# Simple success check
if response.status_code == 200:
print("Upload successful.")
else:
# Print a generic error message with status code and response text
print(
f"Upload failed with status code: {response.status_code}. Error: {response.text}"
)

except requests.exceptions.RequestException as e:
# Print error related to the HTTP request
print(f"An error occurred during the upload: {str(e)}")
except FileNotFoundError:
# Print file not found error
print(f"Error: The file {file_path} does not exist.")


def extract_package_as_zip(file_path: pathlib.Path, extract_path: pathlib.Path):
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(extract_path)
print(f"Extracted zip file to {extract_path}")
except zipfile.BadZipFile:
print("File is not a zip or is corrupted.")
8 changes: 3 additions & 5 deletions comfy_cli/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from .api import RegistryAPI, upload_file_to_signed_url
from .api import RegistryAPI

from .config_parser import extract_node_configuration, initialize_project_config
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion
from .zip import zip_files
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion, Node

__all__ = [
"RegistryAPI",
"extract_node_configuration",
"PyProjectConfig",
"PublishNodeVersionResponse",
"NodeVersion",
"zip_files",
"upload_file_to_signed_url",
"Node",
"initialize_project_config",
]
Loading

0 comments on commit e82ba8c

Please sign in to comment.