Skip to content

Commit cea80a6

Browse files
8029 update load old weights function for diffusion_model_unet.py (Project-MONAI#8031)
Fixes Project-MONAI#8029 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 3a6f620 commit cea80a6

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

monai/networks/nets/diffusion_model_unet.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1837,9 +1837,26 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
18371837
new_state_dict[k] = old_state_dict.pop(k)
18381838

18391839
# fix the attention blocks
1840-
attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k]
1840+
attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k]
18411841
for block in attention_blocks:
1842+
new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
1843+
new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
1844+
new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
1845+
new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
1846+
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
1847+
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
1848+
18421849
# projection
1850+
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1851+
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
1852+
1853+
# fix the cross attention blocks
1854+
cross_attention_blocks = [
1855+
k.replace(".out_proj.weight", "")
1856+
for k in new_state_dict
1857+
if "out_proj.weight" in k and "transformer_blocks" in k
1858+
]
1859+
for block in cross_attention_blocks:
18431860
new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
18441861
new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
18451862

0 commit comments

Comments
 (0)