diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index f6dca3e97c..8fe6ff08ae 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import numpy as np import tempfile import nibabel as nib @@ -68,6 +69,7 @@ def test_shape(self, img_transform, label_transform, indexes, expected_shape): self.assertTupleEqual(data2[indexes[0]].shape, expected_shape) self.assertTupleEqual(data2[indexes[1]].shape, expected_shape) np.testing.assert_allclose(data2[indexes[0]], data2[indexes[0]]) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index ace79c6e6a..d3c492eadb 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import logging import tempfile import numpy as np @@ -97,25 +98,25 @@ def test_value(self, input_param, input_data, expected_print): @parameterized.expand([TEST_CASE_6]) def test_file(self, input_data, expected_print): - with tempfile.TemporaryDirectory() as tempdir: - filename = os.path.join(tempdir, "test_stats.log") - handler = logging.FileHandler(filename, mode="w") - input_param = { - "prefix": "test data", - "data_shape": True, - "intensity_range": True, - "data_value": True, - "additional_info": lambda x: np.mean(x), - "logger_handler": handler, - } - transform = DataStats(**input_param) - _ = transform(input_data) - handler.stream.close() - transform._logger.removeHandler(handler) - with open(filename, "r") as f: - content = f.read() - self.assertEqual(content, expected_print) - os.remove(filename) + tempdir = tempfile.mkdtemp() + filename = os.path.join(tempdir, "test_stats.log") + handler = logging.FileHandler(filename, mode="w") + input_param = { + "prefix": "test data", + "data_shape": True, + "intensity_range": True, + "data_value": True, + "additional_info": lambda x: np.mean(x), + "logger_handler": handler, + } + transform = DataStats(**input_param) + _ = transform(input_data) + handler.stream.close() + transform._logger.removeHandler(handler) + with open(filename, "r") as f: + content = f.read() + self.assertEqual(content, expected_print) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index e26dfdb70b..0c9c6e23fb 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import logging import tempfile import numpy as np @@ -110,26 +111,26 @@ def test_value(self, input_param, input_data, expected_print): @parameterized.expand([TEST_CASE_7]) def test_file(self, input_data, expected_print): - with tempfile.TemporaryDirectory() as tempdir: - filename = os.path.join(tempdir, "test_stats.log") - handler = logging.FileHandler(filename, mode="w") - input_param = { - "keys": "img", - "prefix": "test data", - "data_shape": True, - "intensity_range": True, - "data_value": True, - "additional_info": lambda x: np.mean(x), - "logger_handler": handler, - } - transform = DataStatsd(**input_param) - _ = transform(input_data) - handler.stream.close() - transform.printer._logger.removeHandler(handler) - with open(filename, "r") as f: - content = f.read() - self.assertEqual(content, expected_print) - os.remove(filename) + tempdir = tempfile.mkdtemp() + filename = os.path.join(tempdir, "test_stats.log") + handler = logging.FileHandler(filename, mode="w") + input_param = { + "keys": "img", + "prefix": "test data", + "data_shape": True, + "intensity_range": True, + "data_value": True, + "additional_info": lambda x: np.mean(x), + "logger_handler": handler, + } + transform = DataStatsd(**input_param) + _ = transform(input_data) + handler.stream.close() + transform.printer._logger.removeHandler(handler) + with open(filename, "r") as f: + content = f.read() + self.assertEqual(content, expected_print) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 791fcf122f..3e175cdc3a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import numpy as np import tempfile import nibabel as nib @@ -71,6 +72,7 @@ def test_shape(self, expected_shape): self.assertTupleEqual(data2_simple["image"].shape, expected_shape) self.assertTupleEqual(data2_simple["label"].shape, expected_shape) self.assertTupleEqual(data2_simple["extra"].shape, expected_shape) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 9b7b395a15..654290c082 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import tempfile import shutil import torch @@ -33,15 +32,14 @@ def test_one_save_one_load(self): data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) engine = Engine(lambda e, b: None) - with tempfile.TemporaryDirectory() as tempdir: - save_dir = os.path.join(tempdir, "checkpoint") - CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine) - engine.run([0] * 8, max_epochs=5) - path = save_dir + "/net_final_iteration=40.pth" - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) - engine.run([0] * 8, max_epochs=1) - torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) - shutil.rmtree(save_dir) + tempdir = tempfile.mkdtemp() + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pth" + CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) + shutil.rmtree(tempdir) def test_two_save_one_load(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -55,16 +53,15 @@ def test_two_save_one_load(self): data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) engine = Engine(lambda e, b: None) - with tempfile.TemporaryDirectory() as tempdir: - save_dir = os.path.join(tempdir, "checkpoint") - save_dict = {"net": net1, "opt": optimizer} - CheckpointSaver(save_dir=save_dir, save_dict=save_dict, save_final=True).attach(engine) - engine.run([0] * 8, max_epochs=5) - path = save_dir + "/checkpoint_final_iteration=40.pth" - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) - engine.run([0] * 8, max_epochs=1) - torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) - shutil.rmtree(save_dir) + tempdir = tempfile.mkdtemp() + save_dict = {"net": net1, "opt": optimizer} + CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/checkpoint_final_iteration=40.pth" + CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) + shutil.rmtree(tempdir) def test_save_single_device_load_multi_devices(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -78,15 +75,14 @@ def test_save_single_device_load_multi_devices(self): net2.load_state_dict(data2) net2 = torch.nn.DataParallel(net2) engine = Engine(lambda e, b: None) - with tempfile.TemporaryDirectory() as tempdir: - save_dir = os.path.join(tempdir, "checkpoint") - CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine) - engine.run([0] * 8, max_epochs=5) - path = save_dir + "/net_final_iteration=40.pth" - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) - engine.run([0] * 8, max_epochs=1) - torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1) - shutil.rmtree(save_dir) + tempdir = tempfile.mkdtemp() + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pth" + CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index c893fab694..74b22ad310 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -78,26 +78,25 @@ def _train_func(engine, batch): if multi_devices: net = torch.nn.DataParallel(net) optimizer = optim.SGD(net.parameters(), lr=0.02) - with tempfile.TemporaryDirectory() as tempdir: - save_dir = os.path.join(tempdir, "checkpoint") - handler = CheckpointSaver( - save_dir, - {"net": net, "opt": optimizer}, - "CheckpointSaver", - "test", - save_final, - save_key_metric, - key_metric_name, - key_metric_n_saved, - epoch_level, - save_interval, - n_saved, - ) - handler.attach(engine) - engine.run(data, max_epochs=5) - for filename in filenames: - self.assertTrue(os.path.exists(os.path.join(save_dir, filename))) - shutil.rmtree(save_dir) + tempdir = tempfile.mkdtemp() + handler = CheckpointSaver( + tempdir, + {"net": net, "opt": optimizer}, + "CheckpointSaver", + "test", + save_final, + save_key_metric, + key_metric_name, + key_metric_n_saved, + epoch_level, + save_interval, + n_saved, + ) + handler.attach(engine) + engine.run(data, max_epochs=5) + for filename in filenames: + self.assertTrue(os.path.exists(os.path.join(tempdir, filename))) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_nifti.py b/tests/test_load_nifti.py index 7b19324c84..c7d6d8894d 100644 --- a/tests/test_load_nifti.py +++ b/tests/test_load_nifti.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import numpy as np import tempfile import nibabel as nib @@ -38,11 +39,12 @@ class TestLoadNifti(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, filenames, expected_shape): test_image = np.random.randint(0, 2, size=[128, 128, 128]) - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadNifti(**input_param)(filenames) + tempdir = tempfile.mkdtemp() + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadNifti(**input_param)(filenames) + if isinstance(result, tuple): result, header = result self.assertTrue("affine" in header) @@ -50,6 +52,7 @@ def test_shape(self, input_param, filenames, expected_shape): if input_param["as_closest_canonical"]: np.testing.asesrt_allclose(header["original_affine"], np.eye(4)) self.assertTupleEqual(result.shape, expected_shape) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py index 7451958a4e..0b9239a85e 100644 --- a/tests/test_load_niftid.py +++ b/tests/test_load_niftid.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import numpy as np import tempfile import nibabel as nib @@ -27,13 +28,14 @@ class TestLoadNiftid(unittest.TestCase): def test_shape(self, input_param, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) test_data = dict() - with tempfile.TemporaryDirectory() as tempdir: - for key in KEYS: - nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) - test_data.update({key: os.path.join(tempdir, key + ".nii.gz")}) - result = LoadNiftid(**input_param)(test_data) + tempdir = tempfile.mkdtemp() + for key in KEYS: + nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) + test_data.update({key: os.path.join(tempdir, key + ".nii.gz")}) + result = LoadNiftid(**input_param)(test_data) for key in KEYS: self.assertTupleEqual(result[key].shape, expected_shape) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_png.py b/tests/test_load_png.py index 218555c795..40c0526fcd 100644 --- a/tests/test_load_png.py +++ b/tests/test_load_png.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import numpy as np import tempfile from PIL import Image @@ -29,13 +30,13 @@ class TestLoadPNG(unittest.TestCase): def test_shape(self, data_shape, filenames, expected_shape, meta_shape): test_image = np.random.randint(0, 256, size=data_shape) tempdir = tempfile.mkdtemp() - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - Image.fromarray(test_image.astype("uint8")).save(filenames[i]) - result = LoadPNG()(filenames) + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + Image.fromarray(test_image.astype("uint8")).save(filenames[i]) + result = LoadPNG()(filenames) self.assertTupleEqual(result[1]["spatial_shape"], meta_shape) self.assertTupleEqual(result[0].shape, expected_shape) + shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_pngd.py b/tests/test_load_pngd.py index 2891c95687..6fef42a286 100644 --- a/tests/test_load_pngd.py +++ b/tests/test_load_pngd.py @@ -11,6 +11,7 @@ import unittest import os +import shutil import numpy as np import tempfile from PIL import Image @@ -28,13 +29,13 @@ def test_shape(self, input_param, expected_shape): test_image = np.random.randint(0, 256, size=[128, 128, 3]) tempdir = tempfile.mkdtemp() test_data = dict() - with tempfile.TemporaryDirectory() as tempdir: - for key in KEYS: - Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png")) - test_data.update({key: os.path.join(tempdir, key + ".png")}) - result = LoadPNGd(**input_param)(test_data) + for key in KEYS: + Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png")) + test_data.update({key: os.path.join(tempdir, key + ".png")}) + result = LoadPNGd(**input_param)(test_data) for key in KEYS: self.assertTupleEqual(result[key].shape, expected_shape) + shutil.rmtree(tempdir) if __name__ == "__main__":