Skip to content

Commit

Permalink
LATUPNet Interpretation Code
Browse files Browse the repository at this point in the history
  • Loading branch information
xis10z committed Mar 26, 2024
2 parents 6a6c348 + 28b99f7 commit 54372a4
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ architectures.

The architectures consist of basic 3D UNets and the 3D LATUP-Net segmentation model with some variations/parameters.

Results of using these modules are available at https://qyber.black/ca/results-bca-unet.
Results of using these modules are available at https://qyber.black/ca/results-bca-latup.

## Installation

Expand Down
3 changes: 2 additions & 1 deletion bca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* **bca.model**: Base class for model generator classes (used by models below).
* **bca.loss**: Custom loss classes.
* **bca.metric**: Custom metrics classes.
* **bca.interpret**: Explainability functionality.
* **bca.scheduler**: Scheduler to train models remotely from jupyter notebooks or from the command line
(mostly limited to Linux as `rsync` and `ssh` are needed). Tasks for the scheduler are created by trainer.
Expand All @@ -34,4 +35,4 @@
also provides configuration options.
"""

__all__ = ["dataset", "model", "trainer", "scheduler", "unet", "latupnet", "loss", "metric", "cfg"]
__all__ = ["dataset", "model", "trainer", "scheduler", "unet", "latupnet", "loss", "metric", "interpret", "cfg"]
4 changes: 2 additions & 2 deletions bca/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
The constructor of the `Cfg` class should never be called as it hold all values in the class. Use the
static methods of the class only.
"""
raise Exception("Tried to create a Cfg objects; values should only be held in the class")
raise RuntimeError("Tried to create a Cfg objects; values should only be held in the class")

# Log levels
Error = 1
Expand Down Expand Up @@ -154,7 +154,7 @@ def init(bin_path):
Cfg.val[k] = js[k]
else:
if fc != root_cfg_file: # We fix this here later
raise Exception(f"Unknown config file entry {k} in {fc}")
raise RuntimeError(f"Unknown config file entry {k} in {fc}")
# Setup plot defaults
if Cfg.val["screen_dpi"] == None:
Cfg.val["screen_dpi"] = Cfg._screen_dpi()
Expand Down
16 changes: 8 additions & 8 deletions bca/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def filter_low_labels(self, c, min_label_per):
elif self.crops_type == "bb" or self.crops_type == "orig":
data_cache = self.crops_type
else:
raise Exception(f"Unknown crop type {self.crops_type}")
raise RuntimeError(f"Unknown crop type {self.crops_type}")
cache = os.path.join(self.cache, f"voxel_counts_{data_cache}.csv")
voxel_labels = [None] * len(self)
voxel_counts = [None] * len(self)
Expand Down Expand Up @@ -126,7 +126,7 @@ def filter_low_labels(self, c, min_label_per):
elif self.crops_type == "bb":
crp = self.crops[k]
else:
raise Exception(f"Illegal crops {self.crops_type}")
raise RuntimeError(f"Illegal crops {self.crops_type}")
data = data.slicer[crp[1][0]:crp[1][1],crp[0][0]:crp[0][1],crp[2][0]:crp[2][1]]
data = data.get_fdata()
labels, label_counts = np.unique(data, return_counts=True)
Expand Down Expand Up @@ -226,7 +226,7 @@ def cropped(self, idx, channels=None):
elif self.crops_type == "bb":
crp = self.crops[idx]
else:
raise Exception(f"Illegal crops {self.crops_type}")
raise RuntimeError(f"Illegal crops {self.crops_type}")
data[c] = data[c].slicer[crp[0][0]:crp[0][1],crp[1][0]:crp[1][1],crp[2][0]:crp[2][1]]
return data

Expand Down Expand Up @@ -256,7 +256,7 @@ def browse(self):
"""Interactive widget to browse data in notebooks.
"""
if len(self) < 1:
raise Exception("Dataset empty")
raise RuntimeError("Dataset empty")
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
Expand Down Expand Up @@ -312,7 +312,7 @@ def view(idx, slice, overlay):
linewidth=1, edgecolor='r', facecolor='none')
ax[k].add_patch(rect)
else:
raise Exception(f"Illegal crops {self.crops_type}")
raise RuntimeError(f"Illegal crops {self.crops_type}")
ax[k].set_title(self.patients[idx-1]+"-"+c)
plt.tight_layout()
plt.show()
Expand Down Expand Up @@ -411,7 +411,7 @@ def sequences(self, k, dim, inp, out, batch_size, pre_proc=None, seed=None, fixe
elif self.crops_type == "bb" or self.crops_type == "orig":
data_cache = self.crops_type
else:
raise Exception(f"Illegal crops {self.crops_type}")
raise RuntimeError(f"Illegal crops {self.crops_type}")
data_cache = os.path.join(self.cache, data_cache+"-"+"_".join([str(d) for d in dim])+"-"+("none" if pre_proc is None else pre_proc.__name__))
os.makedirs(data_cache, exist_ok=True)

Expand Down Expand Up @@ -474,7 +474,7 @@ def _split(self, k, seed):
np.random.default_rng(seed=seed).shuffle(idx)
self.set = np.floor(idx % k).astype(np.uint8)
else:
raise Exception(f"Illegal k: {k}")
raise RuntimeError(f"Illegal k: {k}")

def _create_sample(self, pidx, dim, pre_proc, data_cache, seg_mask):
# Create single input/output sample in cache
Expand All @@ -499,7 +499,7 @@ def _create_sample(self, pidx, dim, pre_proc, data_cache, seg_mask):
elif self.crops_type == "bb":
crp = self.crops[pidx]
else:
raise Exception(f"Illegal crops {self.crops_type}")
raise RuntimeError(f"Illegal crops {self.crops_type}")
stack = data[c].slicer[crp[1][0]:crp[1][1],crp[0][0]:crp[0][1],crp[2][0]:crp[2][1]]
vs = stack.header.get_zooms()
sx = (crp[1][1] - crp[1][0]) * vs[0] / dim[0]
Expand Down
199 changes: 199 additions & 0 deletions bca/interpret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# bca/interpret.py - interpret models
#
# SPDX-FileCopyrightText: Copyright (C) 2023-2024 Ebtihal Alwadee <AlwadeeEJ@cardiff.ac.uk>, PhD student at Cardiff University
# SPDX-FileCopyrightText: Copyright (C) 2024 Frank C Langbein <frank@langbein.org>, Cardiff University
# SPDX-License-Identifier: AGPL-3.0-or-later

from .cfg import Cfg

import tensorflow as tf
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns

class GradCamVisualizer:
"""GradCam visualiser.
Visualiser for GradCam to explore impact of attention.
"""

def __init__(self, model,seq):
"""Create GradCAM visualiser.
Args:
* `model`: tensorflow model.
* `seq`: data sequence.
"""
self.model = model
self.seq = seq

def visualise(self, index_image, index_slice, index_layer, layers):
"""Visualise GradCAM for a given image and layer.
Note `index_image` is actually the batch index, but as we use batch size 1
this corresponds to the image index. May have to change this eventually.
Args:
* `index_image`: index of image in sequence.
* `index_slice`: index of slice in image.
* `index_layer`: index of layer.
* `layers`: list of model layers to analyse.
"""
image, mask = self.seq[index_image]
mask = np.argmax(mask, axis=-1)[0,:,:,:]
prediction = self.model.predict(image)
prediction_mask = np.argmax(prediction, axis=-1)[0,:,:,:]
heatmap_per_class = []
for class_idx in range(prediction.shape[-1]):
heatmap = self._get_heatmap_layers(image, layers, class_idx)
heatmap_per_class.append(heatmap)
print(f"GradCAM for {index_image}-{index_slice}")
self._plot(image, mask, prediction_mask, heatmap_per_class, index_slice, index_layer)

def _get_heatmap_layers(self, image, list_layers, class_idx):
# Helper to process layers for heatmap
features_layer=[]
for layer_name in list_layers:
heatmap = self._make_gradcam_heatmap(image, layer_name, class_idx)
features_layer.append(heatmap)
return features_layer

def _make_gradcam_heatmap(self, img_array, layer_name, class_idx):
# Helper to compute GradCAM
grad_model = tf.keras.models.Model([self.model.inputs], [self.model.get_layer(layer_name).output, self.model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
loss = predictions[0,:,:,:,class_idx]
output = conv_outputs[0]
grads = tape.gradient(loss, conv_outputs)[0]
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0,1,2))
cam = np.ones(output.shape[0: 3], dtype = np.float32)
for l, w in enumerate(weights):
cam += w * output[:, :, :, l]
heatmap = tf.maximum(cam, 0) / tf.math.reduce_max(cam)
return heatmap

def _plot(self, image, mask, prediction, heatmaps, index_slice, layer_index):
# Helper to produce GradCAM plot.

# Adjust subplot creation for a single row
fig, axs = plt.subplots(1,len(heatmaps)+3,sharex=True,sharey=True,dpi=Cfg.val["screen_dpi"],figsize=(Cfg.val["figsize"][0]*(len(heatmaps)+3),Cfg.val["figsize"][1]))

class_names = ['BG', 'NCR/NET', 'ED', 'ET'] # FIXME: for now, specific to model this has been used for
heatmap_combined = np.zeros(heatmaps[0][layer_index].shape[:2])
for class_idx in range(len(heatmaps)):
axs[class_idx].imshow(image[0,:,:,index_slice, 0], cmap='gray')
axs[class_idx].imshow(heatmaps[class_idx][layer_index][:,:,index_slice], cmap='hot', alpha=0.5)
axs[class_idx].set_title(f'{class_names[class_idx]}')
heatmap_combined += heatmaps[class_idx][layer_index][:,:,index_slice]

dcmap = colors.ListedColormap(['k','r','g','b']) # FIXME: for now, specific to model this has been used for

axs[len(heatmaps)].imshow(image[0,:,:,index_slice, 0], cmap='gray')
masked_mask = np.ma.masked_where(mask[:,:,index_slice] == 0, mask[:,:,index_slice])
axs[len(heatmaps)].imshow(masked_mask, cmap=dcmap, alpha=0.5, vmin=0, vmax=3)
axs[len(heatmaps)].set_title('Ground Truth')

axs[len(heatmaps)+1].imshow(image[0,:,:,index_slice, 0], cmap='gray')
masked_prediction = np.ma.masked_where(prediction[:,:,index_slice] == 0, prediction[:,:,index_slice])
axs[len(heatmaps)+1].imshow(masked_prediction, cmap=dcmap, alpha=0.5, vmin=0, vmax=3)
axs[len(heatmaps)+1].set_title('Prediction')

axs[len(heatmaps)+2].imshow(heatmap_combined, cmap='hot', alpha=0.5)
axs[len(heatmaps)+2].set_title("Combined")

plt.tight_layout()
plt.show()

class ConfusionMatrices:
"""Confusion matrix between model predictions and ground truth.
Class to compute and visualise the confusion matrix between model predictions and ground truth.
"""

def __init__(self, model, seq, class_names=['BackGround', '(NCR/NET)', '(ED)', '(ET)']):
"""Create new confusion matrix.
Args:
* `model`: tensorflow model.
* `seq`: data sequence.
* `class_names`: list of class names for classification.
"""
self.model = model
self.seq = seq
self.class_names = class_names

def get_all(self):
"""Get all confusion matrices.
Note, this class removes samples/patients without an ET class to avoid inconsistent results,
as some samples do not have ET regions.
Return:
* Pair of aggregated confusion matrix across all classes and individual confusion matrices per class.
"""
et_class_index = self.class_names.index('(ET)')
aggregated_cm = np.zeros([len(self.class_names), len(self.class_names)])
conf_matrices_all = []
for image,mask in self.seq:
prediction = self.model.predict(image, verbose=0)
prediction_mask = np.argmax(prediction, axis=-1)[0,:,:,:]
mask = np.argmax(mask, axis=-1)[0,:,:,:]
cm = ConfusionMatrices._confusion_matrix(mask, prediction_mask)
# Include this matrix only if 'ET' class is present in ground truth or predictions
if et_class_index in np.unique(mask):
conf_matrices_all.append(cm)
aggregated_cm += cm
return aggregated_cm, conf_matrices_all

@staticmethod
def _confusion_matrix(y_true, y_pred):
# Helper to construct confusion matrix
y_true = np.reshape(y_true, -1)
y_pred = np.reshape(y_pred, -1)
return confusion_matrix(y_true, y_pred)

@staticmethod
def calculate_mean_std(matrices):
"""Calculate mean and std across confusion matrices.
Args:
* `matrices`: list of matrices
Return:
* `mean_matrix`: mean confusion matrix.
* `std_matrix`: std confusion matrix.
"""
normalized_matrices = []
for cm in matrices:
row_sums = cm.sum(axis=1)[:, np.newaxis]
normalized_cm = cm.astype('float') / (row_sums + 1e-10)
normalized_matrices.append(normalized_cm)

normalized_matrices_array = np.array(normalized_matrices)
mean_matrix = np.mean(normalized_matrices_array, axis=0)
std_matrix = np.std(normalized_matrices_array, axis=0)

return mean_matrix, std_matrix

def plot_heatmap(self, matrix, title, model_path):
"""Plot confusion matrix as heatmap.
Args:
* `matrix`: confusion matrix.
* `title`: title for plot.
* `model_path`: path to model / name.
"""
print(f"Model: {model_path}")
row_sums = matrix.sum(axis=1)[:, np.newaxis]
normalized_matrix = matrix.astype('float') / (row_sums + 1e-10)
fig = plt.figure(dpi=Cfg.val["screen_dpi"],figsize=(Cfg.val["figsize"][0],Cfg.val["figsize"][1]))
sns.heatmap(normalized_matrix, annot=True, fmt=".2f", cmap="RdBu_r", xticklabels=self.class_names, yticklabels=self.class_names)
plt.title(title)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.tight_layout()
plt.show()
4 changes: 2 additions & 2 deletions bca/latupnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def construct(self, seq, batch_size, jit_compile=True):
num_classes = seq[-1] # Last element of tuple is number of classes
else:
if len(seq.cache.inp_chs) != 1 or len(seq.cache.out_chs) != 1:
raise Exception("Invalid input/output numbers")
raise RuntimeError("Invalid input/output numbers")
inputs = Input(shape=(*seq.dim,len(seq.cache.inp_chs[0])),batch_size=batch_size if self.fixed_batch_size else None, name='input_layer')
num_classes = len(seq.cache.out_chs[0]) # Determine classes from output shape

Expand Down Expand Up @@ -224,7 +224,7 @@ def _attention(self, x, name):
elif self.attention == "multimodal":
return self._att_multimodal(x, name)
if self.attention is not None and self.attention != "None":
raise Exception(f"Unknown attention module {self.attention}")
raise RuntimeError(f"Unknown attention module {self.attention}")
return x

# Squeeze and excitation attention
Expand Down
4 changes: 2 additions & 2 deletions bca/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def compute_channel_weights(ds, mask_channel, channels, mode="enet", normalise=F
elif ds.crops_type == "bb" or ds.crops_type == "orig":
data_cache = ds.crops_type
else:
raise Exception(f"Unknown crop type {ds.crops_type}")
raise RuntimeError(f"Unknown crop type {ds.crops_type}")
cache = os.path.join(ds.cache, f"voxel_counts_{data_cache}.csv")
voxel_labels = [None] * len(ds)
voxel_counts = [None] * len(ds)
Expand Down Expand Up @@ -79,7 +79,7 @@ def compute_channel_weights(ds, mask_channel, channels, mode="enet", normalise=F
for ch in channels:
weights[ch] = 1/np.log(enet_c + float(channel_volumes[ch])/float(total_volume)) if channels[ch][1] else 0.0
else:
raise Exception(f"Unknown class_weight mode {mode}")
raise RuntimeError(f"Unknown class_weight mode {mode}")
if normalise:
# Compute total weight for normalisation
total_weight = 0.0
Expand Down
2 changes: 1 addition & 1 deletion bca/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def plot(self, dim, text=False):
except:
p.kill()
if p.exitcode != 0:
raise Exception("Process failed")
raise RuntimeError("Process failed")
else:
self._plot(file, dim, text)
if text:
Expand Down
Loading

0 comments on commit 54372a4

Please sign in to comment.