From 7725dcb8baa3aa2bccc34fcc63e3ca940d9c136c Mon Sep 17 00:00:00 2001 From: Karol Date: Sun, 21 Jun 2020 11:09:06 +0200 Subject: [PATCH] Classification bug fix --- gcam/backends/base.py | 11 +++++++---- gcam/backends/grad_cam.py | 2 +- gcam/gcam_inject.py | 19 +++---------------- setup.py | 2 +- 4 files changed, 12 insertions(+), 22 deletions(-) diff --git a/gcam/backends/base.py b/gcam/backends/base.py index e5b0fe3..4f39658 100644 --- a/gcam/backends/base.py +++ b/gcam/backends/base.py @@ -102,7 +102,10 @@ def _extract_metadata(self, input, output): # TODO: Does not work for classific else: self.output_channels = self.model.gcam_dict['channels'] if self.model.gcam_dict['data_shape'] == 'default': - self.output_shape = output.shape[2:] + if len(output.shape) == 2: # Classification -> Cannot convert attention map to classifiaction + self.output_shape = None + else: # Output is an 2D/3D image + self.output_shape = output.shape[2:] else: self.output_shape = self.model.gcam_dict['data_shape'] @@ -148,6 +151,6 @@ def _set_postprocessor_and_label(self, output): self.postprocessor = "sigmoid" elif output.shape[0] == self.output_batch_size and len(output.shape) == 4 and output.shape[1] > 1: # 3D segmentation (nnUNet) self.postprocessor = torch.nn.Softmax(dim=2) - # if self.model.gcam_dict['label'] is None: # TODO: Best for classification can lead to empty attention maps in some cases, reason is that computed weights are negative and relu filters them out. No idea if it should be like that or if its a bug - # if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification - # self.model.gcam_dict['label'] = "best" + if self.model.gcam_dict['label'] is None: + if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification + self.model.gcam_dict['label'] = "best" diff --git a/gcam/backends/grad_cam.py b/gcam/backends/grad_cam.py index 1b8eaf3..1a6bb48 100644 --- a/gcam/backends/grad_cam.py +++ b/gcam/backends/grad_cam.py @@ -148,7 +148,7 @@ def _find(self, pool, target_layer): def _compute_grad_weights(self, grads): """Computes the weights based on the gradients by average pooling.""" - if len(self.output_shape) == 2: + if self.input_dim == 2: return F.adaptive_avg_pool2d(grads, 1) else: return F.adaptive_avg_pool3d(grads, 1) diff --git a/gcam/gcam_inject.py b/gcam/gcam_inject.py index 1aa2ad6..e72bf69 100644 --- a/gcam/gcam_inject.py +++ b/gcam/gcam_inject.py @@ -166,7 +166,6 @@ def inject(model, output_dir=None, backend='gcam', layer='auto', channels=1, dat model_clone._process_attention_maps = types.MethodType(_process_attention_maps, model_clone) model_clone._save_attention_map = types.MethodType(_save_attention_map, model_clone) model_clone._replace_output = types.MethodType(_replace_output, model_clone) - model_clone._extract_metadata = types.MethodType(_extract_metadata, model_clone) model_backend, heatmap = _assign_backend(backend, model_clone, layer, postprocessor, retain_graph) gcam_dict['model_backend'] = model_backend @@ -316,20 +315,8 @@ def _replace_output(self, output, attention_map, data_shape): if self.gcam_dict['_replace_output']: if len(attention_map.keys()) == 1: output = torch.tensor(self.gcam_dict['current_attention_map']).to(str(self.gcam_dict['device'])) - output = gcam_utils.interpolate(output, data_shape) + if data_shape is not None: # If data_shape is None then the task is classification -> return unchanged attention map + output = gcam_utils.interpolate(output, data_shape) else: raise ValueError("Not possible to replace output when layer is 'full', only with 'auto' or a manually set layer") - return output - -def _extract_metadata(self, input, output): # TODO: Does not work for classification output (shape: (1, 1000)) - """Extracts metadata like batch size, number of channels and the data shape from the input batch.""" - output_batch_size = output.shape[0] - if self.gcam_dict['channels'] == 'default': - output_channels = output.shape[1] - else: - output_channels = self.gcam_dict['channels'] - if self.gcam_dict['data_shape'] == 'default': - output_shape = output.shape[2:] - else: - output_shape = self.model.gcam_dict['data_shape'] - return output_batch_size, output_channels, output_shape \ No newline at end of file + return output \ No newline at end of file diff --git a/setup.py b/setup.py index fc5253b..23e64c5 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="gcam", - version="0.0.17", + version="0.0.18", author="Karol Gotkowski", author_email="KarolGotkowski@gmx.de", description="An easy to use framework that makes model predictions more interpretable for humans.",