-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.cpp
66 lines (54 loc) · 2.19 KB
/
eval.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
// Deserialize the ScriptModule from a file using torch::jit::load().
torch::jit::script::Module module;
try {
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
//Note that hist and nbrs trajectories should be computed wrt to the reference vehicle
//The following are Dummy variables for testing
//hist, nbrs, mask, lat_enc, lon_enc
torch::Tensor hist = torch::zeros({16,1,2});
torch::Tensor nbrs = torch::zeros({16,39,2});
torch::Tensor lat_enc = torch::zeros({1,3});
torch::Tensor lon_enc = torch::zeros({1,2});
torch::Tensor mask = torch::ones({1,3,13,64}, {torch::kByte});
//Concatentae the inputs into IValue Struct
//hist, nbrs, mask, lat_enc, lon_enc
std::vector<torch::jit::IValue> inputs;
inputs.push_back(hist);
inputs.push_back(nbrs);
inputs.push_back(mask);
inputs.push_back(lat_enc);
inputs.push_back(lon_enc);
//Forward Pass to evaluate the model on the given inputs
//Outputs concatenates three predicted outputs: fut_pred, lat_pred, lon_pred
auto outputs = module.forward(inputs).toTuple();
//unpack out into fut_pred, lat_pred, lon_pred
auto fut_pred = outputs->elements()[0].toTensorList();
torch::Tensor lat_pred = outputs->elements()[1].toTensor();
torch::Tensor lon_pred = outputs->elements()[2].toTensor();
//Find the manuever-based future trajectory
int lat_man = lat_pred.argmax(1).item().toInt();
int lon_man = lon_pred.argmax(1).item().toInt();
int indx = lon_man*3 + lat_man;
torch::Tensor fut_pred_max = fut_pred[indx];
//Each row of the future predicted trajectory is [muX, muY, sigX, sigY, rho]
//You can use muX, muY as the predicted X, Y positions
//Note that trajectories were computed wrt to the reference vehicle
//So the reference position should be added back
//Confirm
std::cout << lat_pred << std::endl;
std::cout << lat_man << std::endl;
std::cout << lon_pred << std::endl;
std::cout << lon_man << std::endl;
std::cout << indx << std::endl;
std::cout << fut_pred_max << std::endl;
}