-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
109 lines (88 loc) · 4.08 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import pandas as pd
import os
import joblib
from dataset import MRIDataset, CustomImageDataset
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader, Dataset, RandomSampler
import torch
from losses import GeneralizedSupervisedNTXenLoss, NTXenLoss, SupConLoss
from torch.nn import CrossEntropyLoss
from models.densenet import densenet121
from models.unet import UNet
import argparse
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from config import Config, PRETRAINING, FINE_TUNING, CLASSES
def get_predictions(loader, net, is_encoder=False):
y_pred = []
y_true = []
for inputs, labels, paths in tqdm(loader, desc="Getting data predictions"):
if is_encoder:
output = net(inputs).data.cpu().numpy()
else:
output = torch.max(net(inputs), 1)[1].data.cpu().numpy()
y_pred.extend(output) # Save Prediction
labels = labels.data.cpu().numpy()
y_true.extend(labels) # Save Truth
return y_pred, y_true
def get_embeddings(loader, net, unknown=False):
embed = []
y_true = []
for inputs, labels, paths in tqdm(loader, desc="Getting data embeddings"):
output = net(inputs, return_hidden=True).data.cpu().numpy()
embed.extend(output) # Save Prediction
labels = labels.data.cpu().numpy()
y_true.extend(labels) # Save Truth
return embed, y_true
def plot_losses(losses):
min_epoch = losses['validation'].index(min(losses['validation']))+1
print(f'Epoch with minimum validation loss: {min_epoch}')
plt.figure(figsize=(3, 2), dpi=300)
plt.plot(range(len(losses['train'])), losses['train'], label="train")
plt.plot(range(len(losses['validation'])), losses['validation'], label="validation")
plt.legend()
plt.ylabel('Cross Entropy Loss')
plt.xlabel('Epoch')
plt.savefig('losses.png', bbox_inches = 'tight', dpi=300)
plt.close()
def plot_confusion_matrix(loader, net):
y_pred, y_true = get_predictions(loader, net)
print(f'Accuracy: {accuracy_score(y_true, y_pred):.2%}')
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=CLASSES, cmap='Blues', values_format='.2%', normalize='true')
plt.savefig('confusion_matrix.png', bbox_inches = 'tight', dpi=300)
plt.close()
def plot_latent_space(loader, net):
embed, y_true = get_embeddings(loader, net)
tsne = TSNE(n_components=2, random_state=123, n_iter=10000)
z = tsne.fit_transform(embed)
cmap = plt.cm.get_cmap('Set1').copy()
cmap = matplotlib.colors.ListedColormap(cmap.colors[:len(CLASSES)])
scatter = plt.scatter(x=z[:,0], y=z[:,1], c=y_true, cmap=cmap)
plt.legend(handles=scatter.legend_elements()[0], labels=CLASSES)
plt.savefig('latent_space.png', bbox_inches = 'tight', dpi=300)
plt.close()
if __name__ == "__main__":
config = Config(FINE_TUNING)
parser = argparse.ArgumentParser()
parser.add_argument("model_path", type=str, help="Path to the checkpoint file of the model")
parser.add_argument("--dir", type=str, default='data/',
help="The input directory that contains the labels and processed data. By default: data/")
parser.add_argument("-l", action='store_true', help="Output only the training losses")
args = parser.parse_args()
dataset_test = CustomImageDataset(config, args.dir + '/test.csv', args.dir + '/Processed/', FINE_TUNING)
loader_test = DataLoader(dataset_test,
batch_size=config.batch_size,
pin_memory=config.pin_mem,
num_workers=config.num_cpu_workers
)
checkpoint = torch.load(args.model_path)
net = densenet121(mode="classifier", drop_rate=0.0, num_classes=len(CLASSES))
net = torch.nn.DataParallel(net).to('cuda')
net.load_state_dict(checkpoint['model'])
plot_losses(checkpoint['losses'])
if not args.l:
plot_confusion_matrix(loader_test, net)
plot_latent_space(loader_test, net)