Skip to content

Commit 85cd1f0

Browse files
update
1 parent 8e047d3 commit 85cd1f0

File tree

4 files changed

+26
-20
lines changed

4 files changed

+26
-20
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
project(rl-tools-example)
12
#set(RL_TOOLS_BACKEND_ENABLE_MKL ON) # if you have MKL installed (fastest on Intel)
23
#set(RL_TOOLS_BACKEND_ENABLE_OPENBLAS ON) # if you have OpenBLAS installed
34
#set(RL_TOOLS_BACKEND_ENABLE_ACCELERATE ON) # if you are on macOS (fastest on Apple Silicon)

README.MD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ This example also includes the automatic experiment tracking available through t
4646
2. [State => JSON](https://github.com/rl-tools/example/blob/39acaa5b5402eacf5c2cab7b2e96db71f2ea110f/include/my_pendulum/operations_cpu.h#L8): Self-explanatory
4747
3. [UI Render function string](https://github.com/rl-tools/example/blob/39acaa5b5402eacf5c2cab7b2e96db71f2ea110f/include/my_pendulum/operations_cpu.h#L16): This function uses the HTML5 Canvas rendering API and can be easily created using [https://studio.rl.tools](https://studio.rl.tools). Nnote that due to the wide spread use of the HTML5 Canvas drawing API, also ChatGPT is really good at creating render functions for different environments if you give it an example like the ones provided on [https://studio.rl.tools](https://studio.rl.tools).
4848

49-
The experiment tracking and save-trajectories step will periodically record trajectories and store them as `.json` files. After/while running the training you can run `./serve.sh` which should start a local webserver on [http://localhost:8080](http://localhost:8080) where you can see the recorded trajectories based on the render function you provided.
49+
The experiment tracking and save-trajectories step will periodically record trajectories and store them as `.json` files. After/while running the training you can run `python3 -m http.server` which should start a local webserver on [http://localhost:8080](http://localhost:8080) where you can see the recorded trajectories based on the render function you provided.

include/my_pendulum/my_pendulum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ struct MyPendulum: rl_tools::rl::environments::Environment<typename T_SPEC::T, t
4343
using ObservationPrivileged = Observation;
4444
static constexpr TI OBSERVATION_DIM = 3;
4545
static constexpr TI ACTION_DIM = 1;
46+
static constexpr TI EPISODE_STEP_LIMIT = 200;
4647
};

src/main.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,41 @@ using TI = typename DEVICE::index_t;
2222
using PENDULUM_SPEC = MyPendulumSpecification<T, TI, MyPendulumParameters<T>>;
2323
using ENVIRONMENT = MyPendulum<PENDULUM_SPEC>;
2424
struct LOOP_CORE_PARAMETERS: rlt::rl::algorithms::ppo::loop::core::DefaultParameters<T, TI, ENVIRONMENT>{
25-
static constexpr TI BATCH_SIZE = 256;
26-
static constexpr TI ACTOR_HIDDEN_DIM = 64;
27-
static constexpr TI CRITIC_HIDDEN_DIM = 64;
28-
static constexpr TI ON_POLICY_RUNNER_STEPS_PER_ENV = 1024;
29-
static constexpr TI N_ENVIRONMENTS = 4;
30-
static constexpr TI TOTAL_STEP_LIMIT = 300000;
25+
26+
static constexpr TI N_ENVIRONMENTS = 8;
27+
static constexpr TI ON_POLICY_RUNNER_STEPS_PER_ENV = 128;
28+
static constexpr TI BATCH_SIZE = 128;
29+
static constexpr TI TOTAL_STEP_LIMIT = 500000;
30+
static constexpr TI ACTOR_HIDDEN_DIM = 32;
31+
static constexpr TI CRITIC_HIDDEN_DIM = 32;
32+
static constexpr auto ACTOR_ACTIVATION_FUNCTION = rlt::nn::activation_functions::ActivationFunction::FAST_TANH;
33+
static constexpr auto CRITIC_ACTIVATION_FUNCTION = rlt::nn::activation_functions::ActivationFunction::FAST_TANH;
3134
static constexpr TI STEP_LIMIT = TOTAL_STEP_LIMIT/(ON_POLICY_RUNNER_STEPS_PER_ENV * N_ENVIRONMENTS) + 1;
32-
static constexpr TI EPISODE_STEP_LIMIT = 200;
33-
using OPTIMIZER_PARAMETERS = rlt::nn::optimizers::adam::DEFAULT_PARAMETERS_PYTORCH<T>;
35+
static constexpr TI EPISODE_STEP_LIMIT = ENVIRONMENT::EPISODE_STEP_LIMIT;
36+
struct OPTIMIZER_PARAMETERS: rlt::nn::optimizers::adam::DEFAULT_PARAMETERS_TENSORFLOW<T>{
37+
static constexpr T ALPHA = 0.01;
38+
};
39+
3440
struct PPO_PARAMETERS: rlt::rl::algorithms::ppo::DefaultParameters<T, TI, BATCH_SIZE>{
3541
static constexpr T ACTION_ENTROPY_COEFFICIENT = 0.0;
36-
static constexpr TI N_EPOCHS = 2;
42+
static constexpr TI N_EPOCHS = 1;
43+
static constexpr bool NORMALIZE_OBSERVATIONS = true;
3744
static constexpr T GAMMA = 0.9;
3845
static constexpr T INITIAL_ACTION_STD = 2.0;
39-
static constexpr bool NORMALIZE_OBSERVATIONS = true;
4046
};
4147
};
4248
using LOOP_CORE_CONFIG = rlt::rl::algorithms::ppo::loop::core::Config<T, TI, RNG, ENVIRONMENT, LOOP_CORE_PARAMETERS>;
4349
#ifndef BENCHMARK
4450
using LOOP_EXTRACK_CONFIG = rlt::rl::loop::steps::extrack::Config<LOOP_CORE_CONFIG>; // Sets up the experiment tracking structure (https://docs.rl.tools/10-Experiment%20Tracking.html)
4551
template <typename NEXT>
4652
struct LOOP_EVAL_PARAMETERS: rlt::rl::loop::steps::evaluation::Parameters<T, TI, NEXT>{
47-
static constexpr TI EVALUATION_INTERVAL = 4;
53+
static constexpr TI EVALUATION_INTERVAL = LOOP_CORE_CONFIG::CORE_PARAMETERS::STEP_LIMIT / 5;
4854
static constexpr TI NUM_EVALUATION_EPISODES = 10;
4955
static constexpr TI N_EVALUATIONS = NEXT::CORE_PARAMETERS::STEP_LIMIT / EVALUATION_INTERVAL;
5056
};
5157
using LOOP_EVALUATION_CONFIG = rlt::rl::loop::steps::evaluation::Config<LOOP_EXTRACK_CONFIG, LOOP_EVAL_PARAMETERS<LOOP_EXTRACK_CONFIG>>; // Evaluates the policy in a fixed interval and logs the return
5258
struct LOOP_SAVE_TRAJECTORIES_PARAMETERS: rlt::rl::loop::steps::save_trajectories::Parameters<T, TI, LOOP_EVALUATION_CONFIG>{
53-
static constexpr TI INTERVAL_TEMP = LOOP_CORE_CONFIG::CORE_PARAMETERS::STEP_LIMIT / 10;
59+
static constexpr TI INTERVAL_TEMP = LOOP_CORE_CONFIG::CORE_PARAMETERS::STEP_LIMIT / 3;
5460
static constexpr TI INTERVAL = INTERVAL_TEMP == 0 ? 1 : INTERVAL_TEMP;
5561
static constexpr TI NUM_EPISODES = 10;
5662
};
@@ -69,23 +75,21 @@ using LOOP_STATE = typename LOOP_CONFIG::template State<LOOP_CONFIG>;
6975

7076
int main(){
7177
DEVICE device;
72-
TI seed = 1337;
78+
TI seed = 2;
7379
LOOP_STATE ls;
7480
#ifndef BENCHMARK
7581
// Set experiment tracking info
7682
ls.extrack_name = "example";
7783
#endif
7884
rlt::malloc(device, ls);
7985
rlt::init(device, ls, seed);
80-
ls.actor_optimizer.parameters.alpha = 1e-2;
81-
ls.critic_optimizer.parameters.alpha = 1e-2;
8286
auto start_time = std::chrono::high_resolution_clock::now();
8387
while(!rlt::step(device, ls)){
8488
// do what ever you want here, e.g. poor man's learning rate scheduler:
85-
if(ls.step % 1 == 0){
86-
ls.actor_optimizer.parameters.alpha *= 0.9;
87-
ls.critic_optimizer.parameters.alpha *= 0.9;
88-
}
89+
// if(ls.step % 1 == 0){
90+
// ls.actor_optimizer.parameters.alpha *= 0.9;
91+
// ls.critic_optimizer.parameters.alpha *= 0.9;
92+
// }
8993
}
9094
auto end_time = std::chrono::high_resolution_clock::now();
9195
std::chrono::duration<double> diff = end_time-start_time;

0 commit comments

Comments
 (0)