1
1
import pathlib
2
- from typing import List , Optional
2
+ from typing import List , Optional , Tuple
3
3
4
+ import requests
4
5
import typer
5
6
6
7
from typing_extensions import Annotated
@@ -24,10 +25,60 @@ def potentially_strip_param_url(path_name: str) -> str:
24
25
return path_name
25
26
26
27
28
+ # Convert relative path to absolute path based on the current working
29
+ # directory
30
+ def is_huggingface_model (url : str ) -> bool :
31
+ return "huggingface.co" in url
32
+
33
+
34
+ def is_civitai_model (url : str ) -> Tuple [bool , int , int ]:
35
+ prefix = "civitai.com"
36
+ try :
37
+ if prefix in url :
38
+ subpath = url [url .find (prefix ) + len (prefix ) :].strip ("/" )
39
+ url_parts = subpath .split ("?" )
40
+ if len (url_parts ) > 1 :
41
+ model_id = url_parts [0 ].split ("/" )[1 ]
42
+ version_id = url_parts [1 ].split ("=" )[1 ]
43
+ return True , int (model_id ), int (version_id )
44
+ else :
45
+ model_id = subpath .split ("/" )[1 ]
46
+ return True , int (model_id ), None
47
+ except ValueError :
48
+ print ("Error parsing Civitai model URL" )
49
+ pass
50
+ return False , None , None
51
+
52
+
53
+ def request_civitai_api (model_id : int , version_id : int = None ):
54
+ # Make a request to the Civitai API to get the model information
55
+ response = requests .get (f"https://civitai.com/api/v1/models/{ model_id } " , timeout = 10 )
56
+ response .raise_for_status () # Raise an error for bad status codes
57
+
58
+ model_data = response .json ()
59
+
60
+ # If version_id is None, use the first version
61
+ if version_id is None :
62
+ version_id = model_data ["modelVersions" ][0 ]["id" ]
63
+
64
+ # Find the version with the specified version_id
65
+ for version in model_data ["modelVersions" ]:
66
+ if version ["id" ] == version_id :
67
+ # Get the model name and download URL from the files array
68
+ for file in version ["files" ]:
69
+ if file ["primary" ]: # Assuming we want the primary file
70
+ model_name = file ["name" ]
71
+ download_url = file ["downloadUrl" ]
72
+ return model_name , download_url
73
+
74
+ # If the specified version_id is not found, raise an error
75
+ raise ValueError (f"Version ID { version_id } not found for model ID { model_id } " )
76
+
77
+
27
78
@app .command ()
28
79
@tracking .track_command ("model" )
29
80
def download (
30
- ctx : typer .Context ,
81
+ _ctx : typer .Context ,
31
82
url : Annotated [
32
83
str ,
33
84
typer .Option (
@@ -42,9 +93,18 @@ def download(
42
93
),
43
94
] = DEFAULT_COMFY_MODEL_PATH ,
44
95
):
45
- """Download a model to a specified relative path if it is not already downloaded."""
46
- # Convert relative path to absolute path based on the current working directory
47
- local_filename = potentially_strip_param_url (url .split ("/" )[- 1 ])
96
+
97
+ local_filename = None
98
+
99
+ is_civitai , model_id , version_id = is_civitai_model (url )
100
+ is_huggingface = False
101
+ if is_civitai :
102
+ local_filename , url = request_civitai_api (model_id , version_id )
103
+ elif is_huggingface_model (url ):
104
+ is_huggingface = True
105
+ local_filename = potentially_strip_param_url (url .split ("/" )[- 1 ])
106
+ else :
107
+ print ("Model source is unknown" )
48
108
local_filename = ui .prompt_input (
49
109
"Enter filename to save model as" , default = local_filename
50
110
)
0 commit comments