diff --git a/comfy_cli/cmdline.py b/comfy_cli/cmdline.py index fd797b85..902a5436 100644 --- a/comfy_cli/cmdline.py +++ b/comfy_cli/cmdline.py @@ -11,7 +11,7 @@ from comfy_cli.command import custom_nodes from comfy_cli.command import install as install_inner from comfy_cli.command import run as run_inner -from comfy_cli import constants, tracking +from comfy_cli import constants, tracking, logging from comfy_cli.env_checker import EnvChecker from comfy_cli.meta_data import MetadataManager from comfy_cli import env_checker @@ -29,11 +29,12 @@ def main(): def init(): - # TODO(yoland): after this + # TODO(yoland): after this metadata_manager = MetadataManager() start_time = time.time() metadata_manager.scan_dir() end_time = time.time() + logging.setup_logging() print(f"scan_dir took {end_time - start_time:.2f} seconds to run") diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index b81cb14a..e75f8e3d 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -1,46 +1,146 @@ +import pathlib +from typing import List, Optional + import typer + from typing_extensions import Annotated -from comfy_cli import tracking +from comfy_cli import tracking, ui +from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH app = typer.Typer() +def get_workspace() -> pathlib.Path: + # TODO: placeholder logic right now, need to implement a config class that + # helps to get the workspace we are working with. + return pathlib.Path.cwd() + + +@app.command() +@tracking.track_command("model") +def download( + url: Annotated[ + str, + typer.Option( + help="The URL from which to download the model", + show_default=False)], + relative_path: Annotated[ + Optional[str], + typer.Option( + help="The relative path from the current workspace to install the model.", + show_default=True)] = DEFAULT_COMFY_MODEL_PATH +): + """Download a model to a specified relative path if it is not already downloaded.""" + # Convert relative path to absolute path based on the current working directory + local_filename = url.split("/")[-1] + local_filepath = get_workspace() / relative_path / local_filename + + # Check if the file already exists + if local_filepath.exists(): + typer.echo(f"File already exists: {local_filepath}") + return + + # File does not exist, proceed with download + typer.echo(f"Start downloading URL: {url} into {local_filepath}") + download_file(url, local_filepath) + + @app.command() @tracking.track_command("model") -def get(url: Annotated[str, typer.Argument(help="The url of the model")], - path: Annotated[str, typer.Argument(help="The path to install the model.")]): - """Download model""" - print(f"Start downloading url: ${url} into ${path}") - download_model(url, path) +def remove( + relative_path: str = typer.Option( + DEFAULT_COMFY_MODEL_PATH, + help="The relative path from the current workspace where the models are stored.", + show_default=True + ), + model_names: Optional[List[str]] = typer.Option( + None, + help="List of model filenames to delete, separated by spaces", + show_default=False + ) +): + """Remove one or more downloaded models, either by specifying them directly or through an interactive selection.""" + model_dir = get_workspace() / relative_path + available_models = list_models(model_dir) + + if not available_models: + typer.echo("No models found to remove.") + return + + to_delete = [] + # Scenario #1: User provided model names to delete + if model_names: + # Validate and filter models to delete based on provided names + missing_models = [] + for name in model_names: + model_path = model_dir / name + if model_path.exists(): + to_delete.append(model_path) + else: + missing_models.append(name) + + if missing_models: + typer.echo("The following models were not found and cannot be removed: " + ", ".join(missing_models)) + if not to_delete: + return # Exit if no valid models were found + + return + + # Scenario #2: User did not provide model names, prompt for selection + else: + selections = ui.prompt_multi_select("Select models to delete:", [model.name for model in available_models]) + if not selections: + typer.echo("No models selected for deletion.") + return + to_delete = [model_dir / selection for selection in selections] + + # Confirm deletion + if to_delete and ui.confirm_action("Are you sure you want to delete the selected files?"): + for model_path in to_delete: + model_path.unlink() + typer.echo(f"Deleted: {model_path}") + else: + typer.echo("Deletion canceled.") @app.command() @tracking.track_command("model") -def remove(): - """Remove a custom node""" - # TODO +def list( + relative_path: str = typer.Option( + DEFAULT_COMFY_MODEL_PATH, + help="The relative path from the current workspace where the models are stored.", + show_default=True + ) +): + """Display a list of all models currently downloaded in a table format.""" + model_dir = get_workspace() / relative_path + models = list_models(model_dir) + + if not models: + typer.echo("No models found.") + return + # Prepare data for table display + data = [(model.name, f"{model.stat().st_size // 1024} KB") for model in models] + column_names = ["Model Name", "Size"] + ui.display_table(data, column_names) + + +def download_file(url: str, local_filepath: pathlib.Path): + """Helper function to download a file.""" -def download_model(url: str, path: str): import httpx - import pathlib - from tqdm import tqdm - local_filename = url.split("/")[-1] - local_filepath = pathlib.Path(path, local_filename) - local_filepath.parent.mkdir(parents=True, exist_ok=True) - - print(f"downloading {url} ...") - with httpx.stream("GET", url, follow_redirects=True) as stream: - total = int(stream.headers["Content-Length"]) - with open(local_filepath, "wb") as f, tqdm( - total=total, unit_scale=True, unit_divisor=1024, unit="B" - ) as progress: - num_bytes_downloaded = stream.num_bytes_downloaded - for data in stream.iter_bytes(): + local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists + + with httpx.stream("GET", url, follow_redirects=True) as response: + total = int(response.headers["Content-Length"]) + with open(local_filepath, "wb") as f: + for data in ui.show_progress(response.iter_bytes(), total): f.write(data) - progress.update( - stream.num_bytes_downloaded - num_bytes_downloaded - ) - num_bytes_downloaded = stream.num_bytes_downloaded + + +def list_models(path: pathlib.Path) -> list: + """List all models in the specified directory.""" + return [file for file in path.iterdir() if file.is_file()] diff --git a/comfy_cli/constants.py b/comfy_cli/constants.py index a7433a76..8eb5093d 100644 --- a/comfy_cli/constants.py +++ b/comfy_cli/constants.py @@ -11,6 +11,7 @@ class OS(Enum): COMFY_GITHUB_URL = 'https://github.com/comfyanonymous/ComfyUI' COMFY_MANAGER_GITHUB_URL = 'https://github.com/ltdrdata/ComfyUI-Manager' +DEFAULT_COMFY_MODEL_PATH = "models/checkpoints" DEFAULT_COMFY_WORKSPACE = { OS.WINDOWS: os.path.join(os.path.expanduser('~'), 'Documents', 'ComfyUI'), OS.MACOS: os.path.join(os.path.expanduser('~'), 'Documents', 'ComfyUI'), diff --git a/comfy_cli/logging.py b/comfy_cli/logging.py new file mode 100644 index 00000000..fe12534c --- /dev/null +++ b/comfy_cli/logging.py @@ -0,0 +1,35 @@ +import logging +import os + +""" +This module provides logging utilities for the CLI. + +Note: we could potentially change the logging library or the way we log messages in the future. +Therefore, it's a good idea to encapsulate logging-related code in a separate module. +""" + + +def setup_logging(): + # TODO: consider supporting different ways of outputting logs + # Note: by default, the log level is set to INFO + log_level = os.getenv("LOG_LEVEL", "WARN").upper() + logging.basicConfig(level=log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + +def debug(message): + logging.debug(message) + + +def info(message): + logging.info(message) + + +def warning(message): + logging.warning(message) + + +def error(message): + logging.error(message) + # TODO: consider tracking errors to Mixpanel as well. diff --git a/comfy_cli/tracking.py b/comfy_cli/tracking.py index de4d7a98..d1edc0a5 100644 --- a/comfy_cli/tracking.py +++ b/comfy_cli/tracking.py @@ -4,6 +4,8 @@ from mixpanel import Mixpanel +from comfy_cli import logging + MIXPANEL_TOKEN = "93aeab8962b622d431ac19800ccc9f67" DISABLE_TELEMETRY = os.getenv('DISABLE_TELEMETRY', False) mp = Mixpanel(MIXPANEL_TOKEN) if MIXPANEL_TOKEN else None @@ -29,7 +31,7 @@ def wrapper(*args, **kwargs): command_name = f"{sub_command}:{func.__name__}" if sub_command is not None else func.__name__ input_arguments = kwargs # Example to pass all args and kwargs - print(f"Tracking command: {command_name} with arguments: {input_arguments}") + logging.debug(f"Tracking command: {command_name} with arguments: {input_arguments}") track_event(command_name, properties=input_arguments) return func(*args, **kwargs) diff --git a/comfy_cli/ui.py b/comfy_cli/ui.py new file mode 100644 index 00000000..2dc6b6bb --- /dev/null +++ b/comfy_cli/ui.py @@ -0,0 +1,77 @@ +import questionary +from typing import List, Tuple + +from rich.table import Table +from rich.console import Console +from rich.progress import Progress + +console = Console() + + +def show_progress(iterable, total, description="Downloading..."): + """ + Display progress bar for iterable processes, especially useful for file downloads. + Each item in the iterable should be a chunk of data, and the progress bar will advance + by the size of each chunk. + + Args: + iterable (Iterable[bytes]): An iterable that yields chunks of data. + total (int): The total size of the data (e.g., total number of bytes) to be downloaded. + description (str): Description text for the progress bar. + + Yields: + bytes: Chunks of data as they are processed. + """ + with Progress() as progress: + task = progress.add_task(description, total=total) + for chunk in iterable: + yield chunk + progress.update(task, advance=len(chunk)) + + +def prompt_multi_select(prompt: str, choices: List[str]) -> List[str]: + """ + Prompts the user to select multiple items from a list of choices. + + Args: + prompt (str): The message to display to the user. + choices (List[str]): A list of choices from which the user can select. + + Returns: + List[str]: A list of the selected items. + """ + selections = questionary.checkbox(prompt, choices=choices).ask() # returns list of selected items + return selections if selections else [] + + +def confirm_action(prompt: str) -> bool: + """ + Prompts the user for confirmation before proceeding with an action. + + Args: + prompt (str): The confirmation message to display to the user. + + Returns: + bool: True if the user confirms, False otherwise. + """ + return questionary.confirm(prompt).ask() # returns True if confirmed, otherwise False + + +def display_table(data: List[Tuple], column_names: List[str], title: str = "") -> None: + """ + Displays a list of tuples in a table format using Rich. + + Args: + data (List[Tuple]): A list of tuples, where each tuple represents a row. + column_names (List[str]): A list of column names for the table. + title (str): The title of the table. + """ + table = Table(title=title) + + for name in column_names: + table.add_column(name, overflow="fold") + + for row in data: + table.add_row(*[str(item) for item in row]) + + console.print(table) diff --git a/pyproject.toml b/pyproject.toml index 94197030..45f4e4e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "pyyaml", "typing-extensions", "mixpanel", + "questionary", ] classifiers = [