Skip to content

Commit

Permalink
Merge branch 'main' into fix/config
Browse files Browse the repository at this point in the history
  • Loading branch information
ltdrdata committed Apr 29, 2024
2 parents 91c2232 + c478adf commit e0e7439
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 31 deletions.
5 changes: 3 additions & 2 deletions comfy_cli/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from comfy_cli.command import custom_nodes
from comfy_cli.command import install as install_inner
from comfy_cli.command import run as run_inner
from comfy_cli import constants, tracking
from comfy_cli import constants, tracking, logging
from comfy_cli.env_checker import EnvChecker
from comfy_cli.meta_data import MetadataManager
from comfy_cli import env_checker
Expand All @@ -29,11 +29,12 @@ def main():


def init():
# TODO(yoland): after this
# TODO(yoland): after this
metadata_manager = MetadataManager()
start_time = time.time()
metadata_manager.scan_dir()
end_time = time.time()
logging.setup_logging()

print(f"scan_dir took {end_time - start_time:.2f} seconds to run")

Expand Down
156 changes: 128 additions & 28 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,146 @@
import pathlib
from typing import List, Optional

import typer

from typing_extensions import Annotated

from comfy_cli import tracking
from comfy_cli import tracking, ui
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH

app = typer.Typer()


def get_workspace() -> pathlib.Path:
# TODO: placeholder logic right now, need to implement a config class that
# helps to get the workspace we are working with.
return pathlib.Path.cwd()


@app.command()
@tracking.track_command("model")
def download(
url: Annotated[
str,
typer.Option(
help="The URL from which to download the model",
show_default=False)],
relative_path: Annotated[
Optional[str],
typer.Option(
help="The relative path from the current workspace to install the model.",
show_default=True)] = DEFAULT_COMFY_MODEL_PATH
):
"""Download a model to a specified relative path if it is not already downloaded."""
# Convert relative path to absolute path based on the current working directory
local_filename = url.split("/")[-1]
local_filepath = get_workspace() / relative_path / local_filename

# Check if the file already exists
if local_filepath.exists():
typer.echo(f"File already exists: {local_filepath}")
return

# File does not exist, proceed with download
typer.echo(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath)


@app.command()
@tracking.track_command("model")
def get(url: Annotated[str, typer.Argument(help="The url of the model")],
path: Annotated[str, typer.Argument(help="The path to install the model.")]):
"""Download model"""
print(f"Start downloading url: ${url} into ${path}")
download_model(url, path)
def remove(
relative_path: str = typer.Option(
DEFAULT_COMFY_MODEL_PATH,
help="The relative path from the current workspace where the models are stored.",
show_default=True
),
model_names: Optional[List[str]] = typer.Option(
None,
help="List of model filenames to delete, separated by spaces",
show_default=False
)
):
"""Remove one or more downloaded models, either by specifying them directly or through an interactive selection."""
model_dir = get_workspace() / relative_path
available_models = list_models(model_dir)

if not available_models:
typer.echo("No models found to remove.")
return

to_delete = []
# Scenario #1: User provided model names to delete
if model_names:
# Validate and filter models to delete based on provided names
missing_models = []
for name in model_names:
model_path = model_dir / name
if model_path.exists():
to_delete.append(model_path)
else:
missing_models.append(name)

if missing_models:
typer.echo("The following models were not found and cannot be removed: " + ", ".join(missing_models))
if not to_delete:
return # Exit if no valid models were found

return

# Scenario #2: User did not provide model names, prompt for selection
else:
selections = ui.prompt_multi_select("Select models to delete:", [model.name for model in available_models])
if not selections:
typer.echo("No models selected for deletion.")
return
to_delete = [model_dir / selection for selection in selections]

# Confirm deletion
if to_delete and ui.confirm_action("Are you sure you want to delete the selected files?"):
for model_path in to_delete:
model_path.unlink()
typer.echo(f"Deleted: {model_path}")
else:
typer.echo("Deletion canceled.")


@app.command()
@tracking.track_command("model")
def remove():
"""Remove a custom node"""
# TODO
def list(
relative_path: str = typer.Option(
DEFAULT_COMFY_MODEL_PATH,
help="The relative path from the current workspace where the models are stored.",
show_default=True
)
):
"""Display a list of all models currently downloaded in a table format."""
model_dir = get_workspace() / relative_path
models = list_models(model_dir)

if not models:
typer.echo("No models found.")
return

# Prepare data for table display
data = [(model.name, f"{model.stat().st_size // 1024} KB") for model in models]
column_names = ["Model Name", "Size"]
ui.display_table(data, column_names)


def download_file(url: str, local_filepath: pathlib.Path):
"""Helper function to download a file."""

def download_model(url: str, path: str):
import httpx
import pathlib
from tqdm import tqdm

local_filename = url.split("/")[-1]
local_filepath = pathlib.Path(path, local_filename)
local_filepath.parent.mkdir(parents=True, exist_ok=True)

print(f"downloading {url} ...")
with httpx.stream("GET", url, follow_redirects=True) as stream:
total = int(stream.headers["Content-Length"])
with open(local_filepath, "wb") as f, tqdm(
total=total, unit_scale=True, unit_divisor=1024, unit="B"
) as progress:
num_bytes_downloaded = stream.num_bytes_downloaded
for data in stream.iter_bytes():
local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
total = int(response.headers["Content-Length"])
with open(local_filepath, "wb") as f:
for data in ui.show_progress(response.iter_bytes(), total):
f.write(data)
progress.update(
stream.num_bytes_downloaded - num_bytes_downloaded
)
num_bytes_downloaded = stream.num_bytes_downloaded


def list_models(path: pathlib.Path) -> list:
"""List all models in the specified directory."""
return [file for file in path.iterdir() if file.is_file()]
1 change: 1 addition & 0 deletions comfy_cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class OS(Enum):
COMFY_GITHUB_URL = 'https://github.com/comfyanonymous/ComfyUI'
COMFY_MANAGER_GITHUB_URL = 'https://github.com/ltdrdata/ComfyUI-Manager'

DEFAULT_COMFY_MODEL_PATH = "models/checkpoints"
DEFAULT_COMFY_WORKSPACE = {
OS.WINDOWS: os.path.join(os.path.expanduser('~'), 'Documents', 'ComfyUI'),
OS.MACOS: os.path.join(os.path.expanduser('~'), 'Documents', 'ComfyUI'),
Expand Down
35 changes: 35 additions & 0 deletions comfy_cli/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import os

"""
This module provides logging utilities for the CLI.
Note: we could potentially change the logging library or the way we log messages in the future.
Therefore, it's a good idea to encapsulate logging-related code in a separate module.
"""


def setup_logging():
# TODO: consider supporting different ways of outputting logs
# Note: by default, the log level is set to INFO
log_level = os.getenv("LOG_LEVEL", "WARN").upper()
logging.basicConfig(level=log_level,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')


def debug(message):
logging.debug(message)


def info(message):
logging.info(message)


def warning(message):
logging.warning(message)


def error(message):
logging.error(message)
# TODO: consider tracking errors to Mixpanel as well.
4 changes: 3 additions & 1 deletion comfy_cli/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from mixpanel import Mixpanel

from comfy_cli import logging

MIXPANEL_TOKEN = "93aeab8962b622d431ac19800ccc9f67"
DISABLE_TELEMETRY = os.getenv('DISABLE_TELEMETRY', False)
mp = Mixpanel(MIXPANEL_TOKEN) if MIXPANEL_TOKEN else None
Expand All @@ -29,7 +31,7 @@ def wrapper(*args, **kwargs):
command_name = f"{sub_command}:{func.__name__}" if sub_command is not None else func.__name__
input_arguments = kwargs # Example to pass all args and kwargs

print(f"Tracking command: {command_name} with arguments: {input_arguments}")
logging.debug(f"Tracking command: {command_name} with arguments: {input_arguments}")
track_event(command_name, properties=input_arguments)
return func(*args, **kwargs)

Expand Down
77 changes: 77 additions & 0 deletions comfy_cli/ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import questionary
from typing import List, Tuple

from rich.table import Table
from rich.console import Console
from rich.progress import Progress

console = Console()


def show_progress(iterable, total, description="Downloading..."):
"""
Display progress bar for iterable processes, especially useful for file downloads.
Each item in the iterable should be a chunk of data, and the progress bar will advance
by the size of each chunk.
Args:
iterable (Iterable[bytes]): An iterable that yields chunks of data.
total (int): The total size of the data (e.g., total number of bytes) to be downloaded.
description (str): Description text for the progress bar.
Yields:
bytes: Chunks of data as they are processed.
"""
with Progress() as progress:
task = progress.add_task(description, total=total)
for chunk in iterable:
yield chunk
progress.update(task, advance=len(chunk))


def prompt_multi_select(prompt: str, choices: List[str]) -> List[str]:
"""
Prompts the user to select multiple items from a list of choices.
Args:
prompt (str): The message to display to the user.
choices (List[str]): A list of choices from which the user can select.
Returns:
List[str]: A list of the selected items.
"""
selections = questionary.checkbox(prompt, choices=choices).ask() # returns list of selected items
return selections if selections else []


def confirm_action(prompt: str) -> bool:
"""
Prompts the user for confirmation before proceeding with an action.
Args:
prompt (str): The confirmation message to display to the user.
Returns:
bool: True if the user confirms, False otherwise.
"""
return questionary.confirm(prompt).ask() # returns True if confirmed, otherwise False


def display_table(data: List[Tuple], column_names: List[str], title: str = "") -> None:
"""
Displays a list of tuples in a table format using Rich.
Args:
data (List[Tuple]): A list of tuples, where each tuple represents a row.
column_names (List[str]): A list of column names for the table.
title (str): The title of the table.
"""
table = Table(title=title)

for name in column_names:
table.add_column(name, overflow="fold")

for row in data:
table.add_row(*[str(item) for item in row])

console.print(table)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"pyyaml",
"typing-extensions",
"mixpanel",
"questionary",
]

classifiers = [
Expand Down

0 comments on commit e0e7439

Please sign in to comment.