From 6db359dd4116d39dcf1300adc4397eb6926e1085 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Wed, 18 Sep 2024 03:59:14 -0400 Subject: [PATCH] Fix the bug that MAISI ckpt cannot be loaded after finetune. (#654) Fixes # . ### Description MAISI output checkpoint after finetuning cannot be used as ``trained_controlnet_path``. This problem came from the `CheckpointSaver`. When a single key is provided for `save_dict` such as: "save_dict": { "controlnet_state_dict": "@controlnet" }, The saved dict does contain the key " "controlnet_state_dict". However, it directly saves the state_dict of controlnet as the checkpoint. The workaround is that we also save the optimizer state. For example, "save_dict": { "controlnet_state_dict": "@controlnet", "optimizer": "@optimizer" }. Then, the MAISI output checkpoint after fine-tuning can be properly loaded. ### Status **Ready/Work in progress/Hold** ### Please ensure all the checkboxes: - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: Pengfei Guo --- models/maisi_ct_generative/configs/metadata.json | 3 ++- models/maisi_ct_generative/configs/train.json | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index d6cc5330..2838ca92 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", - "version": "0.4.1", + "version": "0.4.2", "changelog": { + "0.4.2": "update train.json to fix finetune ckpt bug", "0.4.1": "update large files", "0.4.0": "update to use monai 1.4, model ckpt updated, rm GenerativeAI repo, add quality check", "0.3.6": "first oss version" diff --git a/models/maisi_ct_generative/configs/train.json b/models/maisi_ct_generative/configs/train.json index ab88ffad..f2c91fab 100644 --- a/models/maisi_ct_generative/configs/train.json +++ b/models/maisi_ct_generative/configs/train.json @@ -104,7 +104,7 @@ "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())", "checkpoint_controlnet": "$torch.load(@trained_controlnet_path)", "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", - "scale_factor": "$@checkpoint_controlnet['scale_factor'].to(@device)", + "scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)", "loss": { "_target_": "torch.nn.L1Loss", "reduction": "none" @@ -214,7 +214,8 @@ "_target_": "CheckpointSaver", "save_dir": "@ckpt_dir", "save_dict": { - "controlnet_state_dict": "@controlnet" + "controlnet_state_dict": "@controlnet", + "optimizer": "@optimizer" }, "save_interval": 1, "n_saved": 5