Skip to content
This repository has been archived by the owner on Nov 17, 2020. It is now read-only.

Commit

Permalink
Classification bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Karol authored and Karol committed Jun 21, 2020
1 parent cd3a28a commit 7725dcb
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 22 deletions.
11 changes: 7 additions & 4 deletions gcam/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def _extract_metadata(self, input, output): # TODO: Does not work for classific
else:
self.output_channels = self.model.gcam_dict['channels']
if self.model.gcam_dict['data_shape'] == 'default':
self.output_shape = output.shape[2:]
if len(output.shape) == 2: # Classification -> Cannot convert attention map to classifiaction
self.output_shape = None
else: # Output is an 2D/3D image
self.output_shape = output.shape[2:]
else:
self.output_shape = self.model.gcam_dict['data_shape']

Expand Down Expand Up @@ -148,6 +151,6 @@ def _set_postprocessor_and_label(self, output):
self.postprocessor = "sigmoid"
elif output.shape[0] == self.output_batch_size and len(output.shape) == 4 and output.shape[1] > 1: # 3D segmentation (nnUNet)
self.postprocessor = torch.nn.Softmax(dim=2)
# if self.model.gcam_dict['label'] is None: # TODO: Best for classification can lead to empty attention maps in some cases, reason is that computed weights are negative and relu filters them out. No idea if it should be like that or if its a bug
# if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification
# self.model.gcam_dict['label'] = "best"
if self.model.gcam_dict['label'] is None:
if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification
self.model.gcam_dict['label'] = "best"
2 changes: 1 addition & 1 deletion gcam/backends/grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _find(self, pool, target_layer):

def _compute_grad_weights(self, grads):
"""Computes the weights based on the gradients by average pooling."""
if len(self.output_shape) == 2:
if self.input_dim == 2:
return F.adaptive_avg_pool2d(grads, 1)
else:
return F.adaptive_avg_pool3d(grads, 1)
Expand Down
19 changes: 3 additions & 16 deletions gcam/gcam_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def inject(model, output_dir=None, backend='gcam', layer='auto', channels=1, dat
model_clone._process_attention_maps = types.MethodType(_process_attention_maps, model_clone)
model_clone._save_attention_map = types.MethodType(_save_attention_map, model_clone)
model_clone._replace_output = types.MethodType(_replace_output, model_clone)
model_clone._extract_metadata = types.MethodType(_extract_metadata, model_clone)

model_backend, heatmap = _assign_backend(backend, model_clone, layer, postprocessor, retain_graph)
gcam_dict['model_backend'] = model_backend
Expand Down Expand Up @@ -316,20 +315,8 @@ def _replace_output(self, output, attention_map, data_shape):
if self.gcam_dict['_replace_output']:
if len(attention_map.keys()) == 1:
output = torch.tensor(self.gcam_dict['current_attention_map']).to(str(self.gcam_dict['device']))
output = gcam_utils.interpolate(output, data_shape)
if data_shape is not None: # If data_shape is None then the task is classification -> return unchanged attention map
output = gcam_utils.interpolate(output, data_shape)
else:
raise ValueError("Not possible to replace output when layer is 'full', only with 'auto' or a manually set layer")
return output

def _extract_metadata(self, input, output): # TODO: Does not work for classification output (shape: (1, 1000))
"""Extracts metadata like batch size, number of channels and the data shape from the input batch."""
output_batch_size = output.shape[0]
if self.gcam_dict['channels'] == 'default':
output_channels = output.shape[1]
else:
output_channels = self.gcam_dict['channels']
if self.gcam_dict['data_shape'] == 'default':
output_shape = output.shape[2:]
else:
output_shape = self.model.gcam_dict['data_shape']
return output_batch_size, output_channels, output_shape
return output
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="gcam",
version="0.0.17",
version="0.0.18",
author="Karol Gotkowski",
author_email="KarolGotkowski@gmx.de",
description="An easy to use framework that makes model predictions more interpretable for humans.",
Expand Down

0 comments on commit 7725dcb

Please sign in to comment.