diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index 5730f462..bb4ab8c6 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -761,6 +761,13 @@ def train(cfg): # risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device) # writer.add_scalar("risk/risk_loss", risk_loss, global_step) + if cfg.fine_tune_risk == "off" and cfg.use_risk and global_step % cfg.risk_update_period == 0 and global_step >= cfg.start_risk_update: + for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)): + risk_data = rb.sample(cfg.risk_batch_size) + risk_loss = risk_update_step(risk_model, risk_data, criterion, opt_risk, device) + writer.add_scalar("risk/risk_loss", risk_loss.item(), global_step) + + ''' if cfg.fine_tune_risk == "sync" and cfg.use_risk: if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk: if cfg.finetune_risk_online: @@ -782,7 +789,7 @@ def train(cfg): data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk) risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device) writer.add_scalar("risk/risk_loss", risk_loss, global_step) - + ''' # Only print when at least 1 env is done if "final_info" not in infos: continue @@ -860,7 +867,7 @@ def train(cfg): if cfg.risk_type == "binary": rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1)) else: - rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) + rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks_quant, f_risks) f_obs[i] = None f_next_obs[i] = None @@ -994,7 +1001,7 @@ def train(cfg): torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "policy.pt")) wandb.save("policy.pt") if cfg.use_risk: - torch.save(model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt")) + torch.save(risk_model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt")) wandb.save("risk_model.pt") print(f_ep_len) envs.close()