-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathtest.py
33 lines (30 loc) · 959 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from time import time
import os
def test(model, loader, device, CONFIG, metrics):
'''
test for dot-based model
'''
model.eval()
for metric in metrics:
metric.start()
start = time()
with torch.no_grad():
rs = model.propagate()
for users, ground_truth_u_b, train_mask_u_b in loader:
pred_b = model.evaluate(rs, users.to(device))
pred_b -= 1e8*train_mask_u_b.to(device)
for metric in metrics:
metric(pred_b, ground_truth_u_b.to(device))
print('Test: time={:d}s'.format(int(time()-start)))
for metric in metrics:
metric.stop()
print('{}:{}'.format(metric.get_title(), metric.metric), end='\t')
print('')
return metrics