From a4a99408c46159dd0207b77d68a3625aa20a73e4 Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+hongil0316@users.noreply.github.com> Date: Sat, 27 Apr 2024 23:07:52 -0400 Subject: [PATCH 1/3] Create a new logging & setup functions for standardization --- comfy_cli/cmdline.py | 3 ++- comfy_cli/command/models/models.py | 32 +++++++++++++++++++++++++-- comfy_cli/logging.py | 35 ++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 comfy_cli/logging.py diff --git a/comfy_cli/cmdline.py b/comfy_cli/cmdline.py index 17defc3..6a0371b 100644 --- a/comfy_cli/cmdline.py +++ b/comfy_cli/cmdline.py @@ -27,11 +27,12 @@ def main(): def init(): - # TODO(yoland): after this + # TODO(yoland): after this metadata_manager = MetadataManager() start_time = time.time() metadata_manager.scan_dir() end_time = time.time() + logging.setup_logging() print(f"scan_dir took {end_time - start_time:.2f} seconds to run") diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index b81cb14..48fd57c 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -1,7 +1,7 @@ import typer from typing_extensions import Annotated -from comfy_cli import tracking +from comfy_cli import tracking, logging app = typer.Typer() @@ -31,7 +31,7 @@ def download_model(url: str, path: str): local_filepath = pathlib.Path(path, local_filename) local_filepath.parent.mkdir(parents=True, exist_ok=True) - print(f"downloading {url} ...") + logging.debug(f"downloading {url} ...") with httpx.stream("GET", url, follow_redirects=True) as stream: total = int(stream.headers["Content-Length"]) with open(local_filepath, "wb") as f, tqdm( @@ -44,3 +44,31 @@ def download_model(url: str, path: str): stream.num_bytes_downloaded - num_bytes_downloaded ) num_bytes_downloaded = stream.num_bytes_downloaded + + # def download_model(url: str, path: str): + # # Set up logging to file + # logging.basicConfig(level=logging.INFO, filename='download.log', filemode='w', + # format='%(asctime)s - %(levelname)s - %(message)s') + # + # local_filename = url.split("/")[-1] + # local_filepath = pathlib.Path(path, local_filename) + # local_filepath.parent.mkdir(parents=True, exist_ok=True) + # + # # Log the URL being downloaded + # logging.info(f"Downloading {url} ...") + # + # with httpx.stream("GET", url, follow_redirects=True) as stream: + # total = int(stream.headers["Content-Length"]) + # with open(local_filepath, "wb") as f, tqdm( + # total=total, unit_scale=True, unit_divisor=1024, unit="B" + # ) as progress: + # num_bytes_downloaded = stream.num_bytes_downloaded + # for data in stream.iter_bytes(): + # f.write(data) + # progress.update( + # stream.num_bytes_downloaded - num_bytes_downloaded + # ) + # num_bytes_downloaded = stream.num_bytes_downloaded + # + # # Log the completion of the download + # logging.info(f"Download completed. File saved to {local_filepath}") diff --git a/comfy_cli/logging.py b/comfy_cli/logging.py new file mode 100644 index 0000000..1a1a68b --- /dev/null +++ b/comfy_cli/logging.py @@ -0,0 +1,35 @@ +import logging +import os + +""" +This module provides logging utilities for the CLI. + +Note: we could potentially change the logging library or the way we log messages in the future. +Therefore, it's a good idea to encapsulate logging-related code in a separate module. +""" + + +def setup_logging(): + # TODO: consider supporting different ways of outputting logs + # Note: by default, the log level is set to INFO + log_level = os.getenv("LOG_LEVEL", "INFO").upper() + logging.basicConfig(level=log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + +def debug(message): + logging.debug(message) + + +def info(message): + logging.info(message) + + +def warning(message): + logging.warning(message) + + +def error(message): + logging.error(message) + # TODO: consider tracking errors to Mixpanel as well. From 3a6000b5da922f914f8286a696effb6f2d8161d8 Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+hongil0316@users.noreply.github.com> Date: Sun, 28 Apr 2024 12:23:06 -0400 Subject: [PATCH 2/3] Add missing import statement for logging --- comfy_cli/cmdline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_cli/cmdline.py b/comfy_cli/cmdline.py index 6a0371b..b8cadd2 100644 --- a/comfy_cli/cmdline.py +++ b/comfy_cli/cmdline.py @@ -11,7 +11,7 @@ from comfy_cli.command import custom_nodes from comfy_cli.command import install as install_inner from comfy_cli.command import run as run_inner -from comfy_cli import constants, tracking +from comfy_cli import constants, tracking, logging from comfy_cli.env_checker import EnvChecker from comfy_cli.meta_data import MetadataManager from comfy_cli import env_checker From 607172fccb6c2774565e0dcf3826b67c47aca91b Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+hongil0316@users.noreply.github.com> Date: Sun, 28 Apr 2024 14:10:49 -0400 Subject: [PATCH 3/3] Implement logics to download/remove/list models + ui methods for consistency --- comfy_cli/command/models/models.py | 184 ++++++++++++++++++++--------- comfy_cli/constants.py | 1 + comfy_cli/logging.py | 2 +- comfy_cli/tracking.py | 4 +- comfy_cli/ui.py | 77 ++++++++++++ pyproject.toml | 2 + 6 files changed, 212 insertions(+), 58 deletions(-) create mode 100644 comfy_cli/ui.py diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index 48fd57c..e75f8e3 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -1,74 +1,146 @@ +import pathlib +from typing import List, Optional + import typer + from typing_extensions import Annotated -from comfy_cli import tracking, logging +from comfy_cli import tracking, ui +from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH app = typer.Typer() +def get_workspace() -> pathlib.Path: + # TODO: placeholder logic right now, need to implement a config class that + # helps to get the workspace we are working with. + return pathlib.Path.cwd() + + +@app.command() +@tracking.track_command("model") +def download( + url: Annotated[ + str, + typer.Option( + help="The URL from which to download the model", + show_default=False)], + relative_path: Annotated[ + Optional[str], + typer.Option( + help="The relative path from the current workspace to install the model.", + show_default=True)] = DEFAULT_COMFY_MODEL_PATH +): + """Download a model to a specified relative path if it is not already downloaded.""" + # Convert relative path to absolute path based on the current working directory + local_filename = url.split("/")[-1] + local_filepath = get_workspace() / relative_path / local_filename + + # Check if the file already exists + if local_filepath.exists(): + typer.echo(f"File already exists: {local_filepath}") + return + + # File does not exist, proceed with download + typer.echo(f"Start downloading URL: {url} into {local_filepath}") + download_file(url, local_filepath) + + @app.command() @tracking.track_command("model") -def get(url: Annotated[str, typer.Argument(help="The url of the model")], - path: Annotated[str, typer.Argument(help="The path to install the model.")]): - """Download model""" - print(f"Start downloading url: ${url} into ${path}") - download_model(url, path) +def remove( + relative_path: str = typer.Option( + DEFAULT_COMFY_MODEL_PATH, + help="The relative path from the current workspace where the models are stored.", + show_default=True + ), + model_names: Optional[List[str]] = typer.Option( + None, + help="List of model filenames to delete, separated by spaces", + show_default=False + ) +): + """Remove one or more downloaded models, either by specifying them directly or through an interactive selection.""" + model_dir = get_workspace() / relative_path + available_models = list_models(model_dir) + + if not available_models: + typer.echo("No models found to remove.") + return + + to_delete = [] + # Scenario #1: User provided model names to delete + if model_names: + # Validate and filter models to delete based on provided names + missing_models = [] + for name in model_names: + model_path = model_dir / name + if model_path.exists(): + to_delete.append(model_path) + else: + missing_models.append(name) + + if missing_models: + typer.echo("The following models were not found and cannot be removed: " + ", ".join(missing_models)) + if not to_delete: + return # Exit if no valid models were found + + return + + # Scenario #2: User did not provide model names, prompt for selection + else: + selections = ui.prompt_multi_select("Select models to delete:", [model.name for model in available_models]) + if not selections: + typer.echo("No models selected for deletion.") + return + to_delete = [model_dir / selection for selection in selections] + + # Confirm deletion + if to_delete and ui.confirm_action("Are you sure you want to delete the selected files?"): + for model_path in to_delete: + model_path.unlink() + typer.echo(f"Deleted: {model_path}") + else: + typer.echo("Deletion canceled.") @app.command() @tracking.track_command("model") -def remove(): - """Remove a custom node""" - # TODO +def list( + relative_path: str = typer.Option( + DEFAULT_COMFY_MODEL_PATH, + help="The relative path from the current workspace where the models are stored.", + show_default=True + ) +): + """Display a list of all models currently downloaded in a table format.""" + model_dir = get_workspace() / relative_path + models = list_models(model_dir) + + if not models: + typer.echo("No models found.") + return + # Prepare data for table display + data = [(model.name, f"{model.stat().st_size // 1024} KB") for model in models] + column_names = ["Model Name", "Size"] + ui.display_table(data, column_names) + + +def download_file(url: str, local_filepath: pathlib.Path): + """Helper function to download a file.""" -def download_model(url: str, path: str): import httpx - import pathlib - from tqdm import tqdm - local_filename = url.split("/")[-1] - local_filepath = pathlib.Path(path, local_filename) - local_filepath.parent.mkdir(parents=True, exist_ok=True) - - logging.debug(f"downloading {url} ...") - with httpx.stream("GET", url, follow_redirects=True) as stream: - total = int(stream.headers["Content-Length"]) - with open(local_filepath, "wb") as f, tqdm( - total=total, unit_scale=True, unit_divisor=1024, unit="B" - ) as progress: - num_bytes_downloaded = stream.num_bytes_downloaded - for data in stream.iter_bytes(): + local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists + + with httpx.stream("GET", url, follow_redirects=True) as response: + total = int(response.headers["Content-Length"]) + with open(local_filepath, "wb") as f: + for data in ui.show_progress(response.iter_bytes(), total): f.write(data) - progress.update( - stream.num_bytes_downloaded - num_bytes_downloaded - ) - num_bytes_downloaded = stream.num_bytes_downloaded - - # def download_model(url: str, path: str): - # # Set up logging to file - # logging.basicConfig(level=logging.INFO, filename='download.log', filemode='w', - # format='%(asctime)s - %(levelname)s - %(message)s') - # - # local_filename = url.split("/")[-1] - # local_filepath = pathlib.Path(path, local_filename) - # local_filepath.parent.mkdir(parents=True, exist_ok=True) - # - # # Log the URL being downloaded - # logging.info(f"Downloading {url} ...") - # - # with httpx.stream("GET", url, follow_redirects=True) as stream: - # total = int(stream.headers["Content-Length"]) - # with open(local_filepath, "wb") as f, tqdm( - # total=total, unit_scale=True, unit_divisor=1024, unit="B" - # ) as progress: - # num_bytes_downloaded = stream.num_bytes_downloaded - # for data in stream.iter_bytes(): - # f.write(data) - # progress.update( - # stream.num_bytes_downloaded - num_bytes_downloaded - # ) - # num_bytes_downloaded = stream.num_bytes_downloaded - # - # # Log the completion of the download - # logging.info(f"Download completed. File saved to {local_filepath}") + + +def list_models(path: pathlib.Path) -> list: + """List all models in the specified directory.""" + return [file for file in path.iterdir() if file.is_file()] diff --git a/comfy_cli/constants.py b/comfy_cli/constants.py index ccb8619..6426469 100644 --- a/comfy_cli/constants.py +++ b/comfy_cli/constants.py @@ -9,6 +9,7 @@ class OS(Enum): COMFY_GITHUB_URL = 'https://github.com/comfyanonymous/ComfyUI' COMFY_MANAGER_GITHUB_URL = 'https://github.com/ltdrdata/ComfyUI-Manager' +DEFAULT_COMFY_MODEL_PATH = "models/checkpoints" DEFAULT_COMFY_WORKSPACE = { OS.WINDOWS: os.path.join(os.path.expanduser('~'), 'Documents', 'ComfyUI'), OS.MACOS: os.path.join(os.path.expanduser('~'), 'Documents', 'ComfyUI'), diff --git a/comfy_cli/logging.py b/comfy_cli/logging.py index 1a1a68b..fe12534 100644 --- a/comfy_cli/logging.py +++ b/comfy_cli/logging.py @@ -12,7 +12,7 @@ def setup_logging(): # TODO: consider supporting different ways of outputting logs # Note: by default, the log level is set to INFO - log_level = os.getenv("LOG_LEVEL", "INFO").upper() + log_level = os.getenv("LOG_LEVEL", "WARN").upper() logging.basicConfig(level=log_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') diff --git a/comfy_cli/tracking.py b/comfy_cli/tracking.py index de4d7a9..d1edc0a 100644 --- a/comfy_cli/tracking.py +++ b/comfy_cli/tracking.py @@ -4,6 +4,8 @@ from mixpanel import Mixpanel +from comfy_cli import logging + MIXPANEL_TOKEN = "93aeab8962b622d431ac19800ccc9f67" DISABLE_TELEMETRY = os.getenv('DISABLE_TELEMETRY', False) mp = Mixpanel(MIXPANEL_TOKEN) if MIXPANEL_TOKEN else None @@ -29,7 +31,7 @@ def wrapper(*args, **kwargs): command_name = f"{sub_command}:{func.__name__}" if sub_command is not None else func.__name__ input_arguments = kwargs # Example to pass all args and kwargs - print(f"Tracking command: {command_name} with arguments: {input_arguments}") + logging.debug(f"Tracking command: {command_name} with arguments: {input_arguments}") track_event(command_name, properties=input_arguments) return func(*args, **kwargs) diff --git a/comfy_cli/ui.py b/comfy_cli/ui.py new file mode 100644 index 0000000..2dc6b6b --- /dev/null +++ b/comfy_cli/ui.py @@ -0,0 +1,77 @@ +import questionary +from typing import List, Tuple + +from rich.table import Table +from rich.console import Console +from rich.progress import Progress + +console = Console() + + +def show_progress(iterable, total, description="Downloading..."): + """ + Display progress bar for iterable processes, especially useful for file downloads. + Each item in the iterable should be a chunk of data, and the progress bar will advance + by the size of each chunk. + + Args: + iterable (Iterable[bytes]): An iterable that yields chunks of data. + total (int): The total size of the data (e.g., total number of bytes) to be downloaded. + description (str): Description text for the progress bar. + + Yields: + bytes: Chunks of data as they are processed. + """ + with Progress() as progress: + task = progress.add_task(description, total=total) + for chunk in iterable: + yield chunk + progress.update(task, advance=len(chunk)) + + +def prompt_multi_select(prompt: str, choices: List[str]) -> List[str]: + """ + Prompts the user to select multiple items from a list of choices. + + Args: + prompt (str): The message to display to the user. + choices (List[str]): A list of choices from which the user can select. + + Returns: + List[str]: A list of the selected items. + """ + selections = questionary.checkbox(prompt, choices=choices).ask() # returns list of selected items + return selections if selections else [] + + +def confirm_action(prompt: str) -> bool: + """ + Prompts the user for confirmation before proceeding with an action. + + Args: + prompt (str): The confirmation message to display to the user. + + Returns: + bool: True if the user confirms, False otherwise. + """ + return questionary.confirm(prompt).ask() # returns True if confirmed, otherwise False + + +def display_table(data: List[Tuple], column_names: List[str], title: str = "") -> None: + """ + Displays a list of tuples in a table format using Rich. + + Args: + data (List[Tuple]): A list of tuples, where each tuple represents a row. + column_names (List[str]): A list of column names for the table. + title (str): The title of the table. + """ + table = Table(title=title) + + for name in column_names: + table.add_column(name, overflow="fold") + + for row in data: + table.add_row(*[str(item) for item in row]) + + console.print(table) diff --git a/pyproject.toml b/pyproject.toml index e799820..45f4e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "requests", "pyyaml", "typing-extensions", + "mixpanel", + "questionary", ] classifiers = [