Skip to content

Commit

Permalink
Merge pull request #245 from Dartvauder/dev
Browse files Browse the repository at this point in the history
Update app.py
  • Loading branch information
Dartvauder authored Oct 4, 2024
2 parents a46111a + 2c31597 commit fd6a37e
Showing 1 changed file with 33 additions and 38 deletions.
71 changes: 33 additions & 38 deletions LaunchFile/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8623,47 +8623,42 @@ def display_output_file(text_file, image_file, video_file, audio_file, model3d_f
return text_files, image_files, video_files, audio_files, model3d_files, display_output_file


def download_model(model_name_llm, model_name_sd):
if not model_name_llm and not model_name_sd:
return "Please select a model to download"

if model_name_llm and model_name_sd:
return "Please select one model type for downloading"

if model_name_llm:
model_url = ""
if model_name_llm == "StarlingLM(Transformers7B)":
model_url = "https://huggingface.co/Nexusflow/Starling-LM-7B-beta"
elif model_name_llm == "OpenChat3.6(Llama8B.Q4)":
model_url = "https://huggingface.co/bartowski/openchat-3.6-8b-20240522-GGUF/resolve/main/openchat-3.6-8b-20240522-Q4_K_M.gguf"
model_path = os.path.join("inputs", "text", "llm_models", model_name_llm)

if model_url:
if model_name_llm == "StarlingLM(Transformers7B)":
Repo.clone_from(model_url, model_path)
def download_model(llm_model_url, sd_model_url):
if not llm_model_url and not sd_model_url:
return "Please enter at least one model URL to download"

messages = []

if llm_model_url:
try:
if "/" in llm_model_url and "blob/main" not in llm_model_url:
repo_name = llm_model_url.split("/")[-1]
model_path = os.path.join("inputs", "text", "llm_models", repo_name)
os.makedirs(model_path, exist_ok=True)
Repo.clone_from(f"https://huggingface.co/{llm_model_url}", model_path)
messages.append(f"LLM model repository {repo_name} downloaded successfully!")
else:
response = requests.get(model_url, allow_redirects=True)
file_name = llm_model_url.split("/")[-1]
model_path = os.path.join("inputs", "text", "llm_models", file_name)
response = requests.get(llm_model_url, allow_redirects=True)
with open(model_path, "wb") as file:
file.write(response.content)
return f"LLM model {model_name_llm} downloaded successfully!"
else:
return "Invalid LLM model name"

if model_name_sd:
model_url = ""
if model_name_sd == "Dreamshaper8(SD1.5)":
model_url = "https://huggingface.co/Lykon/DreamShaper/resolve/main/DreamShaper_8_pruned.safetensors"
elif model_name_sd == "RealisticVisionV4.0(SDXL)":
model_url = "https://huggingface.co/SG161222/RealVisXL_V4.0/resolve/main/RealVisXL_V4.0.safetensors"
model_path = os.path.join("inputs", "image", "sd_models", f"{model_name_sd}")

if model_url:
response = requests.get(model_url, allow_redirects=True)
messages.append(f"LLM model file {file_name} downloaded successfully!")
except Exception as e:
messages.append(f"Error downloading LLM model: {str(e)}")

if sd_model_url:
try:
file_name = sd_model_url.split("/")[-1]
model_path = os.path.join("inputs", "image", "sd_models", file_name)
response = requests.get(sd_model_url, allow_redirects=True)
with open(model_path, "wb") as file:
file.write(response.content)
return f"StableDiffusion model {model_name_sd} downloaded successfully!"
else:
return "Invalid StableDiffusion model name"
messages.append(f"StableDiffusion model file {file_name} downloaded successfully!")
except Exception as e:
messages.append(f"Error downloading StableDiffusion model: {str(e)}")

return "\n".join(messages)


def settings_interface(language, share_value, debug_value, monitoring_value, auto_launch, api_status, open_api, queue_max_size, status_update_rate, gradio_auth, server_name, server_port, hf_token, theme,
Expand Down Expand Up @@ -11348,8 +11343,8 @@ def reload_interface():
model_downloader_interface = gr.Interface(
fn=download_model,
inputs=[
gr.Dropdown(choices=[None, "StarlingLM(Transformers7B)", "OpenChat3.6(Llama8B.Q4)"], label=_("Download LLM model", lang), value=None),
gr.Dropdown(choices=[None, "Dreamshaper8(SD1.5)", "RealisticVisionV4.0(SDXL)"], label=_("Download StableDiffusion model", lang), value=None),
gr.Textbox(label=_("Download LLM model", lang), placeholder="repo-author/repo-name or https://huggingface.co/.../model.gguf"),
gr.Textbox(label=_("Download StableDiffusion model", lang), placeholder="https://huggingface.co/.../model.safetensors"),
],
outputs=[
gr.Textbox(label=_("Message", lang), type="text"),
Expand Down

0 comments on commit fd6a37e

Please sign in to comment.