-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
你好,我参考了 https://github.com/inclusionAI/TwinFlow/tree/main/src 的lora使用介绍去蒸馏qwen edit,但是效果很差,能否帮忙看看是哪里的问题呀,谢谢
这是我的推理代码:
import os
import sys
sys.path.append("/home/workspace/code/TwinFlow/src")
import torch
from functools import partial
from torchvision.utils import save_image
from PIL import Image
from torchvision import transforms
from src.networks.qwen_image.modeling_qwen_image import QwenImage
from peft import PeftModel
from unified_sampler import UnifiedSampler
from safetensors.torch import load_file
from torch.amp import autocast as torch_autocast
seed = 42
torch.manual_seed(seed)
device = torch.device("cuda")
base_model_path = "/home/workspace/hf_model/Qwen-Image-Edit-2509"
lora_checkpoint_path = "../outputs/qwenimage_task/qwenimage_lora_2order/checkpoints/global_step_30/model"
input_image_path = "/home/workspace/code/DiffSynth-Studio/example_image_dataset/edit/image1_F.jpg"
prompt = "将裙子改为粉色"
height = 512
width = 512
input_image = Image.open(input_image_path).convert("RGB").resize((width, height), Image.LANCZOS)
transform = transforms.Compose([
transforms.Resize(min(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
input_tensor = transform(input_image).unsqueeze(0).to(device, dtype=torch.bfloat16)
print("load base model...")
model = QwenImage(
base_model_path,
model_type="edit",
aux_time_embed=False,
text_dtype=torch.bfloat16,
imgs_dtype=torch.bfloat16,
device=device
).to(device)
print("load lora...")
base_transformer = model.model.transformer
if lora_checkpoint_path and os.path.exists(lora_checkpoint_path):
print(f"Loading LoRA from {lora_checkpoint_path}...")
lora_transformer = PeftModel.from_pretrained(
base_transformer,
lora_checkpoint_path,
adapter_name="default",
is_trainable=False
)
adapter_path = os.path.join(lora_checkpoint_path, "adapter_model.safetensors")
raw_state_dict = load_file(adapter_path)
final_state_dict = {}
for key, value in raw_state_dict.items():
new_key = key
if "base_model.model.transformer." in new_key:
new_key = new_key.replace("base_model.model.transformer.", "base_model.model.")
if "lora_A.weight" in new_key and ".default.weight" not in new_key:
new_key = new_key.replace("lora_A.weight", "lora_A.default.weight")
elif "lora_B.weight" in new_key and ".default.weight" not in new_key:
new_key = new_key.replace("lora_B.weight", "lora_B.default.weight")
elif "lora_embedding_A" in new_key and ".default.weight" not in new_key:
new_key = new_key.replace("lora_embedding_A.weight", "lora_embedding_A.default.weight")
elif "lora_embedding_B" in new_key and ".default.weight" not in new_key:
new_key = new_key.replace("lora_embedding_B.weight", "lora_embedding_B.default.weight")
final_state_dict[new_key] = value
missing, unexpected = lora_transformer.load_state_dict(final_state_dict, strict=False)
real_missing = [k for k in missing if "lora" in k]
real_unexpected = [k for k in unexpected if "lora" in k]
if len(real_missing) > 0:
print("missing")
else:
print("success!")
lora_transformer.set_adapter("default")
lora_transformer.to(device, dtype=torch.bfloat16)
lora_transformer.eval()
# 5. 替换模型中的 transformer
model.model.transformer = lora_transformer
model.transformer.transformer = lora_transformer
else:
print("can not find lora")
model.eval()
sampler_config = {
"sampling_steps": 4,
"stochast_ratio": 1.0,
"extrapol_ratio": 0.0,
"sampling_order": 1,
"time_dist_ctrl": [1.0, 1.0, 1.0],
"rfba_gap_steps": [0.001, 0.5],
}
sampler = partial(UnifiedSampler().sampling_loop, **sampler_config)
with (
torch.no_grad(),
torch_autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"),
):
edited_image = model.sample(
prompts=[prompt],
images=input_tensor,
cfg_scale=0.0,
seed=seed,
height=height,
width=width,
sampler=sampler,
return_traj=False,
)
save_image((edited_image.squeeze(0) + 1) / 2, "edited_output.jpg")
print("done")
我保存模型的实现
def save_ckpt(
ckpt_root_dir,
model_to_save,
global_step,
):
model_dir = os.path.join(ckpt_root_dir, f"global_step_{global_step}", "model")
os.makedirs(model_dir, exist_ok=True)
if hasattr(model_to_save.transformer, 'module'):
model_to_save.transformer.module.save_pretrained(
model_dir,
safe_serialization=True
)
else:
model_to_save.transformer.save_pretrained(
model_dir,
safe_serialization=True
)
左图是蒸馏前的图,右边是蒸馏后的,我只训练了一张图,跑了120个step,无法过拟合,配置文件没咋改,只是改了这2个地方
model_name: QwenImageEdit #Flux
aux_time_embed: false
lora_rank: 64
lora_alpha: 64

Metadata
Metadata
Assignees
Labels
No labels