-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
168 lines (110 loc) · 6.33 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import argparse
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
import timm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def parse_args():
parser = argparse.ArgumentParser(description='Sanro Health Evaluation Script')
parser.add_argument('--eval_set', type=str, required=True, choices=['kaggle', 'messidor2'], help='Evaluation dataset')
parser.add_argument('--model_dict', type=str, required=True, help='State dictionary to evaluate')
parser.add_argument('--data_dir', type=str, required=False, default='test', help='Data Directory')
return parser.parse_args()
class ScanDataset(Dataset):
def __init__(self, images, ratings, transform=None):
self.images = images
self.ratings = ratings
self.transform = transform
def __len__(self):
self.length = len(self.images)
return self.length
def __getitem__(self, idx):
img_path = os.path.join(self.images[idx])
img = Image.open(img_path)
img_transformed = self.transform(img)
label = self.ratings[idx]
return img_transformed, label
class Evaluator:
def __init__(self, args):
self.data_dir = args.data_dir
self.eval_set = args.eval_set
self.test_transforms = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_workers = torch.cuda.device_count()
self.batch_size = 16
self.model = timm.create_model('vit_base_patch16_384', pretrained=False, num_classes=5)
self.state_dict = torch.load(args.model_dict)
self.model.load_state_dict(self.state_dict)
self.model = self.model.to(self.device)
self.model.eval()
print(f'Model {args.model_dict} successfully loaded.')
if self.eval_set == 'kaggle':
self.images_dir = os.path.join(self.data_dir, 'images')
self.ratings_path = os.path.join(self.data_dir, 'labels.csv')
self.labels_db = pd.read_csv(self.ratings_path)
self.labels_db = self.labels_db[self.labels_db['path'].apply(lambda x: os.path.isfile(os.path.join(self.data_dir, x)))]
self.images = [os.path.join(self.images_dir, item) for item in os.listdir(self.images_dir)]
self.ratings = self.labels_db[self.labels_db['path'].isin(['images/' + item for item in os.listdir(self.images_dir)])]
self.ratings = self.ratings['rating'].tolist()
if self.eval_set == 'messidor2':
self.images_dir = os.path.join(self.data_dir, 'messidor-2/messidor-2/preprocess')
self.ratings_path = os.path.join(self.data_dir, 'messidor_data.csv')
self.labels_db = pd.read_csv(self.ratings_path)
self.labels_db = self.labels_db[self.labels_db['id_code'].apply(lambda x: os.path.isfile(os.path.join(self.images_dir, x)))]
self.images = [os.path.join(self.images_dir, item) for item in os.listdir(self.images_dir)]
self.ratings = self.labels_db[self.labels_db['id_code'].isin([item for item in os.listdir(self.images_dir)])]
self.ratings = self.ratings['diagnosis'].tolist()
self.test_data = ScanDataset(self.images, self.ratings, transform=self.test_transforms)
self.test_loader = DataLoader(dataset = self.test_data, batch_size=self.batch_size, shuffle=False)
print(f'{len(self.test_data)} test images found, {len(self.test_loader)} batchs per epoch.')
def evaluate(self):
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
test_true = []
test_preds = []
test_loss = 0.0
for data, label in tqdm(self.test_loader):
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = criterion(output, label)
pred = output.argmax(dim=1, keepdim=True)
test_true.extend(label.cpu().numpy())
test_preds.extend(pred.cpu().numpy())
test_loss += loss.item() / len(self.test_loader)
test_true = np.array(test_true)
test_preds = np.array(test_preds)
np.save(os.path.join('outputs', f'{self.eval_set}_test_true.npy'), test_true)
np.save(os.path.join('outputs', f'{self.eval_set}_test_preds.npy'), test_preds)
accuracy = accuracy_score(test_true, test_preds)
precision = precision_score(test_true, test_preds, average='macro') # 'macro' can be changed based on needs
recall = recall_score(test_true, test_preds, average='macro') # 'macro' can be changed based on needs
f1 = f1_score(test_true, test_preds, average='macro') # 'macro' can be changed based on needs
conf_matrix = confusion_matrix(test_true, test_preds)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
plt.figure(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, fmt='g', cmap='Blues', cbar=False)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
cm_path = os.path.join('outputs', f'{self.eval_set}_confusion_matrix.png')
plt.savefig(cm_path)
print(f'Confusion Matrix saved at {cm_path}')
if __name__ == "__main__":
args = parse_args()
sanro = Evaluator(args)
sanro.evaluate()