diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index 28b4faf0..79b2abbb 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -838,7 +838,7 @@ def train(cfg): real_next_obs[idx] = infos["final_observation"][idx] if cfg.use_csc: - csc_rb.add(obs_, real_next_obs, action, np.array(terminated).astype(float), done, infos) + csc_rb.add(obs_, real_next_obs, action.cpu().numpy(), np.array(terminated).astype(float), done, infos) info_dict = {'reward': reward, 'done': done, 'cost': cost, 'obs': obs} # if cfg.collect_data: @@ -868,7 +868,7 @@ def train(cfg): f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0) obs_ = next_obs_ - obs = next_obs + obs_old = next_obs # if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk: # if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: # for epoch in range(cfg.num_risk_epochs):