Skip to content

Commit

Permalink
[WIP] Implementation for installing node from registry
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 committed May 10, 2024
1 parent a2fe71e commit 1698dbf
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 19 deletions.
95 changes: 94 additions & 1 deletion comfy_cli/command/custom_nodes/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
from typing_extensions import List, Annotated
from typing import Optional

from comfy_cli import tracking
from comfy_cli import tracking, ui
import os
import pathlib
import subprocess
import sys
from rich import print
import uuid

from comfy_cli.command.models.models import download_file
from comfy_cli.config_manager import ConfigManager
from comfy_cli.workspace_manager import WorkspaceManager
import tarfile

from comfy_cli.registry import (
RegistryAPI,
Expand Down Expand Up @@ -559,3 +563,92 @@ def scaffold():
typer.echo(
"pyproject.toml created successfully. Defaults were filled in. Please check before publishing."
)


@app.command("registry-list-all", help="Init scaffolding for custom node")
@tracking.track_command("node")
def registry_list_all():
"""
Fetch and display all nodes in a table format.
"""
try:
nodes = registry_api.list_all_nodes()
# Map Node data class instances to tuples for display
node_data = [
(
node.id,
node.name,
node.description,
node.author or "N/A",
node.license or "N/A",
", ".join(node.tags),
node.latest_version.version if node.latest_version else "N/A",
)
for node in nodes
]
ui.display_table(
node_data,
[
"ID",
"Name",
"Description",
"Author",
"License",
"Tags",
"Latest Version",
],
title="List of All Nodes",
)
except Exception as e:
print(f"[red]Error: {str(e)}[/red]")


@app.command("registry-install", help="Init scaffolding for custom node")
@tracking.track_command("node")
def registry_install():
"""
Install a node from the registry.
"""
try:
node_id = typer.prompt("Enter the ID of the node you want to install")
version = typer.prompt(
"Enter the version of the node you want to install (leave blank for latest)",
default="",
show_default=False,
)

# Call the API to install the node
node_version = registry_api.install_node(node_id, version)
if node_version.download_url:
# Download the node archive
local_filename = pathlib.Path(
f"./downloads/{node_id}-{node_version.version}.tar.gz"
)
download_file(node_version.download_url, local_filename)

# Extract the downloaded archive
extract_tar_gz(local_filename)
print(
f"Node {node_id} version {node_version.version} installed successfully."
)
else:
print("Download URL not provided.")

except Exception as e:
print(f"[red]Error: {str(e)}[/red]")


def extract_tar_gz(
tar_gz_path: pathlib.Path, extract_path: pathlib.Path = pathlib.Path(".")
):
"""
Extracts a tar.gz file to a specified directory using pathlib.
Args:
tar_gz_path (pathlib.Path): Path to the .tar.gz file.
extract_path (pathlib.Path): Directory to extract the files into.
"""
import tarfile

with tarfile.open(tar_gz_path, "r:gz") as tar:
tar.extractall(path=extract_path)

Check failure

Code scanning / CodeQL

Arbitrary file write during tarfile extraction High

This file extraction depends on a
potentially untrusted source
.
3 changes: 2 additions & 1 deletion comfy_cli/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .api import RegistryAPI, upload_file_to_signed_url

from .config_parser import extract_node_configuration, initialize_project_config
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion
from .types import PyProjectConfig, PublishNodeVersionResponse, NodeVersion, Node
from .zip import zip_files

__all__ = [
Expand All @@ -10,6 +10,7 @@
"PyProjectConfig",
"PublishNodeVersionResponse",
"NodeVersion",
"Node",
"zip_files",
"upload_file_to_signed_url",
"initialize_project_config",
Expand Down
127 changes: 111 additions & 16 deletions comfy_cli/registry/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import requests
import json
from comfy_cli.registry.types import (
PyProjectConfig,
PublishNodeVersionResponse,
NodeVersion,
)

# Reduced global imports from comfy_cli.registry
from comfy_cli.registry.types import NodeVersion, Node


class RegistryAPI:
Expand All @@ -19,9 +17,7 @@ def determine_base_url(self):
else:
return "https://api-frontend-dev-qod3oz2v2q-uc.a.run.app"

def publish_node_version(
self, node_config: PyProjectConfig, token: str
) -> PublishNodeVersionResponse:
def publish_node_version(self, node_config, token):
"""
Publishes a new version of a node.
Expand All @@ -32,6 +28,9 @@ def publish_node_version(
Returns:
PublishNodeVersionResponse: The response object from the API server.
"""
# Local import to prevent circular dependency
from comfy_cli.registry.types import PyProjectConfig, PublishNodeVersionResponse

url = f"{self.base_url}/publishers/{node_config.tool_comfy.publisher_id}/nodes/{node_config.project.name}/versions"
headers = {"Content-Type": "application/json"}
body = {
Expand All @@ -53,21 +52,58 @@ def publish_node_version(

if response.status_code == 201:
data = response.json()
node_version = NodeVersion(
changelog=data["node_version"]["changelog"],
dependencies=data["node_version"]["dependencies"],
deprecated=data["node_version"]["deprecated"],
id=data["node_version"]["id"],
version=data["node_version"]["version"],
)
return PublishNodeVersionResponse(
node_version=node_version, signedUrl=data["signedUrl"]
node_version=map_node_version(data["node_version"]),
signedUrl=data["signedUrl"],
)
else:
raise Exception(
f"Failed to publish node version: {response.status_code} {response.text}"
)

def list_all_nodes(self):
"""
Retrieves a list of all nodes and maps them to Node dataclass instances.
Returns:
list: A list of Node instances.
"""
url = f"{self.base_url}/nodes"
response = requests.get(url)
if response.status_code == 200:
raw_nodes = response.json()["nodes"]
mapped_nodes = [map_node_to_node_class(node) for node in raw_nodes]
return mapped_nodes
else:
raise Exception(
f"Failed to retrieve nodes: {response.status_code} - {response.text}"
)

def install_node(self, node_id, version=None):
"""
Retrieves the node version for installation.
Args:
node_id (str): The unique identifier of the node.
version (str, optional): Specific version of the node to retrieve. If omitted, the latest version is returned.
Returns:
NodeVersion: Node version data or error message.
"""
if version is None:
url = f"{self.base_url}/nodes/{node_id}/install"
else:
url = f"{self.base_url}/nodes/{node_id}/install?version={version}"

response = requests.get(url)
if response.status_code == 200:
# Convert the API response to a NodeVersion object
return map_node_version(response.json())
else:
raise Exception(
f"Failed to install node: {response.status_code} - {response.text}"
)


def upload_file_to_signed_url(signed_url: str, file_path: str):
try:
Expand All @@ -90,3 +126,62 @@ def upload_file_to_signed_url(signed_url: str, file_path: str):
except FileNotFoundError:
# Print file not found error
print(f"Error: The file {file_path} does not exist.")


def map_node_version(api_node_version):
"""
Maps node version data from API response to NodeVersion dataclass.
Args:
api_data (dict): The 'node_version' part of the API response.
Returns:
NodeVersion: An instance of NodeVersion dataclass populated with data from the API.
"""
return NodeVersion(
changelog=api_node_version.get(
"changelog", ""
), # Provide a default value if 'changelog' is missing
dependencies=api_node_version.get(
"dependencies", []
), # Provide a default empty list if 'dependencies' is missing
deprecated=api_node_version.get(
"deprecated", False
), # Assume False if 'deprecated' is not specified
id=api_node_version[
"id"
], # 'id' should be mandatory; raise KeyError if missing
version=api_node_version[
"version"
], # 'version' should be mandatory; raise KeyError if missing
download_url=api_node_version.get(
"downloadUrl", ""
), # Provide a default value if 'downloadUrl' is missing
)


def map_node_to_node_class(api_node_data):
"""
Maps node data from API response to Node dataclass.
Args:
api_node_data (dict): The node data from the API.
Returns:
Node: An instance of Node dataclass populated with API data.
"""
return Node(
id=api_node_data["id"],
name=api_node_data["name"],
description=api_node_data["description"],
author=api_node_data.get("author"),
license=api_node_data.get("license"),
icon=api_node_data.get("icon"),
repository=api_node_data.get("repository"),
tags=api_node_data.get("tags", []),
latest_version=(
map_node_version(api_node_data["latest_version"])
if "latest_version" in api_node_data
else None
),
)
16 changes: 15 additions & 1 deletion comfy_cli/registry/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List
from typing import List, Optional


@dataclass
Expand All @@ -9,6 +9,20 @@ class NodeVersion:
deprecated: bool
id: str
version: str
download_url: str


@dataclass
class Node:
id: str
name: str
description: str
author: Optional[str] = None
license: Optional[str] = None
icon: Optional[str] = None
repository: Optional[str] = None
tags: List[str] = field(default_factory=list)
latest_version: Optional[NodeVersion] = None


@dataclass
Expand Down

0 comments on commit 1698dbf

Please sign in to comment.