Skip to content

Commit

Permalink
added object centric risk only
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Nov 5, 2023
1 parent 9ea66b0 commit cd0f682
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def parse_args():
help="the id of the environment")
parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="mlp",
parser.add_argument("--model-type", type=str, default="bayesian",
help="specify the NN to use for the risk model")
parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--risk-type", type=str, default="binary",
parser.add_argument("--risk-type", type=str, default="quantile",
help="whether the risk is binary or continuous")
parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
Expand All @@ -149,7 +149,7 @@ def parse_args():
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--start-risk-update", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--rb-type", type=str, default="balanced",
parser.add_argument("--rb-type", type=str, default="simple",
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
Expand All @@ -158,6 +158,8 @@ def parse_args():
parser.add_argument("--quantile-size", type=int, default=4, help="size of the risk quantile ")
parser.add_argument("--quantile-num", type=int, default=5, help="number of quantiles to make")
parser.add_argument("--risk-penalty", type=float, default=0., help="penalty to impose for entering risky states")
parser.add_argument("--risk-penalty-start", type=float, default=20., help="penalty to impose for entering risky states")

args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
Expand Down Expand Up @@ -467,15 +469,7 @@ def test_policy(cfg, agent, envs, device, risk_model=None):

def get_risk_obs(cfg, next_obs):
if cfg.unifying_lidar:
if "car" in cfg.risk_model_path.lower():
if "point" in cfg.env_id.lower():
next_risk_obs = torch.zeros((next_obs.size()[0], 120)).to(cfg.device)
other_crap = torch.Tensor([[0, 0, 0, 1., 0., 0., 0., 1., 0., 0., 0., 1.]]*next_obs.shape[0]).to(cfg.device)
next_risk_obs[:, list(range(12))] = next_obs[:, list(range(12))]
next_risk_obs[:, list(range(12,24))] = other_crap
next_risk_obs[:, list(range(24, 120))] = next_obs[:, list(range(12, 108))]
return next_risk_obs
return next_obs
return next_obs[:, -96:]
if "goal" in cfg.risk_model_path.lower():
if "push" in cfg.env_id.lower():
#print("push")
Expand Down Expand Up @@ -580,7 +574,7 @@ def train(cfg):
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=risk_obs_size, batch_norm=True, out_size=risk_size)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size)
if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")
Expand Down Expand Up @@ -664,14 +658,14 @@ def train(cfg):
buffer_num = 0
goal_met = 0; ep_goal_met = 0
#update = 0
risk_penalty, ep_risk_penalty = 0, 0
risk_penalty, ep_risk_penalty, total_risk_updates = 0, 0, 0
for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
if cfg.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
lrnow = frac * cfg.learning_rate
optimizer.param_groups[0]["lr"] = lrnow

for step in range(0, cfg.num_steps):
risk = torch.Tensor([[0.]]).to(device)
global_step += 1 * cfg.num_envs
Expand All @@ -695,7 +689,10 @@ def train(cfg):
id_risk = int(next_risk[:,0] >= 1 / (cfg.fear_radius + 1))
next_risk = torch.zeros_like(next_risk)
next_risk[:, id_risk] = 1
risk_penalty = torch.sum(torch.div(torch.exp(next_risk), quantile_means) * cfg.risk_penalty)
if cfg.risk_model_path == "None" or (cfg.risk_model_path == "scratch" and total_risk_updates < cfg.risk_penalty_start):
risk_penalty = torch.Tensor([0.]).to(device)
else:
risk_penalty = torch.sum(torch.div(torch.exp(next_risk), quantile_means) * cfg.risk_penalty)
ep_risk_penalty += risk_penalty.item()
# print(next_risk)
risks[step] = torch.exp(next_risk)
Expand Down Expand Up @@ -771,6 +768,8 @@ def train(cfg):
elif cfg.fine_tune_risk == "off" and cfg.use_risk:
if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
for epoch in range(cfg.num_risk_epochs):
total_risk_updates +=1
print(total_risk_updates)
if cfg.finetune_risk_online:
print("I am online")
data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0)
Expand All @@ -788,6 +787,7 @@ def train(cfg):
# Skip the envs that are not done
if info is None:
continue
print(ep_risk_penalty)
ep_cost = info["cost_sum"]
cum_cost += ep_cost
ep_len = info["episode"]["l"][0]
Expand Down

0 comments on commit cd0f682

Please sign in to comment.