-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
54 lines (48 loc) · 1.97 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import numpy as np
class IVEvaluator:
def __init__(self):
pass
def evaluate(self, model: torch.nn.Module, tau: list, goals: list, reduce=False):
if reduce:
return self._compute_mean_error(model, tau, goals)
else:
return self._compute_error_across_time(model, tau, goals)
def _compute_mean_error(self, model: torch.nn.Module, tau: list, goals: list):
all_errors = []
for t, goal in zip(tau, goals):
total_error = 0
for (state, nstate, action) in t:
predicted_action = (
model(torch.tensor(state).float(), torch.tensor(goal).float())
.detach()
.numpy()
)
action = action / np.linalg.norm(action)
predicted_action = predicted_action / np.linalg.norm(predicted_action)
error = np.linalg.norm(predicted_action - action) ** 2
total_error += error
all_errors.append(total_error / len(tau))
return np.array(all_errors)
def _compute_error_across_time(
self, model: torch.nn.Module, tau: list, goals: list
):
all_errors = []
for t, goal in zip(tau, goals):
errors_in_t = []
for (state, nstate, action) in t:
predicted_action = (
model(
torch.tensor(state).float().to(model.device),
torch.tensor(goal).float().to(model.device),
)
.cpu()
.detach()
.numpy()
)
action = action / np.linalg.norm(action)
predicted_action = predicted_action / np.linalg.norm(predicted_action)
error = np.linalg.norm(predicted_action - action) ** 2
errors_in_t.append(error)
all_errors.append(errors_in_t)
return np.array(all_errors)