Skip to content
This repository has been archived by the owner on Nov 17, 2020. It is now read-only.

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Karol authored and Karol committed Jun 20, 2020
1 parent 2f143f8 commit ec57e70
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions gcam/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _mask_output(self, output, label):
indices = torch.argmax(output).detach().cpu().numpy()
mask = np.zeros(output.shape)
np.put(mask, indices, 1)
# indices = torch.argmax(output).unsqueeze(0).unsqueeze(0)
# mask = torch.zeros_like(self.logits).to(self.device)
# mask.scatter_(1, indices, 1.0)
elif isinstance(label, int): # Only for classification
indices = (output == label).nonzero()
indices = [index[0] * output.shape[1] + index[1] for index in indices]
Expand Down Expand Up @@ -145,6 +148,6 @@ def _set_postprocessor_and_label(self, output):
self.postprocessor = "sigmoid"
elif output.shape[0] == self.output_batch_size and len(output.shape) == 4 and output.shape[1] > 1: # 3D segmentation (nnUNet)
self.postprocessor = torch.nn.Softmax(dim=2)
if self.model.gcam_dict['label'] is None:
if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification
self.model.gcam_dict['label'] = "best"
# if self.model.gcam_dict['label'] is None: # TODO: Best for classification can lead to empty attention maps in some cases, reason is that computed weights are negative and relu filters them out. No idea if it should be like that or if its a bug
# if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification
# self.model.gcam_dict['label'] = "best"

0 comments on commit ec57e70

Please sign in to comment.