Skip to content

Commit

Permalink
update with resnet checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
arouxel-laas committed Mar 9, 2024
1 parent 62a164b commit b33bb7c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
36 changes: 24 additions & 12 deletions Resnet_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

class ResnetOnly(pl.LightningModule):

def __init__(self, model_name,log_dir="tb_logs", reconstruction_checkpoint=None):
def __init__(self,log_dir="tb_logs"):
super().__init__()

self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1)
self.loss_fn = nn.MSELoss()
self.writer = SummaryWriter(log_dir)


Expand Down Expand Up @@ -49,15 +48,16 @@ def forward(self, x, pattern=None):
self.acquisition = self.acquisition.flip(2)
self.acquisition = self.acquisition.unsqueeze(1).float()

print("acquisition shape: ", self.acquisition.shape)
plt.imshow(self.acquisition[0,0,:,:].cpu().numpy())
plt.show()
# print("acquisition shape: ", self.acquisition.shape)
# plt.imshow(self.acquisition[0,0,:,:].cpu().numpy())
# plt.show()

self.pattern = self.mask_generation(self.acquisition)
self.pattern = BinarizeFunction.apply(self.pattern)

print("pattern shape: ", self.pattern.shape)
plt.imshow(self.pattern[0,0,:,:].cpu().numpy())
plt.show()
# print("pattern shape: ", self.pattern.shape)
# plt.imshow(self.pattern[0,0,:,:].detach().cpu().numpy())
# plt.show()

return self.pattern

Expand All @@ -67,8 +67,8 @@ def training_step(self, batch, batch_idx):

loss = self._common_step(batch, batch_idx)

input_images = self._convert_output_to_images(self._normalize_image_tensor(self.input_image))
patterns = self._convert_output_to_images(self._normalize_image_tensor(self.pattern))
input_images = self._convert_output_to_images(self._normalize_image_tensor(self.acquisition[:,0,:,:]))
patterns = self._convert_output_to_images(self._normalize_image_tensor(self.pattern[:,0,:,:]))

if self.global_step % 30 == 0:
self._log_images('train/input_images', input_images, self.global_step)
Expand Down Expand Up @@ -129,8 +129,9 @@ def _common_step(self, batch, batch_idx):

output_pattern = self.forward(batch)

sum_result = torch.mean(output_pattern,dim=(1,2))
sum_final = torch.sum(sum_result - 0.5)
sum_result = torch.mean(output_pattern,dim=(2,3))
print("sum_result: ", sum_result)
sum_final = torch.sum(torch.abs(sum_result - 0.5))
loss1 = sum_final

loss2 = calculate_spectral_flatness(output_pattern)
Expand Down Expand Up @@ -256,3 +257,14 @@ def calculate_spectral_flatness(pattern):
spectral_flatness = geometric_mean / arithmetic_mean

return spectral_flatness

class BinarizeFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# Forward pass is the binary threshold operation
return (input > 0.5).float()

@staticmethod
def backward(ctx, grad_output):
# For backward pass, just pass the gradients through unchanged
return grad_output
17 changes: 16 additions & 1 deletion optimization_modules_with_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,27 @@ def forward(self,x):
return x

class JointReconstructionModule_V3(pl.LightningModule):
def __init__(self, recon_lightning_module, log_dir="tb_logs",reconstruction_checkpoint=None):
def __init__(self, recon_lightning_module, log_dir="tb_logs",resnet_checkpoint=None):
super().__init__()

self.reconstruction_module = recon_lightning_module
self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1)

if resnet_checkpoint is not None:
# Load the weights from the checkpoint into self.seg_model
checkpoint = torch.load(resnet_checkpoint, map_location=self.device)
# Adjust the keys
adjusted_state_dict = {key.replace('mask_generation.', ''): value
for key, value in checkpoint['state_dict'].items()}
# Filter out unexpected keys
model_keys = set(self.mask_generation.state_dict().keys())
filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys}
self.mask_generation.load_state_dict(filtered_state_dict)

# Freeze the seg_model parameters
# for param in self.mask_generation.parameters():
# param.requires_grad = False

self.loss_fn = nn.MSELoss()
self.ssim_loss = SSIM(window_size=11, n_channels=28)
self.reconstruction_module.ssim_loss = SSIM(window_size=11, n_channels=28)
Expand Down
4 changes: 2 additions & 2 deletions train_resnet_for_noise_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datetime


data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28"
data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28/"
datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=11)

datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M')
Expand Down Expand Up @@ -37,7 +37,7 @@
save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt'
)

reconstruction_module = ResnetOnly(log_dir=log_dir+'/'+ name,t)
reconstruction_module = ResnetOnly(log_dir=log_dir+'/'+ name)


if torch.cuda.is_available():
Expand Down

0 comments on commit b33bb7c

Please sign in to comment.