diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index c9e062e32f..aa0a66f907 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -434,6 +434,14 @@ def shark_sd_fn( return (generated_imgs, "") +def unload_sd(): + print("Unloading models.") + import apps.shark_studio.web.utils.globals as global_obj + + global_obj.clear_cache() + gc.collect() + + def cancel_sd(): print("Inject call to cancel longer API calls.") return diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index fc018dbbfa..20330bcf75 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -19,6 +19,7 @@ from apps.shark_studio.api.sd import ( shark_sd_fn_dict_input, cancel_sd, + unload_sd, ) from apps.shark_studio.api.controlnet import ( cnet_preview, @@ -611,11 +612,9 @@ def base_model_changed(base_model_id): ) with gr.Row(): stable_diffusion = gr.Button("Start") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], + unload = gr.Button("Unload Models") + unload.click( + fn=unload_sd, queue=False, show_progress=False, )