Skip to content

qwen edit 2509 训练lora效果很差 #23

@hello-program

Description

@hello-program

你好,我参考了 https://github.com/inclusionAI/TwinFlow/tree/main/src 的lora使用介绍去蒸馏qwen edit,但是效果很差,能否帮忙看看是哪里的问题呀,谢谢
这是我的推理代码:

import os
import sys
sys.path.append("/home/workspace/code/TwinFlow/src")
import torch
from functools import partial
from torchvision.utils import save_image
from PIL import Image
from torchvision import transforms

from src.networks.qwen_image.modeling_qwen_image import QwenImage
from peft import PeftModel
from unified_sampler import UnifiedSampler
from safetensors.torch import load_file
from torch.amp import autocast as torch_autocast

seed = 42
torch.manual_seed(seed)
device = torch.device("cuda")

base_model_path = "/home/workspace/hf_model/Qwen-Image-Edit-2509"
lora_checkpoint_path = "../outputs/qwenimage_task/qwenimage_lora_2order/checkpoints/global_step_30/model"

input_image_path = "/home/workspace/code/DiffSynth-Studio/example_image_dataset/edit/image1_F.jpg"
prompt = "将裙子改为粉色"

height = 512
width = 512

input_image = Image.open(input_image_path).convert("RGB").resize((width, height), Image.LANCZOS)
transform = transforms.Compose([
    transforms.Resize(min(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])
input_tensor = transform(input_image).unsqueeze(0).to(device, dtype=torch.bfloat16)

print("load base model...")
model = QwenImage(
    base_model_path,
    model_type="edit",
    aux_time_embed=False,
    text_dtype=torch.bfloat16,
    imgs_dtype=torch.bfloat16,
    device=device
).to(device)

print("load lora...")
base_transformer = model.model.transformer

if lora_checkpoint_path and os.path.exists(lora_checkpoint_path):
    print(f"Loading LoRA from {lora_checkpoint_path}...")
    
    lora_transformer = PeftModel.from_pretrained(
        base_transformer,
        lora_checkpoint_path,
        adapter_name="default",
        is_trainable=False
    )
    
    adapter_path = os.path.join(lora_checkpoint_path, "adapter_model.safetensors")
    raw_state_dict = load_file(adapter_path)
    final_state_dict = {}
    
    for key, value in raw_state_dict.items():
        new_key = key
        if "base_model.model.transformer." in new_key:
            new_key = new_key.replace("base_model.model.transformer.", "base_model.model.")
        if "lora_A.weight" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_A.weight", "lora_A.default.weight")
        elif "lora_B.weight" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_B.weight", "lora_B.default.weight")
        elif "lora_embedding_A" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_embedding_A.weight", "lora_embedding_A.default.weight")
        elif "lora_embedding_B" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_embedding_B.weight", "lora_embedding_B.default.weight")
        
        final_state_dict[new_key] = value
    
    missing, unexpected = lora_transformer.load_state_dict(final_state_dict, strict=False)
    

    real_missing = [k for k in missing if "lora" in k]
    real_unexpected = [k for k in unexpected if "lora" in k]
    
    if len(real_missing) > 0:
        print("missing")
    else:
        print("success!")

    lora_transformer.set_adapter("default")
    lora_transformer.to(device, dtype=torch.bfloat16)
    lora_transformer.eval()
    
    # 5. 替换模型中的 transformer
    model.model.transformer = lora_transformer
    model.transformer.transformer = lora_transformer
else:
    print("can not find lora")
model.eval()

sampler_config = {
    "sampling_steps": 4,
    "stochast_ratio": 1.0,
    "extrapol_ratio": 0.0,
    "sampling_order": 1,
    "time_dist_ctrl": [1.0, 1.0, 1.0],
    "rfba_gap_steps": [0.001, 0.5],
}
sampler = partial(UnifiedSampler().sampling_loop, **sampler_config)

with (
    torch.no_grad(),
    torch_autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"),
):
        edited_image = model.sample(
            prompts=[prompt],
            images=input_tensor,
            cfg_scale=0.0, 
            seed=seed,
            height=height,
            width=width,
            sampler=sampler,
            return_traj=False,
        )

save_image((edited_image.squeeze(0) + 1) / 2, "edited_output.jpg")
print("done")

我保存模型的实现

def save_ckpt(
    ckpt_root_dir,
    model_to_save,
    global_step,
):
    model_dir = os.path.join(ckpt_root_dir, f"global_step_{global_step}", "model")
    os.makedirs(model_dir, exist_ok=True)

    if hasattr(model_to_save.transformer, 'module'):
        model_to_save.transformer.module.save_pretrained(
            model_dir,
            safe_serialization=True  
        )
    else:
        model_to_save.transformer.save_pretrained(
            model_dir,
            safe_serialization=True
        )

左图是蒸馏前的图,右边是蒸馏后的,我只训练了一张图,跑了120个step,无法过拟合,配置文件没咋改,只是改了这2个地方

  model_name: QwenImageEdit #Flux
  aux_time_embed: false
  lora_rank: 64
  lora_alpha: 64
Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions