-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_seg.py
52 lines (45 loc) · 1.45 KB
/
eval_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import numpy as np
from loader.viah_loader import *
from loader.bing_loader import *
from utils.utils_args import *
from utils.utils_eval import *
from utils.utils_train import *
from utils.utils_tri import *
from utils.utils_vis import *
from utils.loss import *
from models.model_seg import *
def eval_ds(ds, model):
TestDice_list = []
TestIoU_list = []
for ix, (_x, _y) in enumerate(ds):
_x = _x.float().cpu()
_y = _y.float().cpu()
Mask = model(_x)
Mask[Mask >= 0.5] = 1
Mask[Mask < 0.5] = 0
(cDice, cIoU) = get_dice_ji(Mask, _y)
TestDice_list.append(cDice)
TestIoU_list.append(cIoU)
Dice = np.mean(TestDice_list)
IoU = np.mean(TestIoU_list)
print((Dice, IoU))
def main():
torch.backends.cudnn.benchmark = True
args = get_args()
save_args(args)
if args['task'] == 'viah':
PATH = r'results/viah/best/'
testset = viah_segmentation(ann='test', args=args)
elif args['task'] == 'bing':
testset = bing_segmentation(ann='test', args=args)
PATH = r'results/bing/best/'
segnet = Segmentation(args)
segnet1 = torch.load(PATH + 'SEG.pt')
segnet.load_state_dict(segnet1.state_dict())
segnet.cpu().eval()
ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False,
num_workers=1, drop_last=False)
eval_ds(ds_val, segnet)
if __name__ == '__main__':
main()