-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.cpp
62 lines (53 loc) · 1.97 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include "stdio.h"
#include "stdlib.h"
#include "string.h"
#include <iostream>
#include <torch/torch.h>
#include <boost/circular_buffer.hpp>
#include "mujoco.h"
#include "glfw3.h"
#include <vector>
#include "agent.h"
#include "env.h"
using namespace std;
void train(Env &env, Agent &agent, short total_epochs=1500, float max_time = 10.0/*sec*/)
{
for(short epoch=0; epoch<total_epochs; epoch++)
{
agent.noise.reset();
env.reset();
float total_reward = 0.0;
float reward = 0.0;
float done = 0.0;
float reward_type[5] = {0.0,0.0,0.0,0.0,0.0};//type of total rewards: [task torque time reached_top too_speedy]
double total_actor_loss = 0.0;
double total_critic_loss = 0.0;
env.set_epoch_time();//env.epoch_start_time is set to zero
env.set_render_time();//env.render_time is set to zero
while(env.d->time - env.epoch_start_time < max_time)
{
agent.step(env, done, reward, reward_type, total_reward);
agent.learn(total_actor_loss, total_critic_loss);
if(env.d->time-env.render_time>0.01)
{
env.render();
env.set_render_time();
}
if(done>0.5 || glfwWindowShouldClose(env.window))
break;
};
cout<<"end of epoch: "<<epoch<<" total reward:"<<total_reward<<"time: "<<env.d->time - env.epoch_start_time<<endl;
cout<<"reward_type: distance: "<<reward_type[0]<<" torque: "<<reward_type[1]<<" time: "<<reward_type[2]<<" done: "<<reward_type[3]<<" too_speedy: "<<reward_type[4]<<endl;
cout<<"actor_loss(-critic) : "<<total_actor_loss<<endl;
cout<<"critic_loss : "<<total_critic_loss<<endl<<endl;
if(glfwWindowShouldClose(env.window))
break;
}
}
int main(int argc, const char** argv)
{
Env env(argv);
Agent agent;
train(env,agent);
return 0;
}