Skip to content

Commit

Permalink
collect data
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Jan 11, 2024
1 parent cf4fd7a commit baa934e
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ def train(cfg):

if cfg.collect_data:
#os.system("rm -rf %s"%cfg.storage_path)
storage_path = os.path.join(cfg.storage_path, cfg.env_id, run.name)
make_dirs(storage_path, episode)
storage_path = os.path.join(cfg.storage_path, run.sweep_id, cfg.env_id, run.name)
#make_dirs(storage_path, episode)

buffer_num = 0
goal_met = 0; ep_goal_met = 0
Expand Down Expand Up @@ -761,13 +761,6 @@ def train(cfg):
# risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
# writer.add_scalar("risk/risk_loss", risk_loss, global_step)

if cfg.fine_tune_risk == "off" and cfg.use_risk and global_step % cfg.risk_update_period == 0 and global_step >= cfg.start_risk_update:
for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)):
risk_data = rb.sample(cfg.risk_batch_size)
risk_loss = risk_update_step(risk_model, risk_data, criterion, opt_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss.item(), global_step)

'''
if cfg.fine_tune_risk == "sync" and cfg.use_risk:
if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk:
if cfg.finetune_risk_online:
Expand All @@ -789,7 +782,7 @@ def train(cfg):
data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
'''

# Only print when at least 1 env is done
if "final_info" not in infos:
continue
Expand Down Expand Up @@ -867,7 +860,7 @@ def train(cfg):
if cfg.risk_type == "binary":
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
else:
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks_quant, f_risks)
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)

f_obs[i] = None
f_next_obs[i] = None
Expand All @@ -882,11 +875,13 @@ def train(cfg):

## Save all the data
if cfg.collect_data:
torch.save(f_obs, os.path.join(storage_path, "obs.pt"))
torch.save(f_actions, os.path.join(storage_path, "actions.pt"))
torch.save(f_costs, os.path.join(storage_path, "costs.pt"))
torch.save(f_risks, os.path.join(storage_path, "risks.pt"))
torch.save(torch.Tensor(f_ep_len), os.path.join(storage_path, "ep_len.pt"))
os.makedirs(os.path.join(storage_path, episode))
torch.save(f_obs[0], os.path.join(storage_path, episode, "obs.pt"))
torch.save(f_next_obs[0], os.path.join(storage_path, episode, "next_obs.pt"))
torch.save(f_actions[0], os.path.join(storage_path, episode, "actions.pt"))
torch.save(f_costs[0], os.path.join(storage_path, episode, "costs.pt"))
torch.save(f_risks[0], os.path.join(storage_path, episode, "risks.pt"))
#torch.save(torch.Tensor(f_ep_len), os.path.join(storage_path, "ep_len.pt"))
#make_dirs(storage_path, episode)

# bootstrap value if not done
Expand Down Expand Up @@ -1001,7 +996,7 @@ def train(cfg):
torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "policy.pt"))
wandb.save("policy.pt")
if cfg.use_risk:
torch.save(risk_model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
torch.save(model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
wandb.save("risk_model.pt")
print(f_ep_len)
envs.close()
Expand Down

0 comments on commit baa934e

Please sign in to comment.