Skip to content

Commit

Permalink
simplifying the training of risk model
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Jan 11, 2024
1 parent 60c05a0 commit cf4fd7a
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,13 @@ def train(cfg):
# risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
# writer.add_scalar("risk/risk_loss", risk_loss, global_step)

if cfg.fine_tune_risk == "off" and cfg.use_risk and global_step % cfg.risk_update_period == 0 and global_step >= cfg.start_risk_update:
for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)):
risk_data = rb.sample(cfg.risk_batch_size)
risk_loss = risk_update_step(risk_model, risk_data, criterion, opt_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss.item(), global_step)

'''
if cfg.fine_tune_risk == "sync" and cfg.use_risk:
if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk:
if cfg.finetune_risk_online:
Expand All @@ -782,7 +789,7 @@ def train(cfg):
data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
writer.add_scalar("risk/risk_loss", risk_loss, global_step)

'''
# Only print when at least 1 env is done
if "final_info" not in infos:
continue
Expand Down Expand Up @@ -860,7 +867,7 @@ def train(cfg):
if cfg.risk_type == "binary":
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
else:
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks_quant, f_risks)

f_obs[i] = None
f_next_obs[i] = None
Expand Down Expand Up @@ -994,7 +1001,7 @@ def train(cfg):
torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "policy.pt"))
wandb.save("policy.pt")
if cfg.use_risk:
torch.save(model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
torch.save(risk_model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
wandb.save("risk_model.pt")
print(f_ep_len)
envs.close()
Expand Down

0 comments on commit cf4fd7a

Please sign in to comment.