Skip to content

Commit ae6ebdc

Browse files
committed
adding termination conditgion
1 parent ca5a5e4 commit ae6ebdc

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

cleanrl/ppo_rnd_envpool.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ def parse_args():
3737
help="the wandb's project name")
3838
parser.add_argument("--wandb-entity", type=str, default=None,
3939
help="the entity (team) of wandb's project")
40+
parser.add_argument("--reward-goal", type=float, default=1.0,
41+
help="reward to give when the goal is achieved")
42+
parser.add_argument("--reward-distance", type=float, default=1.0,
43+
help="reward to give when the goal is achieved")
44+
parser.add_argument("--early-termination", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
45+
help="whether to terminate early i.e. when the catastrophe has happened")
46+
parser.add_argument("--unifying-lidar", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
47+
help="what kind of sensor is used (same for every environment?)")
48+
parser.add_argument("--term-cost", type=int, default=1,
49+
help="how many violations before you terminate")
50+
parser.add_argument("--failure-penalty", type=float, default=0.0,
51+
help="Reward Penalty when you fail")
52+
parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
53+
help="store data while trianing")
54+
parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1",
55+
help="the storage path for the data collected")
4056

4157
# Algorithm specific arguments
4258
parser.add_argument("--env-id", type=str, default="MontezumaRevenge-v5",
@@ -140,9 +156,9 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
140156
def make_env(cfg, idx, capture_video, run_name, gamma):
141157
def thunk():
142158
if capture_video:
143-
env = gym.make(cfg.env_id)#, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
159+
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
144160
else:
145-
env = gym.make(cfg.env_id)#, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
161+
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
146162
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
147163
env = gym.wrappers.RecordEpisodeStatistics(env)
148164
if capture_video:

0 commit comments

Comments
 (0)