Skip to content

Commit b637e01

Browse files
authored
442 fix model checkpoint dictionary issue (Project-MONAI#443)
* [DLMED] fix dictionary issue of checkpoint * [DLMED] fix flake8 error * [DLMED] update according to comments
1 parent eb10439 commit b637e01

File tree

3 files changed

+102
-5
lines changed

3 files changed

+102
-5
lines changed

monai/handlers/checkpoint_loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ class CheckpointLoader:
1919
"""
2020
CheckpointLoader acts as an Ignite handler to load checkpoint data from file.
2121
It can load variables for network, optimizer, lr_scheduler.
22-
And also can restore training if load the state_dict of Ignite engine.
22+
If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead
23+
as PyTorch recommended and then use this loader to load the model.
24+
And also can restore training session if load the state_dict of Ignite engine.
2325
2426
Args:
2527
load_path (str): the file path of checkpoint, it should be a PyTorch pth file.
@@ -48,5 +50,10 @@ def attach(self, engine):
4850

4951
def __call__(self, engine):
5052
checkpoint = torch.load(self.load_path)
53+
if len(self.load_dict) == 1:
54+
key = list(self.load_dict.keys())[0]
55+
if not (key in checkpoint):
56+
checkpoint = {key: checkpoint}
57+
5158
Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint)
5259
self.logger.info(f"Restored all variables from {self.load_path}")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
import tempfile
14+
import shutil
15+
import torch
16+
import unittest
17+
from ignite.engine import Engine
18+
import torch.optim as optim
19+
from monai.handlers import CheckpointSaver, CheckpointLoader
20+
import logging
21+
import sys
22+
23+
24+
class TestHandlerCheckpointLoader(unittest.TestCase):
25+
def test_one_save_one_load(self):
26+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
27+
net1 = torch.nn.PReLU()
28+
data1 = net1.state_dict()
29+
data1["weight"] = torch.tensor([0.1])
30+
net1.load_state_dict(data1)
31+
net2 = torch.nn.PReLU()
32+
data2 = net2.state_dict()
33+
data2["weight"] = torch.tensor([0.2])
34+
net2.load_state_dict(data2)
35+
engine = Engine(lambda e, b: None)
36+
with tempfile.TemporaryDirectory() as tempdir:
37+
save_dir = os.path.join(tempdir, "checkpoint")
38+
CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine)
39+
engine.run([0] * 8, max_epochs=5)
40+
path = save_dir + "/net_final_iteration=40.pth"
41+
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
42+
engine.run([0] * 8, max_epochs=1)
43+
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
44+
shutil.rmtree(save_dir)
45+
46+
def test_two_save_one_load(self):
47+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
48+
net1 = torch.nn.PReLU()
49+
optimizer = optim.SGD(net1.parameters(), lr=0.02)
50+
data1 = net1.state_dict()
51+
data1["weight"] = torch.tensor([0.1])
52+
net1.load_state_dict(data1)
53+
net2 = torch.nn.PReLU()
54+
data2 = net2.state_dict()
55+
data2["weight"] = torch.tensor([0.2])
56+
net2.load_state_dict(data2)
57+
engine = Engine(lambda e, b: None)
58+
with tempfile.TemporaryDirectory() as tempdir:
59+
save_dir = os.path.join(tempdir, "checkpoint")
60+
save_dict = {"net": net1, "opt": optimizer}
61+
CheckpointSaver(save_dir=save_dir, save_dict=save_dict, save_final=True).attach(engine)
62+
engine.run([0] * 8, max_epochs=5)
63+
path = save_dir + "/checkpoint_final_iteration=40.pth"
64+
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
65+
engine.run([0] * 8, max_epochs=1)
66+
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
67+
shutil.rmtree(save_dir)
68+
69+
def test_save_single_device_load_multi_devices(self):
70+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
71+
net1 = torch.nn.PReLU()
72+
data1 = net1.state_dict()
73+
data1["weight"] = torch.tensor([0.1])
74+
net1.load_state_dict(data1)
75+
net2 = torch.nn.PReLU()
76+
data2 = net2.state_dict()
77+
data2["weight"] = torch.tensor([0.2])
78+
net2.load_state_dict(data2)
79+
net2 = torch.nn.DataParallel(net2)
80+
engine = Engine(lambda e, b: None)
81+
with tempfile.TemporaryDirectory() as tempdir:
82+
save_dir = os.path.join(tempdir, "checkpoint")
83+
CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine)
84+
engine.run([0] * 8, max_epochs=5)
85+
path = save_dir + "/net_final_iteration=40.pth"
86+
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
87+
engine.run([0] * 8, max_epochs=1)
88+
torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1)
89+
shutil.rmtree(save_dir)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

tests/test_handler_checkpoint_saver.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import unittest
1717
from ignite.engine import Engine
18-
from monai.handlers import CheckpointSaver, CheckpointLoader
18+
from monai.handlers import CheckpointSaver
1919
import torch.optim as optim
2020
from parameterized import parameterized
2121
import logging
@@ -97,9 +97,6 @@ def _train_func(engine, batch):
9797
engine.run(data, max_epochs=5)
9898
for filename in filenames:
9999
self.assertTrue(os.path.exists(os.path.join(save_dir, filename)))
100-
loader = CheckpointLoader(load_path=os.path.join(save_dir, filename), load_dict={"net": net})
101-
loader.attach(engine)
102-
engine.run(data, max_epochs=1)
103100
shutil.rmtree(save_dir)
104101

105102

0 commit comments

Comments
 (0)