@@ -37,6 +37,22 @@ def parse_args():
37
37
help = "the wandb's project name" )
38
38
parser .add_argument ("--wandb-entity" , type = str , default = None ,
39
39
help = "the entity (team) of wandb's project" )
40
+ parser .add_argument ("--reward-goal" , type = float , default = 1.0 ,
41
+ help = "reward to give when the goal is achieved" )
42
+ parser .add_argument ("--reward-distance" , type = float , default = 1.0 ,
43
+ help = "reward to give when the goal is achieved" )
44
+ parser .add_argument ("--early-termination" , type = lambda x : bool (strtobool (x )), default = True , nargs = "?" , const = True ,
45
+ help = "whether to terminate early i.e. when the catastrophe has happened" )
46
+ parser .add_argument ("--unifying-lidar" , type = lambda x : bool (strtobool (x )), default = True , nargs = "?" , const = True ,
47
+ help = "what kind of sensor is used (same for every environment?)" )
48
+ parser .add_argument ("--term-cost" , type = int , default = 1 ,
49
+ help = "how many violations before you terminate" )
50
+ parser .add_argument ("--failure-penalty" , type = float , default = 0.0 ,
51
+ help = "Reward Penalty when you fail" )
52
+ parser .add_argument ("--collect-data" , type = lambda x : bool (strtobool (x )), default = False , nargs = "?" , const = True ,
53
+ help = "store data while trianing" )
54
+ parser .add_argument ("--storage-path" , type = str , default = "./data/ppo/term_1" ,
55
+ help = "the storage path for the data collected" )
40
56
41
57
# Algorithm specific arguments
42
58
parser .add_argument ("--env-id" , type = str , default = "MontezumaRevenge-v5" ,
@@ -140,9 +156,9 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
140
156
def make_env (cfg , idx , capture_video , run_name , gamma ):
141
157
def thunk ():
142
158
if capture_video :
143
- env = gym .make (cfg .env_id ) # , render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
159
+ env = gym .make (cfg .env_id , render_mode = "rgb_array" , early_termination = cfg .early_termination , term_cost = cfg .term_cost , failure_penalty = cfg .failure_penalty , reward_goal = cfg .reward_goal , reward_distance = cfg .reward_distance )
144
160
else :
145
- env = gym .make (cfg .env_id ) # , early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
161
+ env = gym .make (cfg .env_id , early_termination = cfg .early_termination , term_cost = cfg .term_cost , failure_penalty = cfg .failure_penalty , reward_goal = cfg .reward_goal , reward_distance = cfg .reward_distance )
146
162
env = gym .wrappers .FlattenObservation (env ) # deal with dm_control's Dict observation space
147
163
env = gym .wrappers .RecordEpisodeStatistics (env )
148
164
if capture_video :
0 commit comments