-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
51 lines (34 loc) · 2.44 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import time
import torch
class Trainer():
def __init__(self):
pass
def update(self, model, batch, expert_batch, eval=False):
t0 = time.time()
metrics = dict()
t1 = time.time()
## Batch
obs, acts,next_obs = torch.FloatTensor(batch.observations), torch.FloatTensor(batch.actions), torch.FloatTensor(batch.next_observations)
expert_obs, expert_acts, expert_next_obs = torch.FloatTensor(expert_batch.observations), torch.FloatTensor(expert_batch.actions), torch.FloatTensor(expert_batch.next_observations)
terminals = 1-torch.FloatTensor(batch.masks)
expert_terminals = 1-torch.FloatTensor(expert_batch.masks)
is_expert = torch.FloatTensor(batch.is_expert)
metrics = model.update(obs.float().cuda(),acts.float().cuda(), next_obs.float().cuda(), terminals.float().cuda(),is_expert.float().cuda(),expert_obs.float().cuda(),expert_acts.float().cuda(), expert_next_obs.float().cuda(),expert_terminals.float().cuda() )
t2 = time.time()
return metrics, f"Load time {t1-t0}, Batch time {t2-t1}, Update time {t2-t1}, V Loss {metrics['v_loss']}"
class TrainerSNS():
def __init__(self):
pass
def update(self, model, batch, expert_batch, eval=False):
t0 = time.time()
metrics = dict()
t1 = time.time()
## Batch
obs, acts,next_obs, next_next_obs = torch.FloatTensor(batch.observations), torch.FloatTensor(batch.actions), torch.FloatTensor(batch.next_observations), torch.FloatTensor(batch.next_next_observations)
expert_obs, expert_acts, expert_next_obs, expert_next_next_obs = torch.FloatTensor(expert_batch.observations), torch.FloatTensor(expert_batch.actions), torch.FloatTensor(expert_batch.next_observations), torch.FloatTensor(expert_batch.next_next_observations)
terminals = 1-torch.FloatTensor(batch.masks)
expert_terminals = 1-torch.FloatTensor(expert_batch.masks)
is_expert = torch.FloatTensor(batch.is_expert)
metrics = model.update(obs.float().cuda(),acts.float().cuda(), next_obs.float().cuda(),next_next_obs.float().cuda(), terminals.float().cuda(),is_expert.float().cuda(),expert_obs.float().cuda(),expert_acts.float().cuda(), expert_next_obs.float().cuda(),expert_next_next_obs.float().cuda(),expert_terminals.float().cuda() )
t2 = time.time()
return metrics, f"Load time {t1-t0}, Batch time {t2-t1}, Update time {t2-t1}, V Loss {metrics['v_loss']}"