Skip to content

Commit

Permalink
adding evaluation capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Nov 23, 2023
1 parent 7d97bf3 commit 020c631
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
cost = int(terminated) and (rewards == 0)
goal = int(terminated) and (rewards > 0)
total_goals += goal
if (args.fine_tune_risk != "None" and args.use_risk):
for i in range(args.num_envs):
f_obs[i] = torch.Tensor(obs["image"][i]).reshape(1, -1).to(device) if f_obs[i] is None else torch.concat([f_obs[i], torch.Tensor(obs["image"][i]).reshape(1, -1).to(device)], axis=0)
Expand All @@ -313,7 +315,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
continue
total_cost += cost
ep_len = info["episode"]["l"]

if args.use_risk and args.fine_tune_risk != "None":
e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len))
e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1)))
Expand All @@ -327,6 +328,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
scores.append(info['episode']['r'])
print(f"global_step={global_step}, episodic_return={info['episode']['r']}, total cost={total_cost}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/Total Goals", total_goals, global_step)
writer.add_scalar("charts/Avg Return", np.mean(scores[-100:]), global_step)
writer.add_scalar("charts/total_cost", total_cost, global_step)
writer.add_scalar("charts/episodic_cost", cost, global_step)
Expand Down

0 comments on commit 020c631

Please sign in to comment.