Skip to content

Commit

Permalink
Add support for Flux models to img2img module
Browse files Browse the repository at this point in the history
  • Loading branch information
Woolverine94 committed Nov 8, 2024
1 parent ba8ad93 commit b34665c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
62 changes: 54 additions & 8 deletions ressources/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import PIL
import torch
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image, StableDiffusion3Img2ImgPipeline
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image, StableDiffusion3Img2ImgPipeline, FluxImg2ImgPipeline
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType
import random
Expand All @@ -19,6 +19,8 @@
# Gestion des modèles
model_path_img2img = "./models/Stable_Diffusion/"
os.makedirs(model_path_img2img, exist_ok=True)
model_path_flux_img2img = "./models/Flux/"
os.makedirs(model_path_flux_img2img, exist_ok=True)

model_list_img2img_local = []

Expand Down Expand Up @@ -105,6 +107,8 @@
"ariG23498/sd-3.5-merged",
"-[ 👏 🐢 SD3.5 Medium ]-",
"adamo1139/stable-diffusion-3.5-medium-ungated",
"-[ 🏆 🐢 Flux ]-",
"Freepik/flux.1-lite-8B-alpha",
"-[ 🏠 Local models ]-",
]

Expand Down Expand Up @@ -177,7 +181,7 @@ def image_img2img(
lora_weight_array = []

if lora_model_img2img != "":
if is_sd3(modelid_img2img) and lora_model_img2img == "ByteDance/Hyper-SD":
if (is_sd3(modelid_img2img) or is_flux(modelid_img2img)) and lora_model_img2img == "ByteDance/Hyper-SD":
lora_weight_img2img = 0.12
lora_array.append(f"{lora_model_img2img}")
lora_weight_array.append(float(lora_weight_img2img))
Expand Down Expand Up @@ -229,6 +233,12 @@ def image_img2img(
else :
is_bin_img2img: bool = False

if is_flux(modelid_img2img):
is_flux_img2img: bool = True
else :
is_flux_img2img: bool = False


if is_turbo_img2img and is_sd35_img2img:
is_turbo_img2img: bool = False

Expand Down Expand Up @@ -312,7 +322,24 @@ def image_img2img(
resume_download=True,
local_files_only=True if offline_test() else None
)

elif (is_flux_img2img == True):
if modelid_img2img[0:9] == "./models/" :
pipe_img2img = FluxImg2ImgPipeline.from_single_file(
modelid_img2img,
torch_dtype=model_arch,
use_safetensors=True if not is_bin_img2img else False,
# load_safety_checker=False if (nsfw_filter_final == None) else True,
local_files_only=True if offline_test() else None
)
else :
pipe_img2img = FluxImg2ImgPipeline.from_pretrained(
modelid_img2img,
cache_dir=model_path_flux_img2img,
torch_dtype=model_arch,
use_safetensors=True if not is_bin_img2img else False,
resume_download=True,
local_files_only=True if offline_test() else None
)
else :
if modelid_img2img[0:9] == "./models/" :
pipe_img2img = StableDiffusionImg2ImgPipeline.from_single_file(
Expand All @@ -338,7 +365,7 @@ def image_img2img(

pipe_img2img = schedulerer(pipe_img2img, sampler_img2img)
pipe_img2img.enable_attention_slicing("max")
if not is_sd3_img2img and not is_sd35_img2img and not is_sd35m_img2img:
if not is_sd3_img2img and not is_sd35_img2img and not is_sd35m_img2img and not is_flux_img2img:
tomesd.apply_patch(pipe_img2img, ratio=tkme_img2img)
if device_label_img2img == "cuda" :
pipe_img2img.enable_sequential_cpu_offload()
Expand All @@ -365,6 +392,8 @@ def image_img2img(
lora_model_path = model_path_lora_sd3
elif is_sd35_img2img or is_sd35m_img2img:
lora_model_path = model_path_lora_sd35
elif is_flux_img2img:
lora_model_path = model_path_lora_flux
else:
lora_model_path = model_path_lora_sd

Expand Down Expand Up @@ -425,13 +454,14 @@ def image_img2img(

if source_type_img2img == "sketch" :
dim_size=[512, 512]
elif ((is_xl_img2img == True) or (is_sd3_img2img == True)) and not (is_turbo_img2img == True) :
elif (is_xl_img2img or is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img or is_flux_img2img) and not (is_turbo_img2img == True) :
dim_size = correct_size(width_img2img, height_img2img, 1024)
else :
dim_size = correct_size(width_img2img, height_img2img, 512)
image_input = PIL.Image.open(img_img2img)
image_input = image_input.convert("RGB")
image_input = image_input.resize((dim_size[0], dim_size[1]))
print(dim_size[0], dim_size[1])

prompt_img2img = str(prompt_img2img)
negative_prompt_img2img = str(negative_prompt_img2img)
Expand All @@ -451,7 +481,7 @@ def image_img2img(
conditioning, pooled = compel(prompt_img2img)
neg_conditioning, neg_pooled = compel(negative_prompt_img2img)
[conditioning, neg_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, neg_conditioning])
elif is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img:
elif is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img or is_flux_img2img:
pass
else :
compel = Compel(tokenizer=pipe_img2img.tokenizer, text_encoder=pipe_img2img.text_encoder, truncate_long_prompts=False, device=device_img2img)
Expand Down Expand Up @@ -510,6 +540,22 @@ def image_img2img(
callback_on_step_end=check_img2img,
callback_on_step_end_tensor_inputs=['latents'],
).images
elif is_flux_img2img:
image = pipe_img2img(
image=image_input,
prompt=prompt_img2img,
width=dim_size[0],
height=dim_size[1],
max_sequence_length=512,
num_images_per_prompt=num_images_per_prompt_img2img,
guidance_scale=guidance_scale_img2img,
strength=denoising_strength_img2img,
num_inference_steps=num_inference_step_img2img,
timesteps=sampling_schedule_img2img,
generator=generator,
callback_on_step_end=check_img2img,
callback_on_step_end_tensor_inputs=['latents'],
).images
else :
image = pipe_img2img(
image=image_input,
Expand All @@ -527,7 +573,7 @@ def image_img2img(
).images

for j in range(len(image)):
if is_xl_img2img or is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img or (modelid_img2img[0:9] == "./models/"):
if is_xl_img2img or is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img or is_flux_img2img or (modelid_img2img[0:9] == "./models/"):
image[j] = safety_checker_sdxl(model_path_img2img, image[j], nsfw_filter)
savename = name_image()
if use_gfpgan_img2img == True :
Expand Down Expand Up @@ -563,7 +609,7 @@ def image_img2img(

exif_writer_png(reporting_img2img, final_image)

if is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img:
if is_sd3_img2img or is_sd35_img2img or is_sd35m_img2img or is_flux_img2img:
del nsfw_filter_final, feat_ex, pipe_img2img, generator, image_input, image
else:
del nsfw_filter_final, feat_ex, pipe_img2img, generator, image_input, compel, conditioning, neg_conditioning, image
Expand Down
10 changes: 8 additions & 2 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,8 @@ def change_model_type_img2img(model_img2img):
return sampler_img2img.update(value="Flow Match Euler"), width_img2img.update(), height_img2img.update(), num_inference_step_img2img.update(value=10), guidance_scale_img2img.update(value=4.5), lora_model_img2img.update(choices=list(lora_model_list(model_img2img).keys()), value="", interactive=True), txtinv_img2img.update(choices=list(txtinv_list(model_img2img).keys()), value="", interactive=False), negative_prompt_img2img.update(interactive=True)
elif is_sd3(model_img2img):
return sampler_img2img.update(value="Flow Match Euler"), width_img2img.update(), height_img2img.update(), num_inference_step_img2img.update(value=10), guidance_scale_img2img.update(value=7.5), lora_model_img2img.update(choices=list(lora_model_list(model_img2img).keys()), value="", interactive=True), txtinv_img2img.update(choices=list(txtinv_list(model_img2img).keys()), value="", interactive=False), negative_prompt_img2img.update(interactive=True)
elif is_flux(model_img2img):
return sampler_img2img.update(value="Flow Match Euler"), width_img2img.update(), height_img2img.update(), num_inference_step_img2img.update(value=10), guidance_scale_img2img.update(value=3.5), lora_model_img2img.update(choices=list(lora_model_list(model_img2img).keys()), value="", interactive=True), txtinv_img2img.update(choices=list(txtinv_list(model_img2img).keys()), value="", interactive=False), negative_prompt_img2img.update(interactive=True)
elif is_sdxl(model_img2img):
return sampler_img2img.update(value=list(SCHEDULER_MAPPING.keys())[0]), width_img2img.update(), height_img2img.update(), num_inference_step_img2img.update(value=10), guidance_scale_img2img.update(value=7.5), lora_model_img2img.update(choices=list(lora_model_list(model_img2img).keys()), value="", interactive=True), txtinv_img2img.update(choices=list(txtinv_list(model_img2img).keys()), value="", interactive=True), negative_prompt_img2img.update(interactive=True)
else:
Expand Down Expand Up @@ -1006,7 +1008,7 @@ def change_lora_model_img2img(model, lora_model, prompt, steps, cfg_scale, sampl
biniou_internal_previous_sampler_img2img = sampler
if (lora_model == "ByteDance/SDXL-Lightning") or (lora_model == "GraydientPlatformAPI/lightning-faster-lora"):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=4), guidance_scale_img2img.update(value=0.0), sampler_img2img.update(value="LCM")
elif ((lora_model == "ByteDance/Hyper-SD") or ("H1T/TCD-SD" in lora_model.upper())) and not is_sd3(model):
elif ((lora_model == "ByteDance/Hyper-SD") or ("H1T/TCD-SD" in lora_model.upper())) and not is_sd3(model) and not is_flux(model):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=4), guidance_scale_img2img.update(value=0.0), sampler_img2img.update(value="TCD")
elif (lora_model == "openskyml/lcm-lora-sdxl-turbo"):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=4), guidance_scale_img2img.update(value=0.0), sampler_img2img.update(value="LCM")
Expand All @@ -1022,8 +1024,12 @@ def change_lora_model_img2img(model, lora_model, prompt, steps, cfg_scale, sampl
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=6), guidance_scale_img2img.update(value=3.0), sampler_img2img.update(value="DPM++ SDE")
elif (lora_model == "mann-e/Mann-E_Turbo"):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=6), guidance_scale_img2img.update(value=3.0), sampler_img2img.update(value="DPM++ SDE Karras")
elif (lora_model == "ByteDance/Hyper-SD") and is_sd3(model):
elif (lora_model == "ByteDance/Hyper-SD") and (is_sd3(model)):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=4), guidance_scale_img2img.update(value=3.0), sampler_img2img.update(value="Flow Match Euler")
elif (lora_model == "ByteDance/Hyper-SD") and (is_flux(model)):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=8), guidance_scale_img2img.update(value=3.5), sampler_img2img.update(value="Flow Match Euler")
elif (lora_model == "Lingyuzhou/Hyper_Flux.1_Dev_4_step_Lora"):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(value=4), guidance_scale_img2img.update(value=3.5), sampler_img2img.update(value="Flow Match Euler")
else:
if ((biniou_internal_previous_model_img2img == "") and (biniou_internal_previous_steps_img2img == "") and (biniou_internal_previous_cfg_img2img == "") and (biniou_internal_previous_sampler_img2img == "")):
return prompt_img2img.update(value=lora_prompt_img2img), num_inference_step_img2img.update(), guidance_scale_img2img.update(), sampler_img2img.update()
Expand Down

0 comments on commit b34665c

Please sign in to comment.