Skip to content

Commit cf5b661

Browse files
committed
Feat: add support for civitai url model download
1 parent 2506ffc commit cf5b661

File tree

1 file changed

+65
-5
lines changed

1 file changed

+65
-5
lines changed

comfy_cli/command/models/models.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pathlib
2-
from typing import List, Optional
2+
from typing import List, Optional, Tuple
33

4+
import requests
45
import typer
56

67
from typing_extensions import Annotated
@@ -24,10 +25,60 @@ def potentially_strip_param_url(path_name: str) -> str:
2425
return path_name
2526

2627

28+
# Convert relative path to absolute path based on the current working
29+
# directory
30+
def is_huggingface_model(url: str) -> bool:
31+
return "huggingface.co" in url
32+
33+
34+
def is_civitai_model(url: str) -> Tuple[bool, int, int]:
35+
prefix = "civitai.com"
36+
try:
37+
if prefix in url:
38+
subpath = url[url.find(prefix) + len(prefix) :].strip("/")
39+
url_parts = subpath.split("?")
40+
if len(url_parts) > 1:
41+
model_id = url_parts[0].split("/")[1]
42+
version_id = url_parts[1].split("=")[1]
43+
return True, int(model_id), int(version_id)
44+
else:
45+
model_id = subpath.split("/")[1]
46+
return True, int(model_id), None
47+
except ValueError:
48+
print("Error parsing Civitai model URL")
49+
pass
50+
return False, None, None
51+
52+
53+
def request_civitai_api(model_id: int, version_id: int = None):
54+
# Make a request to the Civitai API to get the model information
55+
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10)
56+
response.raise_for_status() # Raise an error for bad status codes
57+
58+
model_data = response.json()
59+
60+
# If version_id is None, use the first version
61+
if version_id is None:
62+
version_id = model_data["modelVersions"][0]["id"]
63+
64+
# Find the version with the specified version_id
65+
for version in model_data["modelVersions"]:
66+
if version["id"] == version_id:
67+
# Get the model name and download URL from the files array
68+
for file in version["files"]:
69+
if file["primary"]: # Assuming we want the primary file
70+
model_name = file["name"]
71+
download_url = file["downloadUrl"]
72+
return model_name, download_url
73+
74+
# If the specified version_id is not found, raise an error
75+
raise ValueError(f"Version ID {version_id} not found for model ID {model_id}")
76+
77+
2778
@app.command()
2879
@tracking.track_command("model")
2980
def download(
30-
ctx: typer.Context,
81+
_ctx: typer.Context,
3182
url: Annotated[
3283
str,
3384
typer.Option(
@@ -42,9 +93,18 @@ def download(
4293
),
4394
] = DEFAULT_COMFY_MODEL_PATH,
4495
):
45-
"""Download a model to a specified relative path if it is not already downloaded."""
46-
# Convert relative path to absolute path based on the current working directory
47-
local_filename = potentially_strip_param_url(url.split("/")[-1])
96+
97+
local_filename = None
98+
99+
is_civitai, model_id, version_id = is_civitai_model(url)
100+
is_huggingface = False
101+
if is_civitai:
102+
local_filename, url = request_civitai_api(model_id, version_id)
103+
elif is_huggingface_model(url):
104+
is_huggingface = True
105+
local_filename = potentially_strip_param_url(url.split("/")[-1])
106+
else:
107+
print("Model source is unknown")
48108
local_filename = ui.prompt_input(
49109
"Enter filename to save model as", default=local_filename
50110
)

0 commit comments

Comments
 (0)