Skip to content

Commit

Permalink
Remove unneccessary tempdir (Project-MONAI#446)
Browse files Browse the repository at this point in the history
* [DLMED] remove tempdir

* [DLMED] update all new features

* [DLMED] delete tempdir
  • Loading branch information
Nic-Ma authored May 28, 2020
1 parent 253d1aa commit 3181e3e
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 108 deletions.
2 changes: 2 additions & 0 deletions tests/test_arraydataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
import nibabel as nib
Expand Down Expand Up @@ -68,6 +69,7 @@ def test_shape(self, img_transform, label_transform, indexes, expected_shape):
self.assertTupleEqual(data2[indexes[0]].shape, expected_shape)
self.assertTupleEqual(data2[indexes[1]].shape, expected_shape)
np.testing.assert_allclose(data2[indexes[0]], data2[indexes[0]])
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
39 changes: 20 additions & 19 deletions tests/test_data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import logging
import tempfile
import numpy as np
Expand Down Expand Up @@ -97,25 +98,25 @@ def test_value(self, input_param, input_data, expected_print):

@parameterized.expand([TEST_CASE_6])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_stats.log")
handler = logging.FileHandler(filename, mode="w")
input_param = {
"prefix": "test data",
"data_shape": True,
"intensity_range": True,
"data_value": True,
"additional_info": lambda x: np.mean(x),
"logger_handler": handler,
}
transform = DataStats(**input_param)
_ = transform(input_data)
handler.stream.close()
transform._logger.removeHandler(handler)
with open(filename, "r") as f:
content = f.read()
self.assertEqual(content, expected_print)
os.remove(filename)
tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "test_stats.log")
handler = logging.FileHandler(filename, mode="w")
input_param = {
"prefix": "test data",
"data_shape": True,
"intensity_range": True,
"data_value": True,
"additional_info": lambda x: np.mean(x),
"logger_handler": handler,
}
transform = DataStats(**input_param)
_ = transform(input_data)
handler.stream.close()
transform._logger.removeHandler(handler)
with open(filename, "r") as f:
content = f.read()
self.assertEqual(content, expected_print)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
41 changes: 21 additions & 20 deletions tests/test_data_statsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import logging
import tempfile
import numpy as np
Expand Down Expand Up @@ -110,26 +111,26 @@ def test_value(self, input_param, input_data, expected_print):

@parameterized.expand([TEST_CASE_7])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_stats.log")
handler = logging.FileHandler(filename, mode="w")
input_param = {
"keys": "img",
"prefix": "test data",
"data_shape": True,
"intensity_range": True,
"data_value": True,
"additional_info": lambda x: np.mean(x),
"logger_handler": handler,
}
transform = DataStatsd(**input_param)
_ = transform(input_data)
handler.stream.close()
transform.printer._logger.removeHandler(handler)
with open(filename, "r") as f:
content = f.read()
self.assertEqual(content, expected_print)
os.remove(filename)
tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "test_stats.log")
handler = logging.FileHandler(filename, mode="w")
input_param = {
"keys": "img",
"prefix": "test data",
"data_shape": True,
"intensity_range": True,
"data_value": True,
"additional_info": lambda x: np.mean(x),
"logger_handler": handler,
}
transform = DataStatsd(**input_param)
_ = transform(input_data)
handler.stream.close()
transform.printer._logger.removeHandler(handler)
with open(filename, "r") as f:
content = f.read()
self.assertEqual(content, expected_print)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
import nibabel as nib
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_shape(self, expected_shape):
self.assertTupleEqual(data2_simple["image"].shape, expected_shape)
self.assertTupleEqual(data2_simple["label"].shape, expected_shape)
self.assertTupleEqual(data2_simple["extra"].shape, expected_shape)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
54 changes: 25 additions & 29 deletions tests/test_handler_checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import shutil
import torch
Expand All @@ -33,15 +32,14 @@ def test_one_save_one_load(self):
data2["weight"] = torch.tensor([0.2])
net2.load_state_dict(data2)
engine = Engine(lambda e, b: None)
with tempfile.TemporaryDirectory() as tempdir:
save_dir = os.path.join(tempdir, "checkpoint")
CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = save_dir + "/net_final_iteration=40.pth"
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
shutil.rmtree(save_dir)
tempdir = tempfile.mkdtemp()
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pth"
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
shutil.rmtree(tempdir)

def test_two_save_one_load(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Expand All @@ -55,16 +53,15 @@ def test_two_save_one_load(self):
data2["weight"] = torch.tensor([0.2])
net2.load_state_dict(data2)
engine = Engine(lambda e, b: None)
with tempfile.TemporaryDirectory() as tempdir:
save_dir = os.path.join(tempdir, "checkpoint")
save_dict = {"net": net1, "opt": optimizer}
CheckpointSaver(save_dir=save_dir, save_dict=save_dict, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = save_dir + "/checkpoint_final_iteration=40.pth"
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
shutil.rmtree(save_dir)
tempdir = tempfile.mkdtemp()
save_dict = {"net": net1, "opt": optimizer}
CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/checkpoint_final_iteration=40.pth"
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
shutil.rmtree(tempdir)

def test_save_single_device_load_multi_devices(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Expand All @@ -78,15 +75,14 @@ def test_save_single_device_load_multi_devices(self):
net2.load_state_dict(data2)
net2 = torch.nn.DataParallel(net2)
engine = Engine(lambda e, b: None)
with tempfile.TemporaryDirectory() as tempdir:
save_dir = os.path.join(tempdir, "checkpoint")
CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = save_dir + "/net_final_iteration=40.pth"
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1)
shutil.rmtree(save_dir)
tempdir = tempfile.mkdtemp()
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pth"
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
39 changes: 19 additions & 20 deletions tests/test_handler_checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,25 @@ def _train_func(engine, batch):
if multi_devices:
net = torch.nn.DataParallel(net)
optimizer = optim.SGD(net.parameters(), lr=0.02)
with tempfile.TemporaryDirectory() as tempdir:
save_dir = os.path.join(tempdir, "checkpoint")
handler = CheckpointSaver(
save_dir,
{"net": net, "opt": optimizer},
"CheckpointSaver",
"test",
save_final,
save_key_metric,
key_metric_name,
key_metric_n_saved,
epoch_level,
save_interval,
n_saved,
)
handler.attach(engine)
engine.run(data, max_epochs=5)
for filename in filenames:
self.assertTrue(os.path.exists(os.path.join(save_dir, filename)))
shutil.rmtree(save_dir)
tempdir = tempfile.mkdtemp()
handler = CheckpointSaver(
tempdir,
{"net": net, "opt": optimizer},
"CheckpointSaver",
"test",
save_final,
save_key_metric,
key_metric_name,
key_metric_n_saved,
epoch_level,
save_interval,
n_saved,
)
handler.attach(engine)
engine.run(data, max_epochs=5)
for filename in filenames:
self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
13 changes: 8 additions & 5 deletions tests/test_load_nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
import nibabel as nib
Expand Down Expand Up @@ -38,18 +39,20 @@ class TestLoadNifti(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_param, filenames, expected_shape):
test_image = np.random.randint(0, 2, size=[128, 128, 128])
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadNifti(**input_param)(filenames)
tempdir = tempfile.mkdtemp()
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadNifti(**input_param)(filenames)

if isinstance(result, tuple):
result, header = result
self.assertTrue("affine" in header)
np.testing.assert_allclose(header["affine"], np.eye(4))
if input_param["as_closest_canonical"]:
np.testing.asesrt_allclose(header["original_affine"], np.eye(4))
self.assertTupleEqual(result.shape, expected_shape)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
12 changes: 7 additions & 5 deletions tests/test_load_niftid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
import nibabel as nib
Expand All @@ -27,13 +28,14 @@ class TestLoadNiftid(unittest.TestCase):
def test_shape(self, input_param, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
test_data = dict()
with tempfile.TemporaryDirectory() as tempdir:
for key in KEYS:
nib.save(test_image, os.path.join(tempdir, key + ".nii.gz"))
test_data.update({key: os.path.join(tempdir, key + ".nii.gz")})
result = LoadNiftid(**input_param)(test_data)
tempdir = tempfile.mkdtemp()
for key in KEYS:
nib.save(test_image, os.path.join(tempdir, key + ".nii.gz"))
test_data.update({key: os.path.join(tempdir, key + ".nii.gz")})
result = LoadNiftid(**input_param)(test_data)
for key in KEYS:
self.assertTupleEqual(result[key].shape, expected_shape)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions tests/test_load_png.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
from PIL import Image
Expand All @@ -29,13 +30,13 @@ class TestLoadPNG(unittest.TestCase):
def test_shape(self, data_shape, filenames, expected_shape, meta_shape):
test_image = np.random.randint(0, 256, size=data_shape)
tempdir = tempfile.mkdtemp()
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
result = LoadPNG()(filenames)
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
result = LoadPNG()(filenames)
self.assertTupleEqual(result[1]["spatial_shape"], meta_shape)
self.assertTupleEqual(result[0].shape, expected_shape)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions tests/test_load_pngd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
from PIL import Image
Expand All @@ -28,13 +29,13 @@ def test_shape(self, input_param, expected_shape):
test_image = np.random.randint(0, 256, size=[128, 128, 3])
tempdir = tempfile.mkdtemp()
test_data = dict()
with tempfile.TemporaryDirectory() as tempdir:
for key in KEYS:
Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png"))
test_data.update({key: os.path.join(tempdir, key + ".png")})
result = LoadPNGd(**input_param)(test_data)
for key in KEYS:
Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png"))
test_data.update({key: os.path.join(tempdir, key + ".png")})
result = LoadPNGd(**input_param)(test_data)
for key in KEYS:
self.assertTupleEqual(result[key].shape, expected_shape)
shutil.rmtree(tempdir)


if __name__ == "__main__":
Expand Down

0 comments on commit 3181e3e

Please sign in to comment.