Skip to content

Commit

Permalink
Adding preliminary support for Stable Diffusion 3
Browse files Browse the repository at this point in the history
  • Loading branch information
Woolverine94 committed Jun 13, 2024
1 parent 5a73742 commit e697fe2
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 9 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

## Updates

* 🆕 **2024-06-14** : 🔥🔥🔥 ***Support for Stable Diffusion 3 !!!*** 🔥🔥🔥 > Adding preliminary support for model [v2ray/stable-diffusion-3-medium-diffusers](https://hf.co/v2ray/stable-diffusion-3-medium-diffusers) to module Stable Diffusion.

* 🆕 **2024-06-13** : 🔥 ***Support for Visionix-alpha*** 🔥 > Adding support for model [ehristoforu/Visionix-alpha](https://hf.co/ehristoforu/Visionix-alpha) to modules Stable Diffusion, Img2img, IP-Adapter, Controlnet, Photobooth and Text2Video-Zero.

* 🆕 **2024-06-10** : 🔥 ***Support for SPO-SDXL_4k-p_10ep*** 🔥 > Adding support for model [SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep](https://hf.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep) to modules Stable Diffusion, Img2img, IP-Adapter, Controlnet, Photobooth and Text2Video-Zero.
Expand Down
10 changes: 10 additions & 0 deletions ressources/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,16 @@ def is_sdxl(model):
is_sdxl_value = False
return is_sdxl_value

def is_sd3(model):
if (\
(model == "v2ray/stable-diffusion-3-medium-diffusers")\
):
is_sd3_value = True
else:
is_sd3_value = False
return is_sd3_value


def lora_model_list(model):
if is_sdxl(model):
model_path_lora = "./models/lora/SDXL"
Expand Down
2 changes: 2 additions & 0 deletions ressources/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
EDMDPMSolverMultistepScheduler,
EDMEulerScheduler,
TCDScheduler,
FlowMatchEulerDiscreteScheduler,
)

SCHEDULER_MAPPING = {
Expand Down Expand Up @@ -45,6 +46,7 @@
"EDM DPM++ 2M": EDMDPMSolverMultistepScheduler,
"EDM Euler": EDMEulerScheduler,
"TCD": TCDScheduler,
"Flow Match Euler": FlowMatchEulerDiscreteScheduler,
}

SCHEDULER_MAPPING_MUSICLDM = {
Expand Down
65 changes: 56 additions & 9 deletions ressources/txt2img_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# txt2img_sd.py
import gradio as gr
import os
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image, StableDiffusion3Pipeline
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType
import torch
Expand Down Expand Up @@ -33,6 +33,7 @@
"IDKiro/sdxs-512-0.9",
"sd-community/sdxl-flash",
"ehristoforu/Visionix-alpha",
"v2ray/stable-diffusion-3-medium-diffusers",
"RunDiffusion/Juggernaut-X-Hyper",
"cutycat2000x/InterDiffusion-4.0",
"RunDiffusion/Juggernaut-XL-Lightning",
Expand Down Expand Up @@ -126,6 +127,11 @@ def image_txt2img_sd(
else :
is_xl_txt2img_sd: bool = False

if is_sd3(modelid_txt2img_sd):
is_sd3_txt2img_sd: bool = True
else :
is_sd3_txt2img_sd: bool = False

if ("dataautogpt3/ProteusV0.4" in modelid_txt2img_sd) or (modelid_txt2img_sd == "RunDiffusion/Juggernaut-XL-Lightning") or (modelid_txt2img_sd == "RunDiffusion/Juggernaut-X-Hyper"):
is_bin_txt2img_sd: bool = True
else :
Expand All @@ -135,6 +141,8 @@ def image_txt2img_sd(
if is_sdxl(modelid_txt2img_sd):
sampling_schedule_txt2img_sd = AysSchedules["StableDiffusionXLTimesteps"]
sampler_txt2img_sd = "DPM++ SDE"
elif is_sd3(modelid_txt2img_sd):
pass
else:
sampling_schedule_txt2img_sd = AysSchedules["StableDiffusionTimesteps"]
sampler_txt2img_sd = "Euler"
Expand Down Expand Up @@ -184,6 +192,24 @@ def image_txt2img_sd(
resume_download=True,
local_files_only=True if offline_test() else None
)
elif (is_sd3_txt2img_sd == True) :
if modelid_txt2img_sd[0:9] == "./models/" :
pipe_txt2img_sd = StableDiffusion3Pipeline.from_single_file(
modelid_txt2img_sd,
torch_dtype=model_arch,
use_safetensors=True if not is_bin_txt2img_sd else False,
load_safety_checker=False if (nsfw_filter_final == None) else True,
local_files_only=True if offline_test() else None
)
else :
pipe_txt2img_sd = StableDiffusion3Pipeline.from_pretrained(
modelid_txt2img_sd,
cache_dir=model_path_txt2img_sd,
torch_dtype=model_arch,
use_safetensors=True if not is_bin_txt2img_sd else False,
resume_download=True,
local_files_only=True if offline_test() else None
)
else :
if modelid_txt2img_sd[0:9] == "./models/" :
pipe_txt2img_sd = StableDiffusionPipeline.from_single_file(
Expand Down Expand Up @@ -211,12 +237,14 @@ def image_txt2img_sd(
pipe_txt2img_sd = schedulerer(pipe_txt2img_sd, sampler_txt2img_sd)
# if lora_model_txt2img_sd == "":
pipe_txt2img_sd.enable_attention_slicing("max")
tomesd.apply_patch(pipe_txt2img_sd, ratio=tkme_txt2img_sd)
if not is_sd3_txt2img_sd:
tomesd.apply_patch(pipe_txt2img_sd, ratio=tkme_txt2img_sd)
if device_label_txt2img_sd == "cuda" :
pipe_txt2img_sd.enable_sequential_cpu_offload()
else :
pipe_txt2img_sd = pipe_txt2img_sd.to(device_txt2img_sd)
pipe_txt2img_sd.enable_vae_slicing()
if not is_sd3_txt2img_sd:
pipe_txt2img_sd.enable_vae_slicing()

if lora_model_txt2img_sd != "":
model_list_lora_txt2img_sd = lora_model_list(modelid_txt2img_sd)
Expand Down Expand Up @@ -306,16 +334,18 @@ def image_txt2img_sd(
conditioning, pooled = compel(prompt_txt2img_sd)
neg_conditioning, neg_pooled = compel(negative_prompt_txt2img_sd)
[conditioning, neg_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, neg_conditioning])
elif (is_sd3_txt2img_sd == True):
pass
else :
compel = Compel(tokenizer=pipe_txt2img_sd.tokenizer, text_encoder=pipe_txt2img_sd.text_encoder, truncate_long_prompts=False, device=device_txt2img_sd)
conditioning = compel.build_conditioning_tensor(prompt_txt2img_sd)
neg_conditioning = compel.build_conditioning_tensor(negative_prompt_txt2img_sd)
neg_conditioning = compel.build_conditioning_tensor(negative_prompt_txt2img_sd)
[conditioning, neg_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, neg_conditioning])

final_image = []
final_seed = []
for i in range (num_prompt_txt2img_sd):
if (is_xl_txt2img_sd == True) :
if (is_xl_txt2img_sd == True):
image = pipe_txt2img_sd(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
Expand All @@ -331,7 +361,21 @@ def image_txt2img_sd(
callback_on_step_end=check_txt2img_sd,
callback_on_step_end_tensor_inputs=['latents'],
).images
else :
elif (is_sd3_txt2img_sd == True):
image = pipe_txt2img_sd(
prompt=prompt_txt2img_sd,
negative_prompt=negative_prompt_txt2img_sd,
height=height_txt2img_sd,
width=width_txt2img_sd,
num_images_per_prompt=num_images_per_prompt_txt2img_sd,
num_inference_steps=num_inference_step_txt2img_sd,
timesteps=sampling_schedule_txt2img_sd,
guidance_scale=guidance_scale_txt2img_sd,
generator=generator[i],
callback_on_step_end=check_txt2img_sd,
callback_on_step_end_tensor_inputs=['latents'],
).images
else:
image = pipe_txt2img_sd(
prompt_embeds=conditioning,
negative_prompt_embeds=neg_conditioning,
Expand All @@ -348,7 +392,7 @@ def image_txt2img_sd(
).images

for j in range(len(image)):
if is_xl_txt2img_sd:
if is_xl_txt2img_sd or is_sd3_txt2img_sd:
image[j] = safety_checker_sdxl(model_path_txt2img_sd, image[j], nsfw_filter)
seed_id = random_seed + i*num_images_per_prompt_txt2img_sd + j if (seed_txt2img_sd == 0) else seed_txt2img_sd + i*num_images_per_prompt_txt2img_sd + j
savename = name_seeded_image(seed_id)
Expand Down Expand Up @@ -380,8 +424,11 @@ def image_txt2img_sd(
print(reporting_txt2img_sd)

exif_writer_png(reporting_txt2img_sd, final_image)

del nsfw_filter_final, feat_ex, pipe_txt2img_sd, generator, compel, conditioning, neg_conditioning, image

if is_sd3_txt2img_sd:
del nsfw_filter_final, feat_ex, pipe_txt2img_sd, generator, image
else:
del nsfw_filter_final, feat_ex, pipe_txt2img_sd, generator, compel, conditioning, neg_conditioning, image
clean_ram()

print(f">>>[Stable Diffusion 🖼️ ]: leaving module")
Expand Down
2 changes: 2 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ def change_model_type_txt2img_sd(model_txt2img_sd):
return sampler_txt2img_sd.update(value=list(SCHEDULER_MAPPING.keys())[0]), width_txt2img_sd.update(value=biniou_global_sd15_width), height_txt2img_sd.update(value=biniou_global_sd15_height), num_inference_step_txt2img_sd.update(value=10), guidance_scale_txt2img_sd.update(value=3.0), lora_model_txt2img_sd.update(choices=list(lora_model_list(model_txt2img_sd).keys()), value="", interactive=False), txtinv_txt2img_sd.update(choices=list(txtinv_list(model_txt2img_sd).keys()), value=""), negative_prompt_txt2img_sd.update(interactive=True)
elif is_sdxl(model_txt2img_sd):
return sampler_txt2img_sd.update(value=list(SCHEDULER_MAPPING.keys())[0]), width_txt2img_sd.update(value=biniou_global_sdxl_width), height_txt2img_sd.update(value=biniou_global_sdxl_height), num_inference_step_txt2img_sd.update(value=10), guidance_scale_txt2img_sd.update(value=7.0), lora_model_txt2img_sd.update(choices=list(lora_model_list(model_txt2img_sd).keys()), value="", interactive=True), txtinv_txt2img_sd.update(choices=list(txtinv_list(model_txt2img_sd).keys()), value=""), negative_prompt_txt2img_sd.update(interactive=True)
elif is_sd3(model_txt2img_sd):
return sampler_txt2img_sd.update(value="Flow Match Euler"), width_txt2img_sd.update(value=biniou_global_sdxl_width), height_txt2img_sd.update(value=biniou_global_sdxl_height), num_inference_step_txt2img_sd.update(value=20), guidance_scale_txt2img_sd.update(value=7.0), lora_model_txt2img_sd.update(choices=list(lora_model_list(model_txt2img_sd).keys()), value="", interactive=False), txtinv_txt2img_sd.update(choices=list(txtinv_list(model_txt2img_sd).keys()), value=""), negative_prompt_txt2img_sd.update(interactive=True)
else:
return sampler_txt2img_sd.update(value=list(SCHEDULER_MAPPING.keys())[0]), width_txt2img_sd.update(value=biniou_global_sd15_width), height_txt2img_sd.update(value=biniou_global_sd15_height), num_inference_step_txt2img_sd.update(value=10), guidance_scale_txt2img_sd.update(value=7.0), lora_model_txt2img_sd.update(choices=list(lora_model_list(model_txt2img_sd).keys()), value="", interactive=True), txtinv_txt2img_sd.update(choices=list(txtinv_list(model_txt2img_sd).keys()), value=""), negative_prompt_txt2img_sd.update(interactive=True)

Expand Down

0 comments on commit e697fe2

Please sign in to comment.