From cd0f6824239401103e12c32fc487abfedce92c44 Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Sun, 5 Nov 2023 01:20:48 -0400 Subject: [PATCH] added object centric risk only --- cleanrl/ppo_continuous_action_wandb.py | 32 +++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index 1b5ae082..b9efecb5 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -123,11 +123,11 @@ def parse_args(): help="the id of the environment") parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="Use risk model in the critic or not ") - parser.add_argument("--model-type", type=str, default="mlp", + parser.add_argument("--model-type", type=str, default="bayesian", help="specify the NN to use for the risk model") parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") - parser.add_argument("--risk-type", type=str, default="binary", + parser.add_argument("--risk-type", type=str, default="quantile", help="whether the risk is binary or continuous") parser.add_argument("--fear-radius", type=int, default=5, help="fear radius for training the risk model") @@ -149,7 +149,7 @@ def parse_args(): help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") parser.add_argument("--start-risk-update", type=int, default=10000, help="number of epochs to update the risk model") - parser.add_argument("--rb-type", type=str, default="balanced", + parser.add_argument("--rb-type", type=str, default="simple", help="which type of replay buffer to use for ") parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") @@ -158,6 +158,8 @@ def parse_args(): parser.add_argument("--quantile-size", type=int, default=4, help="size of the risk quantile ") parser.add_argument("--quantile-num", type=int, default=5, help="number of quantiles to make") parser.add_argument("--risk-penalty", type=float, default=0., help="penalty to impose for entering risky states") + parser.add_argument("--risk-penalty-start", type=float, default=20., help="penalty to impose for entering risky states") + args = parser.parse_args() args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) @@ -467,15 +469,7 @@ def test_policy(cfg, agent, envs, device, risk_model=None): def get_risk_obs(cfg, next_obs): if cfg.unifying_lidar: - if "car" in cfg.risk_model_path.lower(): - if "point" in cfg.env_id.lower(): - next_risk_obs = torch.zeros((next_obs.size()[0], 120)).to(cfg.device) - other_crap = torch.Tensor([[0, 0, 0, 1., 0., 0., 0., 1., 0., 0., 0., 1.]]*next_obs.shape[0]).to(cfg.device) - next_risk_obs[:, list(range(12))] = next_obs[:, list(range(12))] - next_risk_obs[:, list(range(12,24))] = other_crap - next_risk_obs[:, list(range(24, 120))] = next_obs[:, list(range(12, 108))] - return next_risk_obs - return next_obs + return next_obs[:, -96:] if "goal" in cfg.risk_model_path.lower(): if "push" in cfg.env_id.lower(): #print("push") @@ -580,7 +574,7 @@ def train(cfg): agent = RiskAgent(envs=envs, risk_size=risk_size).to(device) #else: # agent = ContRiskAgent(envs=envs).to(device) - risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=risk_obs_size, batch_norm=True, out_size=risk_size) + risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size) if os.path.exists(cfg.risk_model_path): risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device)) print("Pretrained risk model loaded successfully") @@ -664,14 +658,14 @@ def train(cfg): buffer_num = 0 goal_met = 0; ep_goal_met = 0 #update = 0 - risk_penalty, ep_risk_penalty = 0, 0 + risk_penalty, ep_risk_penalty, total_risk_updates = 0, 0, 0 for update in range(1, num_updates + 1): # Annealing the rate if instructed to do so. if cfg.anneal_lr: frac = 1.0 - (update - 1.0) / num_updates lrnow = frac * cfg.learning_rate optimizer.param_groups[0]["lr"] = lrnow - + for step in range(0, cfg.num_steps): risk = torch.Tensor([[0.]]).to(device) global_step += 1 * cfg.num_envs @@ -695,7 +689,10 @@ def train(cfg): id_risk = int(next_risk[:,0] >= 1 / (cfg.fear_radius + 1)) next_risk = torch.zeros_like(next_risk) next_risk[:, id_risk] = 1 - risk_penalty = torch.sum(torch.div(torch.exp(next_risk), quantile_means) * cfg.risk_penalty) + if cfg.risk_model_path == "None" or (cfg.risk_model_path == "scratch" and total_risk_updates < cfg.risk_penalty_start): + risk_penalty = torch.Tensor([0.]).to(device) + else: + risk_penalty = torch.sum(torch.div(torch.exp(next_risk), quantile_means) * cfg.risk_penalty) ep_risk_penalty += risk_penalty.item() # print(next_risk) risks[step] = torch.exp(next_risk) @@ -771,6 +768,8 @@ def train(cfg): elif cfg.fine_tune_risk == "off" and cfg.use_risk: if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: for epoch in range(cfg.num_risk_epochs): + total_risk_updates +=1 + print(total_risk_updates) if cfg.finetune_risk_online: print("I am online") data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0) @@ -788,6 +787,7 @@ def train(cfg): # Skip the envs that are not done if info is None: continue + print(ep_risk_penalty) ep_cost = info["cost_sum"] cum_cost += ep_cost ep_len = info["episode"]["l"][0]