-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
248 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,146 @@ | ||
import pathlib | ||
from typing import List, Optional | ||
|
||
import typer | ||
|
||
from typing_extensions import Annotated | ||
|
||
from comfy_cli import tracking | ||
from comfy_cli import tracking, ui | ||
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH | ||
|
||
app = typer.Typer() | ||
|
||
|
||
def get_workspace() -> pathlib.Path: | ||
# TODO: placeholder logic right now, need to implement a config class that | ||
# helps to get the workspace we are working with. | ||
return pathlib.Path.cwd() | ||
|
||
|
||
@app.command() | ||
@tracking.track_command("model") | ||
def download( | ||
url: Annotated[ | ||
str, | ||
typer.Option( | ||
help="The URL from which to download the model", | ||
show_default=False)], | ||
relative_path: Annotated[ | ||
Optional[str], | ||
typer.Option( | ||
help="The relative path from the current workspace to install the model.", | ||
show_default=True)] = DEFAULT_COMFY_MODEL_PATH | ||
): | ||
"""Download a model to a specified relative path if it is not already downloaded.""" | ||
# Convert relative path to absolute path based on the current working directory | ||
local_filename = url.split("/")[-1] | ||
local_filepath = get_workspace() / relative_path / local_filename | ||
|
||
# Check if the file already exists | ||
if local_filepath.exists(): | ||
typer.echo(f"File already exists: {local_filepath}") | ||
return | ||
|
||
# File does not exist, proceed with download | ||
typer.echo(f"Start downloading URL: {url} into {local_filepath}") | ||
download_file(url, local_filepath) | ||
|
||
|
||
@app.command() | ||
@tracking.track_command("model") | ||
def get(url: Annotated[str, typer.Argument(help="The url of the model")], | ||
path: Annotated[str, typer.Argument(help="The path to install the model.")]): | ||
"""Download model""" | ||
print(f"Start downloading url: ${url} into ${path}") | ||
download_model(url, path) | ||
def remove( | ||
relative_path: str = typer.Option( | ||
DEFAULT_COMFY_MODEL_PATH, | ||
help="The relative path from the current workspace where the models are stored.", | ||
show_default=True | ||
), | ||
model_names: Optional[List[str]] = typer.Option( | ||
None, | ||
help="List of model filenames to delete, separated by spaces", | ||
show_default=False | ||
) | ||
): | ||
"""Remove one or more downloaded models, either by specifying them directly or through an interactive selection.""" | ||
model_dir = get_workspace() / relative_path | ||
available_models = list_models(model_dir) | ||
|
||
if not available_models: | ||
typer.echo("No models found to remove.") | ||
return | ||
|
||
to_delete = [] | ||
# Scenario #1: User provided model names to delete | ||
if model_names: | ||
# Validate and filter models to delete based on provided names | ||
missing_models = [] | ||
for name in model_names: | ||
model_path = model_dir / name | ||
if model_path.exists(): | ||
to_delete.append(model_path) | ||
else: | ||
missing_models.append(name) | ||
|
||
if missing_models: | ||
typer.echo("The following models were not found and cannot be removed: " + ", ".join(missing_models)) | ||
if not to_delete: | ||
return # Exit if no valid models were found | ||
|
||
return | ||
|
||
# Scenario #2: User did not provide model names, prompt for selection | ||
else: | ||
selections = ui.prompt_multi_select("Select models to delete:", [model.name for model in available_models]) | ||
if not selections: | ||
typer.echo("No models selected for deletion.") | ||
return | ||
to_delete = [model_dir / selection for selection in selections] | ||
|
||
# Confirm deletion | ||
if to_delete and ui.confirm_action("Are you sure you want to delete the selected files?"): | ||
for model_path in to_delete: | ||
model_path.unlink() | ||
typer.echo(f"Deleted: {model_path}") | ||
else: | ||
typer.echo("Deletion canceled.") | ||
|
||
|
||
@app.command() | ||
@tracking.track_command("model") | ||
def remove(): | ||
"""Remove a custom node""" | ||
# TODO | ||
def list( | ||
relative_path: str = typer.Option( | ||
DEFAULT_COMFY_MODEL_PATH, | ||
help="The relative path from the current workspace where the models are stored.", | ||
show_default=True | ||
) | ||
): | ||
"""Display a list of all models currently downloaded in a table format.""" | ||
model_dir = get_workspace() / relative_path | ||
models = list_models(model_dir) | ||
|
||
if not models: | ||
typer.echo("No models found.") | ||
return | ||
|
||
# Prepare data for table display | ||
data = [(model.name, f"{model.stat().st_size // 1024} KB") for model in models] | ||
column_names = ["Model Name", "Size"] | ||
ui.display_table(data, column_names) | ||
|
||
|
||
def download_file(url: str, local_filepath: pathlib.Path): | ||
"""Helper function to download a file.""" | ||
|
||
def download_model(url: str, path: str): | ||
import httpx | ||
import pathlib | ||
from tqdm import tqdm | ||
|
||
local_filename = url.split("/")[-1] | ||
local_filepath = pathlib.Path(path, local_filename) | ||
local_filepath.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
print(f"downloading {url} ...") | ||
with httpx.stream("GET", url, follow_redirects=True) as stream: | ||
total = int(stream.headers["Content-Length"]) | ||
with open(local_filepath, "wb") as f, tqdm( | ||
total=total, unit_scale=True, unit_divisor=1024, unit="B" | ||
) as progress: | ||
num_bytes_downloaded = stream.num_bytes_downloaded | ||
for data in stream.iter_bytes(): | ||
local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists | ||
|
||
with httpx.stream("GET", url, follow_redirects=True) as response: | ||
total = int(response.headers["Content-Length"]) | ||
with open(local_filepath, "wb") as f: | ||
for data in ui.show_progress(response.iter_bytes(), total): | ||
f.write(data) | ||
progress.update( | ||
stream.num_bytes_downloaded - num_bytes_downloaded | ||
) | ||
num_bytes_downloaded = stream.num_bytes_downloaded | ||
|
||
|
||
def list_models(path: pathlib.Path) -> list: | ||
"""List all models in the specified directory.""" | ||
return [file for file in path.iterdir() if file.is_file()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import logging | ||
import os | ||
|
||
""" | ||
This module provides logging utilities for the CLI. | ||
Note: we could potentially change the logging library or the way we log messages in the future. | ||
Therefore, it's a good idea to encapsulate logging-related code in a separate module. | ||
""" | ||
|
||
|
||
def setup_logging(): | ||
# TODO: consider supporting different ways of outputting logs | ||
# Note: by default, the log level is set to INFO | ||
log_level = os.getenv("LOG_LEVEL", "WARN").upper() | ||
logging.basicConfig(level=log_level, | ||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S') | ||
|
||
|
||
def debug(message): | ||
logging.debug(message) | ||
|
||
|
||
def info(message): | ||
logging.info(message) | ||
|
||
|
||
def warning(message): | ||
logging.warning(message) | ||
|
||
|
||
def error(message): | ||
logging.error(message) | ||
# TODO: consider tracking errors to Mixpanel as well. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import questionary | ||
from typing import List, Tuple | ||
|
||
from rich.table import Table | ||
from rich.console import Console | ||
from rich.progress import Progress | ||
|
||
console = Console() | ||
|
||
|
||
def show_progress(iterable, total, description="Downloading..."): | ||
""" | ||
Display progress bar for iterable processes, especially useful for file downloads. | ||
Each item in the iterable should be a chunk of data, and the progress bar will advance | ||
by the size of each chunk. | ||
Args: | ||
iterable (Iterable[bytes]): An iterable that yields chunks of data. | ||
total (int): The total size of the data (e.g., total number of bytes) to be downloaded. | ||
description (str): Description text for the progress bar. | ||
Yields: | ||
bytes: Chunks of data as they are processed. | ||
""" | ||
with Progress() as progress: | ||
task = progress.add_task(description, total=total) | ||
for chunk in iterable: | ||
yield chunk | ||
progress.update(task, advance=len(chunk)) | ||
|
||
|
||
def prompt_multi_select(prompt: str, choices: List[str]) -> List[str]: | ||
""" | ||
Prompts the user to select multiple items from a list of choices. | ||
Args: | ||
prompt (str): The message to display to the user. | ||
choices (List[str]): A list of choices from which the user can select. | ||
Returns: | ||
List[str]: A list of the selected items. | ||
""" | ||
selections = questionary.checkbox(prompt, choices=choices).ask() # returns list of selected items | ||
return selections if selections else [] | ||
|
||
|
||
def confirm_action(prompt: str) -> bool: | ||
""" | ||
Prompts the user for confirmation before proceeding with an action. | ||
Args: | ||
prompt (str): The confirmation message to display to the user. | ||
Returns: | ||
bool: True if the user confirms, False otherwise. | ||
""" | ||
return questionary.confirm(prompt).ask() # returns True if confirmed, otherwise False | ||
|
||
|
||
def display_table(data: List[Tuple], column_names: List[str], title: str = "") -> None: | ||
""" | ||
Displays a list of tuples in a table format using Rich. | ||
Args: | ||
data (List[Tuple]): A list of tuples, where each tuple represents a row. | ||
column_names (List[str]): A list of column names for the table. | ||
title (str): The title of the table. | ||
""" | ||
table = Table(title=title) | ||
|
||
for name in column_names: | ||
table.add_column(name, overflow="fold") | ||
|
||
for row in data: | ||
table.add_row(*[str(item) for item in row]) | ||
|
||
console.print(table) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ dependencies = [ | |
"pyyaml", | ||
"typing-extensions", | ||
"mixpanel", | ||
"questionary", | ||
] | ||
|
||
classifiers = [ | ||
|