diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..d9a07e0 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,27 @@ +name: Run pytest + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' # Follow the min version in pyproject.toml + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install -e . + + - name: Run tests + env: + PYTHONPATH: ${{ github.workspace }} + run: | + pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 98f552a..b67f843 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,4 +9,4 @@ repos: hooks: - id: pylint args: - - --disable=R,C,W,E0401 + - --disable=R,C,W,E0401 \ No newline at end of file diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index 9d0bff8..cdc0b09 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -1,6 +1,7 @@ import pathlib -from typing import List, Optional +from typing import List, Optional, Tuple +import requests import typer from typing_extensions import Annotated @@ -24,10 +25,91 @@ def potentially_strip_param_url(path_name: str) -> str: return path_name +# Convert relative path to absolute path based on the current working +# directory +def check_huggingface_url(url: str) -> bool: + return "huggingface.co" in url + + +def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]: + """ + Returns: + is_civitai_model_url: True if the url is a civitai model url + is_civitai_api_url: True if the url is a civitai api url + model_id: The model id or None if it's api url + version_id: The version id or None if it doesn't have version id info + """ + prefix = "civitai.com" + try: + if prefix in url: + # URL is civitai api download url: https://civitai.com/api/download/models/12345 + if "civitai.com/api/download" in url: + # This is a direct download link + version_id = url.strip("/").split("/")[-1] + return False, True, None, int(version_id) + + # URL is civitai web url (e.g. + # - https://civitai.com/models/43331 + # - https://civitai.com/models/43331/majicmix-realistic + subpath = url[url.find(prefix) + len(prefix) :].strip("/") + url_parts = subpath.split("?") + if len(url_parts) > 1: + model_id = url_parts[0].split("/")[1] + version_id = url_parts[1].split("=")[1] + return True, False, int(model_id), int(version_id) + else: + model_id = subpath.split("/")[1] + return True, False, int(model_id), None + except (ValueError, IndexError): + print("Error parsing Civitai model URL") + + return False, False, None, None + + +def request_civitai_model_version_api(version_id: int): + # Make a request to the Civitai API to get the model information + response = requests.get( + f"https://civitai.com/api/v1/model-versions/{version_id}", timeout=10 + ) + response.raise_for_status() # Raise an error for bad status codes + + model_data = response.json() + for file in model_data["files"]: + if file["primary"]: # Assuming we want the primary file + model_name = file["name"] + download_url = file["downloadUrl"] + return model_name, download_url + + +def request_civitai_model_api(model_id: int, version_id: int = None): + # Make a request to the Civitai API to get the model information + response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10) + response.raise_for_status() # Raise an error for bad status codes + + model_data = response.json() + + # If version_id is None, use the first version + if version_id is None: + version_id = model_data["modelVersions"][0]["id"] + + # Find the version with the specified version_id + for version in model_data["modelVersions"]: + if version["id"] == version_id: + # Get the model name and download URL from the files array + for file in version["files"]: + if file["primary"]: # Assuming we want the primary file + model_name = file["name"] + download_url = file["downloadUrl"] + return model_name, download_url + + # If the specified version_id is not found, raise an error + raise ValueError(f"Version ID {version_id} not found for model ID {model_id}") + + @app.command() @tracking.track_command("model") def download( - ctx: typer.Context, + _ctx: typer.Context, url: Annotated[ str, typer.Option( @@ -42,9 +124,22 @@ def download( ), ] = DEFAULT_COMFY_MODEL_PATH, ): - """Download a model to a specified relative path if it is not already downloaded.""" - # Convert relative path to absolute path based on the current working directory - local_filename = potentially_strip_param_url(url.split("/")[-1]) + + local_filename = None + + is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url( + url + ) + is_huggingface = False + if is_civitai_model_url: + local_filename, url = request_civitai_model_api(model_id, version_id) + elif is_civitai_api_url: + local_filename, url = request_civitai_model_version_api(version_id) + elif check_huggingface_url(url): + is_huggingface = True + local_filename = potentially_strip_param_url(url.split("/")[-1]) + else: + print("Model source is unknown") local_filename = ui.prompt_input( "Enter filename to save model as", default=local_filename ) diff --git a/pyproject.toml b/pyproject.toml index b7a280f..ea09c35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ comfycli = "comfy_cli.__main__:main" [project.urls] Repository = "https://github.com/drip-art/comfy-cli.git" -[tool.setuptools] -py-modules = ["comfy_cli"] +[tool.setuptools.packages.find] +where = ["."] +include = ["comfy_cli*"] diff --git a/requirement.txt b/requirements.txt similarity index 76% rename from requirement.txt rename to requirements.txt index 56be9f5..d9d37ce 100644 --- a/requirement.txt +++ b/requirements.txt @@ -4,7 +4,9 @@ GitPython requests pyyaml typing-extensions >= 4.7.0 +questionary mixpanel tomlkit pathspec -httpx \ No newline at end of file +httpx +packaging \ No newline at end of file diff --git a/tests/comfy_cli/command/models/test_models.py b/tests/comfy_cli/command/models/test_models.py new file mode 100644 index 0000000..94279e6 --- /dev/null +++ b/tests/comfy_cli/command/models/test_models.py @@ -0,0 +1,36 @@ +from comfy_cli.command.models.models import check_civitai_url + + +def test_valid_model_url(): + url = "https://civitai.com/models/43331" + assert check_civitai_url(url) == (True, False, 43331, None) + + +def test_valid_model_url_with_version(): + url = "https://civitai.com/models/43331/majicmix-realistic" + assert check_civitai_url(url) == (True, False, 43331, None) + + +def test_valid_model_url_with_query(): + url = "https://civitai.com/models/43331?version=12345" + assert check_civitai_url(url) == (True, False, 43331, 12345) + + +def test_valid_api_url(): + url = "https://civitai.com/api/download/models/67890" + assert check_civitai_url(url) == (False, True, None, 67890) + + +def test_invalid_url(): + url = "https://example.com/models/43331" + assert check_civitai_url(url) == (False, False, None, None) + + +def test_malformed_url(): + url = "https://civitai.com/models/" + assert check_civitai_url(url) == (False, False, None, None) + + +def test_malformed_query_url(): + url = "https://civitai.com/models/43331?version=" + assert check_civitai_url(url) == (False, False, None, None)