From 8d2d788fc5d118b20f3b81aacda60987b7e5263e Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Thu, 11 Jan 2024 16:53:05 -0500 Subject: [PATCH] fixing bugs --- cleanrl/ppo_continuous_action_wandb.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index 4200025c..93146e18 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -654,7 +654,7 @@ def train(cfg): if cfg.collect_data: #os.system("rm -rf %s"%cfg.storage_path) - storage_path = os.path.join(cfg.storage_path, run.sweep_id, cfg.env_id, run.name) + storage_path = os.path.join(cfg.storage_path, "sweep" if run.sweep_id is None else run.sweep_id, cfg.env_id, run.name) #make_dirs(storage_path, episode) buffer_num = 0 @@ -875,12 +875,20 @@ def train(cfg): ## Save all the data if cfg.collect_data: - os.makedirs(os.path.join(storage_path, episode)) - torch.save(f_obs[0], os.path.join(storage_path, episode, "obs.pt")) - torch.save(f_next_obs[0], os.path.join(storage_path, episode, "next_obs.pt")) - torch.save(f_actions[0], os.path.join(storage_path, episode, "actions.pt")) - torch.save(f_costs[0], os.path.join(storage_path, episode, "costs.pt")) - torch.save(f_risks[0], os.path.join(storage_path, episode, "risks.pt")) + os.makedirs(os.path.join(storage_path, str(episode))) + torch.save(f_obs[0], os.path.join(storage_path, str(episode), "obs.pt")) + torch.save(f_next_obs[0], os.path.join(storage_path, str(episode), "next_obs.pt")) + torch.save(f_actions[0], os.path.join(storage_path, str(episode), "actions.pt")) + torch.save(f_costs[0], os.path.join(storage_path, str(episode), "costs.pt")) + torch.save(f_risks, os.path.join(storage_path, str(episode), "risks.pt")) + f_obs[i] = None + f_next_obs[i] = None + f_risks = None + #f_ep_len = None + f_actions[i] = None + f_rewards[i] = None + f_dones[i] = None + f_costs[i] = None #torch.save(torch.Tensor(f_ep_len), os.path.join(storage_path, "ep_len.pt")) #make_dirs(storage_path, episode)