|
| 1 | +# convert Diffusers v1.x/v2.0 model to original Stable Diffusion |
| 2 | +# v1: initial version |
| 3 | +# v2: support safetensors |
| 4 | +# v3: fix to support another format |
| 5 | +# v4: support safetensors in Diffusers |
| 6 | + |
| 7 | +import argparse |
| 8 | +import os |
| 9 | +import torch |
| 10 | +from diffusers import StableDiffusionPipeline |
| 11 | + |
| 12 | +import model_util |
| 13 | + |
| 14 | + |
| 15 | +def convert(args): |
| 16 | + # 引数を確認する |
| 17 | + load_dtype = torch.float16 if args.fp16 else None |
| 18 | + |
| 19 | + save_dtype = None |
| 20 | + if args.fp16: |
| 21 | + save_dtype = torch.float16 |
| 22 | + elif args.bf16: |
| 23 | + save_dtype = torch.bfloat16 |
| 24 | + elif args.float: |
| 25 | + save_dtype = torch.float |
| 26 | + |
| 27 | + is_load_ckpt = os.path.isfile(args.model_to_load) |
| 28 | + is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 |
| 29 | + |
| 30 | + assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" |
| 31 | + assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" |
| 32 | + |
| 33 | + # モデルを読み込む |
| 34 | + msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) |
| 35 | + print(f"loading {msg}: {args.model_to_load}") |
| 36 | + |
| 37 | + if is_load_ckpt: |
| 38 | + v2_model = args.v2 |
| 39 | + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) |
| 40 | + else: |
| 41 | + pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) |
| 42 | + text_encoder = pipe.text_encoder |
| 43 | + vae = pipe.vae |
| 44 | + unet = pipe.unet |
| 45 | + |
| 46 | + if args.v1 == args.v2: |
| 47 | + # 自動判定する |
| 48 | + v2_model = unet.config.cross_attention_dim == 1024 |
| 49 | + print("checking model version: model is " + ('v2' if v2_model else 'v1')) |
| 50 | + else: |
| 51 | + v2_model = args.v1 |
| 52 | + |
| 53 | + # 変換して保存する |
| 54 | + msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" |
| 55 | + print(f"converting and saving as {msg}: {args.model_to_save}") |
| 56 | + |
| 57 | + if is_save_ckpt: |
| 58 | + original_model = args.model_to_load if is_load_ckpt else None |
| 59 | + key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, |
| 60 | + original_model, args.epoch, args.global_step, save_dtype, vae) |
| 61 | + print(f"model saved. total converted state_dict keys: {key_count}") |
| 62 | + else: |
| 63 | + print(f"copy scheduler/tokenizer config from: {args.reference_model}") |
| 64 | + model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) |
| 65 | + print(f"model saved.") |
| 66 | + |
| 67 | + |
| 68 | +if __name__ == '__main__': |
| 69 | + parser = argparse.ArgumentParser() |
| 70 | + parser.add_argument("--v1", action='store_true', |
| 71 | + help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') |
| 72 | + parser.add_argument("--v2", action='store_true', |
| 73 | + help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') |
| 74 | + parser.add_argument("--fp16", action='store_true', |
| 75 | + help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') |
| 76 | + parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') |
| 77 | + parser.add_argument("--float", action='store_true', |
| 78 | + help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') |
| 79 | + parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') |
| 80 | + parser.add_argument("--global_step", type=int, default=0, |
| 81 | + help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') |
| 82 | + parser.add_argument("--reference_model", type=str, default=None, |
| 83 | + help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") |
| 84 | + parser.add_argument("--use_safetensors", action='store_true', |
| 85 | + help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)") |
| 86 | + |
| 87 | + parser.add_argument("model_to_load", type=str, default=None, |
| 88 | + help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") |
| 89 | + parser.add_argument("model_to_save", type=str, default=None, |
| 90 | + help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") |
| 91 | + |
| 92 | + args = parser.parse_args() |
| 93 | + convert(args) |
0 commit comments