-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTTPNet.py
executable file
·84 lines (63 loc) · 2.97 KB
/
TTPNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from PredictionBiLSTM import PredictionBiLSTM
class TTPNet(nn.Module):
def __init__(self, ):
super(TTPNet, self).__init__()
self.build()
# self.init_weight()
def build(self):
self.bi_lstm = PredictionBiLSTM()
self.input2hid = nn.Linear(128, 128)
self.hid2hid = nn.Linear(128, 64)
self.hid2out = nn.Linear(64, 1)
def forward(self, attr, traj):
hiddens = self.bi_lstm(attr, traj)
n = hiddens.size()[1]
h_f = []
for i in range(2, n):
h_f_temp = torch.sum(hiddens[:, :i], dim = 1)
h_f.append(h_f_temp)
h_f.append(torch.sum(hiddens, dim = 1))
h_f = torch.stack(h_f).permute(1, 0, 2)
T_f_hat = self.input2hid(h_f)
T_f_hat = F.relu(T_f_hat)
T_f_hat = self.hid2hid(T_f_hat)
T_f_hat = F.relu(T_f_hat)
T_f_hat = self.hid2out(T_f_hat)
return T_f_hat
def dual_loss(self, T_f_hat, traj, mean, std):
T_f_hat = T_f_hat * std + mean
T_f = torch.unsqueeze(traj['T_f'], dim = 2)
M_f = torch.unsqueeze(traj['M_f'], dim = 1)
loss_f = torch.bmm(M_f, torch.pow((T_f_hat-T_f)/T_f, 2)) / torch.bmm(M_f, M_f.permute(0, 2, 1))
loss_f = torch.pow(loss_f, 1/2)
return {'pred': T_f_hat[:, -1]}, loss_f.mean()
def MAPE_loss(self, pred, label, mean, std):
label = label.view(-1, 1)
label = label * std + mean
loss = torch.abs(pred - label) / label
return {'label': label, 'pred': pred}, loss.mean()
# def init_weight(self):
# for name, param in self.named_parameters():
# if name.find('.ln') == -1:
# print(name)
# if name.find('.bias') != -1:
# param.data.fill_(0)
# elif name.find('.weight') != -1:
# nn.init.xavier_uniform_(param.data)
def eval_on_batch(self, attr, traj, config):
T_f_hat = self(attr, traj)
if self.training:
pred_dict, loss = self.dual_loss(T_f_hat, traj, config['time_gap_mean'], config['time_gap_std'])
pred = T_f_hat * config['time_gap_std'] + config['time_gap_mean']
MAPE_dict, MAPE_loss = self.MAPE_loss(pred[:, -1], attr['time'], config['time_mean'], config['time_std'])
return pred_dict, loss, MAPE_dict, MAPE_loss
else:
# pred_dict, loss = self.dual_loss(T_f_hat, traj, config['time_gap_mean'], config['time_gap_std'])
# MAPE_dict, MAPE_loss = self.MAPE_loss(pred_dict['pred'], attr['time'], config['time_mean'], config['time_std'])
pred = T_f_hat * config['time_gap_std'] + config['time_gap_mean']
MAPE_dict, MAPE_loss = self.MAPE_loss(pred[:, -1], attr['time'], config['time_mean'], config['time_std'])
return MAPE_dict, MAPE_loss