-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
I attempted to use the following script for batch image generation:
import SimpleITK as sitk
import torch
import os
import numpy as np
from utils import trim_state_dict_name
from matplotlib import pyplot as plt
import nibabel as nib
latent_dim = 1024
save_step = 80000
batch_size = 1
img_size = 256
num_class = 0
exp_name = "HA_GAN_run1"
num_images = 3039 # Number of images to generate in a loop
if img_size == 256:
from models.Model_HA_GAN_256 import Generator, Encoder, Sub_Encoder
elif img_size == 128:
from models.Model_HA_GAN_128 import Generator, Encoder, Sub_Encoder
G = Generator(mode='eval', latent_dim=latent_dim, num_class=num_class).cuda()
E = Encoder().cuda()
Sub_E = Sub_Encoder(latent_dim=latent_dim).cuda()
# ----------------------
# Load Generator weights
ckpt_path = "/HA-GAN/GSP_HA_GAN_pretrained/G_iter80000.pth"
ckpt = torch.load(ckpt_path)['model']
ckpt = trim_state_dict_name(ckpt)
G.load_state_dict(ckpt)
# Load Encoder weights
ckpt_path = "/HA-GAN/GSP_HA_GAN_pretrained/E_iter80000.pth"
ckpt = torch.load(ckpt_path)['model']
ckpt = trim_state_dict_name(ckpt)
E.load_state_dict(ckpt)
# Load Sub_Encoder weights
ckpt_path = "/HA-GAN/GSP_HA_GAN_pretrained/Sub_E_iter80000.pth"
ckpt = torch.load(ckpt_path)['model']
ckpt = trim_state_dict_name(ckpt)
Sub_E.load_state_dict(ckpt)
print(exp_name, save_step, "step weights loaded.")
outpath = "/HA-GAN/GSP_HA_GAN_images"
os.makedirs(outpath, exist_ok=True)
G.eval()
E.eval()
Sub_E.eval()
torch.cuda.empty_cache()
# ----------------------
# Loop to generate multiple images and save them
low_threshold = -1024
high_threshold = 600
with torch.no_grad():
for i in range(num_images):
z_rand = torch.randn((batch_size, latent_dim)).cuda()
x_rand = G(z_rand, 0)
x_rand = x_rand.detach().cpu().numpy()
# Map the generated output from [-1,1] to [0,1]; adjust this step based on your model's output
x_rand = 0.5 * x_rand + 0.5
# For batch_size=1, take the first image and channel 0
x_rand = x_rand[0, 0, :, :, :]
# Map to the typical CT intensity range [low_threshold, high_threshold]
x_rand_nifti = x_rand * (high_threshold - low_threshold) + low_threshold
x_rand_nifti = x_rand_nifti.astype(np.int16)
# Transpose and encapsulate in NIfTI format
x_rand_nifti = nib.Nifti1Image(x_rand_nifti.transpose((2, 1, 0)), affine=np.eye(4))
# Construct output filename with loop index i
out_filename = os.path.join(outpath, f"x_rand_nifti_{i}.nii.gz")
nib.save(x_rand_nifti, out_filename)
print(f"Saved {out_filename}")However, when loading the GSP_HA_GAN pretrained weights, I encountered the following error:
Traceback (most recent call last):
File "inference_gsp.py", line 31, in <module>
ckpt = torch.load(ckpt_path)['model']
File ".../torch/serialization.py", line 593, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ".../torch/serialization.py", line 747, in _legacy_load
return legacy_load(f)
File ".../torch/serialization.py", line 672, in legacy_load
tar.extract('storages', path=tmpdir)
File ".../tarfile.py", line 2060, in extract
tarinfo = self.getmember(member)
File ".../tarfile.py", line 1782, in getmember
raise KeyError("filename %r not found" % name)
KeyError: "filename 'storages' not found"
After searching online, it appears that this error may indicate that the weight file is corrupted.
Could you please update the repository with a new, valid version of the pretrained weight files? Your assistance would be greatly appreciated.
Thank you very much!
Metadata
Metadata
Assignees
Labels
No labels