Skip to content

Commit

Permalink
Refactoring + Complete installation process
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 committed May 11, 2024
1 parent 1698dbf commit 1252237
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 152 deletions.
93 changes: 43 additions & 50 deletions comfy_cli/command/custom_nodes/command.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import typer
from typing_extensions import List, Annotated
from typing import Optional

from comfy_cli import tracking, ui
import os
import pathlib
import subprocess
import sys
from rich import print
import uuid
from typing import Optional

from comfy_cli.command.models.models import download_file
from comfy_cli.config_manager import ConfigManager
from comfy_cli.workspace_manager import WorkspaceManager
import tarfile
import typer
from rich import print
from typing_extensions import List, Annotated

from comfy_cli import ui, logging, tracking
from comfy_cli.config_manager import ConfigManager
from comfy_cli.file_utils import (
download_file,
upload_file_to_signed_url,
zip_files,
extract_package_as_zip,
)
from comfy_cli.registry import (
RegistryAPI,
extract_node_configuration,
upload_file_to_signed_url,
zip_files,
initialize_project_config,
)
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
manager_app = typer.Typer()
Expand Down Expand Up @@ -565,12 +566,13 @@ def scaffold():
)


@app.command("registry-list-all", help="Init scaffolding for custom node")
@app.command("registry-list", help="List all nodes in the registry")
@tracking.track_command("node")
def registry_list_all():
def display_all_nodes():
"""
Fetch and display all nodes in a table format.
Display all nodes in the registry.
"""

try:
nodes = registry_api.list_all_nodes()
# Map Node data class instances to tuples for display
Expand Down Expand Up @@ -603,52 +605,43 @@ def registry_list_all():
print(f"[red]Error: {str(e)}[/red]")


@app.command("registry-install", help="Init scaffolding for custom node")
@app.command("registry-install", help="Install a node from the registry")
@tracking.track_command("node")
def registry_install():
def install(node_id: str = "comfyui-inspire-pack", version: str = "1.0.0"):
"""
Install a node from the registry.
Args:
node_id: The ID of the node to install.
version: The version of the node to install. If not provided, the latest version will be installed.
"""
try:

# If the node ID is not provided, prompt the user to enter it
if not node_id:
node_id = typer.prompt("Enter the ID of the node you want to install")
version = typer.prompt(
"Enter the version of the node you want to install (leave blank for latest)",
default="",
show_default=False,
)

node_version = None
try:
# Call the API to install the node
node_version = registry_api.install_node(node_id, version)
if node_version.download_url:
# Download the node archive
local_filename = pathlib.Path(
f"./downloads/{node_id}-{node_version.version}.tar.gz"
)
download_file(node_version.download_url, local_filename)

# Extract the downloaded archive
extract_tar_gz(local_filename)
print(
f"Node {node_id} version {node_version.version} installed successfully."
)
else:
if not node_version.download_url:
# TODO: print error message
print("Download URL not provided.")
return

except Exception as e:
# TODO: print error message
print(f"[red]Error: {str(e)}[/red]")
return

logging.debug(f"registry_install command - node version: {node_version}")
# Download the node archive
local_filename = pathlib.Path(
f"./downloads/{node_id}-{node_version.version}.tar.gz"
)
download_file(node_version.download_url, local_filename)

def extract_tar_gz(
tar_gz_path: pathlib.Path, extract_path: pathlib.Path = pathlib.Path(".")
):
"""
Extracts a tar.gz file to a specified directory using pathlib.
Args:
tar_gz_path (pathlib.Path): Path to the .tar.gz file.
extract_path (pathlib.Path): Directory to extract the files into.
"""
import tarfile

with tarfile.open(tar_gz_path, "r:gz") as tar:
tar.extractall(path=extract_path)
# Extract the downloaded archive to the custom_node directory on the workspace.
# workspace/custom_nodes
custom_nodes_path = pathlib.Path(workspace_manager.workspace_path) / "custom_nodes"
extract_package_as_zip(local_filename, custom_nodes_path)
print(f"Node {node_id} version {node_version.version} installed successfully.")
46 changes: 1 addition & 45 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from comfy_cli import tracking, ui
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.file_utils import download_file, DownloadException
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
Expand All @@ -18,10 +19,6 @@ def get_workspace() -> pathlib.Path:
return pathlib.Path(workspace_manager.workspace_path)


class DownloadException(Exception):
pass


def potentially_strip_param_url(path_name: str) -> str:
path_name = path_name.split("?")[0]
return path_name
Expand Down Expand Up @@ -158,47 +155,6 @@ def list(
ui.display_table(data, column_names)


def guess_status_code_reason(status_code: int) -> str:
if status_code == 401:
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
elif status_code == 404:
return "Sorry, your model is in another castle (404)"
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

import httpx

local_filepath.parent.mkdir(
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
with open(local_filepath, "wb") as f:
for data in ui.show_progress(
response.iter_bytes(),
total,
description=f"Downloading {total//1024//1024} MB",
):
f.write(data)
except KeyboardInterrupt:
delete_eh = ui.prompt_confirm_action(
"Download interrupted, cleanup files?"
)
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
raise DownloadException(f"Failed to download file.\n{status_reason}")


def list_models(path: pathlib.Path) -> list:
"""List all models in the specified directory."""
return [file for file in path.iterdir() if file.is_file()]
108 changes: 108 additions & 0 deletions comfy_cli/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import pathlib
import zipfile

import requests
from pathspec import pathspec

from comfy_cli import ui


class DownloadException(Exception):
pass


def guess_status_code_reason(status_code: int) -> str:
if status_code == 401:
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
elif status_code == 404:
return "Sorry, your model is in another castle (404)"
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

import httpx

local_filepath.parent.mkdir(
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
with open(local_filepath, "wb") as f:
for data in ui.show_progress(
response.iter_bytes(),
total,
description=f"Downloading {total // 1024 // 1024} MB",
):
f.write(data)
except KeyboardInterrupt:
delete_eh = ui.prompt_confirm_action(
"Download interrupted, cleanup files?"
)
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
raise DownloadException(f"Failed to download file.\n{status_reason}")


def zip_files(zip_filename):
gitignore_path = ".gitignore"
if not os.path.exists(gitignore_path):
print(f"No .gitignore file found in {os.getcwd()}, proceeding without it.")
gitignore = ""
else:
with open(gitignore_path, "r") as file:
gitignore = file.read()

spec = pathspec.PathSpec.from_lines("gitwildmatch", gitignore.splitlines())

with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk("."):
if ".git" in dirs:
dirs.remove(".git")
for file in files:
file_path = os.path.join(root, file)
if not spec.match_file(file_path):
zipf.write(
file_path, os.path.relpath(file_path, os.path.join(root, ".."))
)


def upload_file_to_signed_url(signed_url: str, file_path: str):
try:
with open(file_path, "rb") as f:
headers = {"Content-Type": "application/gzip"}
response = requests.put(signed_url, data=f, headers=headers)

# Simple success check
if response.status_code == 200:
print("Upload successful.")
else:
# Print a generic error message with status code and response text
print(
f"Upload failed with status code: {response.status_code}. Error: {response.text}"
)

except requests.exceptions.RequestException as e:
# Print error related to the HTTP request
print(f"An error occurred during the upload: {str(e)}")
except FileNotFoundError:
# Print file not found error
print(f"Error: The file {file_path} does not exist.")


def extract_package_as_zip(file_path: pathlib.Path, extract_path: pathlib.Path):
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(extract_path)
print(f"Extracted zip file to {extract_path}")
except zipfile.BadZipFile:
print("File is not a zip or is corrupted.")
5 changes: 1 addition & 4 deletions comfy_cli/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .api import RegistryAPI, upload_file_to_signed_url
from .api import RegistryAPI

from .config_parser import extract_node_configuration, initialize_project_config
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion, Node
from .zip import zip_files

__all__ = [
"RegistryAPI",
Expand All @@ -11,7 +10,5 @@
"PublishNodeVersionResponse",
"NodeVersion",
"Node",
"zip_files",
"upload_file_to_signed_url",
"initialize_project_config",
]
25 changes: 2 additions & 23 deletions comfy_cli/registry/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os

import requests
Expand Down Expand Up @@ -98,36 +99,14 @@ def install_node(self, node_id, version=None):
response = requests.get(url)
if response.status_code == 200:
# Convert the API response to a NodeVersion object
logging.debug(f"RegistryAPI install_node response: {response.json()}")
return map_node_version(response.json())
else:
raise Exception(
f"Failed to install node: {response.status_code} - {response.text}"
)


def upload_file_to_signed_url(signed_url: str, file_path: str):
try:
with open(file_path, "rb") as f:
headers = {"Content-Type": "application/gzip"}
response = requests.put(signed_url, data=f, headers=headers)

# Simple success check
if response.status_code == 200:
print("Upload successful.")
else:
# Print a generic error message with status code and response text
print(
f"Upload failed with status code: {response.status_code}. Error: {response.text}"
)

except requests.exceptions.RequestException as e:
# Print error related to the HTTP request
print(f"An error occurred during the upload: {str(e)}")
except FileNotFoundError:
# Print file not found error
print(f"Error: The file {file_path} does not exist.")


def map_node_version(api_node_version):
"""
Maps node version data from API response to NodeVersion dataclass.
Expand Down
Loading

0 comments on commit 1252237

Please sign in to comment.