From b33bb7c4061432a308722f97b760fcc71bb882cf Mon Sep 17 00:00:00 2001 From: Antoine Rouxel Date: Sat, 9 Mar 2024 18:26:19 +0100 Subject: [PATCH] update with resnet checkpoint --- Resnet_only.py | 36 +++++++++++++++++--------- optimization_modules_with_resnet_v2.py | 17 +++++++++++- train_resnet_for_noise_generation.py | 4 +-- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/Resnet_only.py b/Resnet_only.py index 81af723..1365597 100644 --- a/Resnet_only.py +++ b/Resnet_only.py @@ -17,11 +17,10 @@ class ResnetOnly(pl.LightningModule): - def __init__(self, model_name,log_dir="tb_logs", reconstruction_checkpoint=None): + def __init__(self,log_dir="tb_logs"): super().__init__() self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1) - self.loss_fn = nn.MSELoss() self.writer = SummaryWriter(log_dir) @@ -49,15 +48,16 @@ def forward(self, x, pattern=None): self.acquisition = self.acquisition.flip(2) self.acquisition = self.acquisition.unsqueeze(1).float() - print("acquisition shape: ", self.acquisition.shape) - plt.imshow(self.acquisition[0,0,:,:].cpu().numpy()) - plt.show() + # print("acquisition shape: ", self.acquisition.shape) + # plt.imshow(self.acquisition[0,0,:,:].cpu().numpy()) + # plt.show() self.pattern = self.mask_generation(self.acquisition) + self.pattern = BinarizeFunction.apply(self.pattern) - print("pattern shape: ", self.pattern.shape) - plt.imshow(self.pattern[0,0,:,:].cpu().numpy()) - plt.show() + # print("pattern shape: ", self.pattern.shape) + # plt.imshow(self.pattern[0,0,:,:].detach().cpu().numpy()) + # plt.show() return self.pattern @@ -67,8 +67,8 @@ def training_step(self, batch, batch_idx): loss = self._common_step(batch, batch_idx) - input_images = self._convert_output_to_images(self._normalize_image_tensor(self.input_image)) - patterns = self._convert_output_to_images(self._normalize_image_tensor(self.pattern)) + input_images = self._convert_output_to_images(self._normalize_image_tensor(self.acquisition[:,0,:,:])) + patterns = self._convert_output_to_images(self._normalize_image_tensor(self.pattern[:,0,:,:])) if self.global_step % 30 == 0: self._log_images('train/input_images', input_images, self.global_step) @@ -129,8 +129,9 @@ def _common_step(self, batch, batch_idx): output_pattern = self.forward(batch) - sum_result = torch.mean(output_pattern,dim=(1,2)) - sum_final = torch.sum(sum_result - 0.5) + sum_result = torch.mean(output_pattern,dim=(2,3)) + print("sum_result: ", sum_result) + sum_final = torch.sum(torch.abs(sum_result - 0.5)) loss1 = sum_final loss2 = calculate_spectral_flatness(output_pattern) @@ -256,3 +257,14 @@ def calculate_spectral_flatness(pattern): spectral_flatness = geometric_mean / arithmetic_mean return spectral_flatness + +class BinarizeFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + # Forward pass is the binary threshold operation + return (input > 0.5).float() + + @staticmethod + def backward(ctx, grad_output): + # For backward pass, just pass the gradients through unchanged + return grad_output diff --git a/optimization_modules_with_resnet_v2.py b/optimization_modules_with_resnet_v2.py index 83a88be..25bf528 100755 --- a/optimization_modules_with_resnet_v2.py +++ b/optimization_modules_with_resnet_v2.py @@ -26,12 +26,27 @@ def forward(self,x): return x class JointReconstructionModule_V3(pl.LightningModule): - def __init__(self, recon_lightning_module, log_dir="tb_logs",reconstruction_checkpoint=None): + def __init__(self, recon_lightning_module, log_dir="tb_logs",resnet_checkpoint=None): super().__init__() self.reconstruction_module = recon_lightning_module self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1) + if resnet_checkpoint is not None: + # Load the weights from the checkpoint into self.seg_model + checkpoint = torch.load(resnet_checkpoint, map_location=self.device) + # Adjust the keys + adjusted_state_dict = {key.replace('mask_generation.', ''): value + for key, value in checkpoint['state_dict'].items()} + # Filter out unexpected keys + model_keys = set(self.mask_generation.state_dict().keys()) + filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys} + self.mask_generation.load_state_dict(filtered_state_dict) + + # Freeze the seg_model parameters + # for param in self.mask_generation.parameters(): + # param.requires_grad = False + self.loss_fn = nn.MSELoss() self.ssim_loss = SSIM(window_size=11, n_channels=28) self.reconstruction_module.ssim_loss = SSIM(window_size=11, n_channels=28) diff --git a/train_resnet_for_noise_generation.py b/train_resnet_for_noise_generation.py index 59ae320..05afcd6 100644 --- a/train_resnet_for_noise_generation.py +++ b/train_resnet_for_noise_generation.py @@ -7,7 +7,7 @@ import datetime -data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" +data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28/" datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=11) datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') @@ -37,7 +37,7 @@ save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' ) -reconstruction_module = ResnetOnly(log_dir=log_dir+'/'+ name,t) +reconstruction_module = ResnetOnly(log_dir=log_dir+'/'+ name) if torch.cuda.is_available():