Skip to content

Commit c6a30f4

Browse files
committed
working with velocity
1 parent da1dbfc commit c6a30f4

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

cleanrl/ppo_continuous_action_wandb.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,13 @@ def forward(self, x):
206206

207207
def make_env(cfg, idx, capture_video, run_name, gamma):
208208
def thunk():
209-
210-
if "safety" in cfg.env_id.lower():
209+
if "velocity" in cfg.env_id.lower() or "safety" not in cfg.env_id.lower():
210+
env = gym.make(cfg.env_id)
211+
else:
211212
if capture_video:
212-
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
213+
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
213214
else:
214215
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
215-
else:
216-
env = gym.make(cfg.env_id)
217216
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
218217
env = gym.wrappers.RecordEpisodeStatistics(env)
219218
if capture_video:
@@ -733,11 +732,11 @@ def train(cfg):
733732
if info is None:
734733
continue
735734
print(ep_risk_penalty)
736-
ep_cost = info["cost_sum"] if "safe" in cfg.env_id.lower() else info["cost"]
735+
ep_cost = info["cost"]
737736
cum_cost += ep_cost
738737
ep_len = info["episode"]["l"][0]
739738
buffer_num += ep_len
740-
goal_met_ep = info["cum_goal_met"] if "safe" in cfg.env_id.lower() else info["is_success"]
739+
goal_met_ep = info["cum_goal_met"] if "safe" in cfg.env_id.lower() and "velocity" not in cfg.env_id.lower() else 0
741740
goal_met += goal_met_ep
742741
#print(f"global_step={global_step}, episodic_return={info['episode']['r']}, episode_cost={ep_cost}")
743742
scores.append(info['episode']['r'])
@@ -752,7 +751,7 @@ def train(cfg):
752751
writer.add_scalar("Results/Avg_Return", avg_mean_score, global_step)
753752
torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "policy.pt"))
754753
wandb.save("policy.pt")
755-
print(f"cummulative_cost={cum_cost}, global_step={global_step}, episodic_return={avg_mean_score}, episode_cost={ep_cost}")
754+
print(f"cummulative_cost={cum_cost}, global_step={global_step}, episodic_return={info['episode']['r']}, avg_episodic_return={avg_mean_score}, episode_cost={ep_cost}")
756755
if cfg.use_risk:
757756
ep_risk = torch.sum(all_risks.squeeze()[last_step:global_step, 0]).item()
758757
cum_risk += ep_risk
@@ -769,7 +768,7 @@ def train(cfg):
769768
step_log = 0
770769
ep_risk_penalty = 0
771770
# f_dist_to_fail = torch.Tensor(np.array(list(reversed(range(f_obs.size()[0]))))).to(device) if cost > 0 else torch.Tensor(np.array([f_obs.size()[0]]*f_obs.shape[0])).to(device)
772-
e_risks = np.array(list(reversed(range(int(ep_len))))) if cum_cost > 0 else np.array([int(ep_len)]*int(ep_len))
771+
e_risks = np.array(list(reversed(range(int(ep_len))))) if terminated else np.array([int(ep_len)]*int(ep_len))
773772
# print(risks.size())
774773
e_risks = torch.Tensor(e_risks)
775774
if cfg.fine_tune_risk != "None" and cfg.use_risk:

0 commit comments

Comments
 (0)