-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathevaluate_classifier.py
129 lines (116 loc) · 5.05 KB
/
evaluate_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# COVID-CT-Mask-Net
# Torchvision detection package is locally re-implemented
# Transformed into a classification model with Mask R-CNN backend
# by Alex Ter-Sarkisov@City, University of London
# alex.ter-sarkisov@city.ac.uk
# 2020
import os
import re
import sys
import time
import config_classifier as config
import cv2
#######################################
import models.mask_net as mask_net
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import utils
from PIL import Image as PILImage
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.rpn import AnchorGenerator
def main(config, step):
torch.manual_seed(time.time())
start_time = time.time()
devices = ['cpu', 'cuda']
backbones = ['resnet50', 'resnet34', 'resnet18']
truncation_levels = ['0','1','2']
assert config.device in devices
assert config.backbone_name in backbones
assert config.truncation in truncation_levels
pretrained_model, model_name, test_data_dir, device, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features\
= config.ckpt, config.model_name, config.test_data_dir, config.device, config.rpn_nms_th, \
config.roi_nms_th, config.backbone_name, config.truncation, config.roi_batch_size, config.num_classes, config.s_features
if torch.cuda.is_available() and device == 'cuda':
device = torch.device('cuda')
else:
device = torch.device('cpu')
# either 2+1 or 1+1 classes
ckpt = torch.load(pretrained_model, map_location=device)
# 'box_detections_per_img': batch size input in module S
# 'box_score_thresh': negative to accept all predictions
covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
print(covid_mask_net_args)
# extract anchor generator from the checkpoint
sizes = ckpt['anchor_generator'].sizes
aspect_ratios = ckpt['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes, aspect_ratios)
# Faster R-CNN interfaces, masks not implemented at this stage
box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
# Mask prediction is not necessary, keep it for future extensions
covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
covid_mask_net_args['box_predictor'] = box_predictor
covid_mask_net_args['box_head'] = box_head
# representation size of the S classification module
# these should be provided in the config
covid_mask_net_args['s_representation_size'] = s_features
# Instance of the model, copy weights
covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
covid_mask_net_model.load_state_dict(ckpt['model_weights'])
covid_mask_net_model.eval().to(device)
print(covid_mask_net_model)
# confusion matrix
confusion_matrix = torch.zeros(3, 3, dtype=torch.int32).to(device)
for idx, f in enumerate(os.listdir(test_data_dir)):
step(f, covid_mask_net_model, test_data_dir, device, confusion_matrix)
print("------------------------------------------")
print("Confusion Matrix for 3-class problem:")
print("0: Control, 1: Normal Pneumonia, 2: COVID")
print(confusion_matrix)
print("------------------------------------------")
# confusion matrix
cm = confusion_matrix.float()
cm[0, :].div_(cm[0, :].sum())
cm[1, :].div_(cm[1, :].sum())
cm[2, :].div_(cm[2, :].sum())
print("------------------------------------------")
print("Class Sensitivity:")
print(cm)
print("------------------------------------------")
print('Overall accuracy:')
print(confusion_matrix.diag().float().sum().div(confusion_matrix.sum()))
end_time = time.time()
print("Evaluation took {0:.1f} seconds".format(end_time - start_time))
def test_step(im_input, model, source_dir, device, c_matrix):
# CNCB NCOV datasets: the first integer is the correct class:
# 0: control
# 1: pneumonia
# 2: COVID
# extract the correct class from the file name
correct_class = int(im_input.split('/')[-1].split('_')[0])
im = PILImage.open(os.path.join(source_dir, im_input))
if im.mode != 'RGB':
im = im.convert(mode='RGB')
# get rid of alpha channel
img = np.array(im)
# print(img)
if img.shape[2] > 3:
img = img[:, :, :3]
t_ = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(512),
transforms.ToTensor()])
img = t_(img)
if device == torch.device('cuda'):
img = img.to(device)
out = model([img])
pred_class = out[0]['final_scores'].argmax().item()
# get confusion matrix
c_matrix[correct_class, pred_class] += 1
# run the inference
if __name__ == '__main__':
config_test = config.get_config_pars_classifier("test")
main(config_test, test_step)