-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtraining.py
48 lines (44 loc) · 1.76 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from rlberry.manager import ExperimentManager
from rlberry.envs import gym_make
from rlberry.agents.stable_baselines import StableBaselinesAgent
from rlberry.seeding import Seeder
from stable_baselines3 import PPO
from avec_ppo import AVECPPO
seeder = Seeder(42)
# The ExperimentManager class is a compact way of experimenting with a deepRL agent.
default_xp = ExperimentManager(
StableBaselinesAgent, # The Agent class.
(gym_make, dict(id="Acrobot-v1")), # The Environment to solve.
fit_budget=5e4, # The number of interactions
# between the agent and the
# environment during training.
init_kwargs=dict(algo_cls=PPO), # Init value for StableBaselinesAgent
eval_kwargs=dict(eval_horizon=500), # The number of interactions
# between the agent and the
# environment during evaluations.
n_fit=5, # The number of agents to train.
# Usually, it is good to do more
# than 1 because the training is
# stochastic.
seed=seeder,
agent_name="default_ppo", # The agent's name.
output_dir="data_training_default_ppo"
)
avec_xp = ExperimentManager(
StableBaselinesAgent, # The Agent class.
(gym_make, dict(id="Acrobot-v1")), # The Environment to solve.
fit_budget=5e4, # The number of interactions
# between the agent and the
# environment during training.
init_kwargs=dict(algo_cls=AVECPPO), # Init value for StableBaselinesAgent
eval_kwargs=dict(eval_horizon=500), # The number of interactions
# between the agent and the
# environment during evaluations.
n_fit=5, # The number of agents to train.
# Usually, it is good to do more
# than 1 because the training is
# stochastic.
seed=seeder,
agent_name="avec_ppo", # The agent's name.
output_dir="data_training_avec_ppo"
)