diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index b9efecb5b..a9761bfa2 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -1,6 +1,7 @@ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy import argparse import os +import tqdm import random import time from distutils.util import strtobool @@ -167,6 +168,38 @@ def parse_args(): return args +class QuantileLoss(nn.Module): + def __init__(self, quantiles): + super().__init__() + self.quantiles = quantiles + + def forward(self, preds, target): + assert not target.requires_grad + assert preds.size(0) == target.size(0) + losses = [] + for i, q in enumerate(self.quantiles): + errors = target - preds[:, i] + losses.append(torch.max((q-1) * errors, q * errors).unsqueeze(1)) + loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1)) + return loss + +class QRNN(nn.Module): + def __init__(self, input_size, hidden_size1, hidden_size2, output_size, quantiles): + super(QRNN, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size1) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size1, hidden_size2) + self.fc3 = nn.Linear(hidden_size2, output_size * len(quantiles)) + self.quantiles = quantiles + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + out = self.relu(out) + out = self.fc3(out) + out = out.view(out.size(0), len(self.quantiles)) # Reshape output to separate quantiles + return out @@ -315,42 +348,6 @@ def get_action_and_value(self, x, action=None): return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) -class ContRiskAgent(nn.Module): - def __init__(self, envs, linear_size=64, risk_size=1): - super().__init__() - self.critic = nn.Sequential( - layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+1, linear_size)), - nn.Tanh(), - layer_init(nn.Linear(linear_size, linear_size)), - nn.Tanh(), - layer_init(nn.Linear(linear_size, 1), std=1.0), - ) - self.actor_mean = nn.Sequential( - layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+1, linear_size)), - nn.Tanh(), - layer_init(nn.Linear(linear_size, linear_size)), - nn.Tanh(), - layer_init(nn.Linear(linear_size, np.prod(envs.single_action_space.shape)), std=0.01), - ) - self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) - - def get_value(self, x, risk): - x = torch.cat([x, risk], axis=1) - return self.critic(x) - - def get_action_and_value(self, x, risk, action=None): - x = torch.cat([x, risk], axis=1) - action_mean = self.actor_mean(x) - action_logstd = self.actor_logstd.expand_as(action_mean) - action_std = torch.exp(action_logstd) - probs = Normal(action_mean, action_std) - if action is None: - action = probs.sample() - return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) - - - - class RiskDataset(nn.Module): def __init__(self, inputs, targets): self.inputs = inputs @@ -360,9 +357,7 @@ def __len__(self): return self.inputs.size()[0] def __getitem__(self, idx): - y = torch.zeros(2) - y[int(self.targets[idx][0])] = 1.0 - return self.inputs[idx], y + return self.inputs[idx], self.targets[idx] @@ -414,16 +409,12 @@ def risk_sgd_step(cfg, model, data, criterion, opt, device): def train_risk(cfg, model, data, criterion, opt, device): model.train() - dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type, - fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num) + dataset = RiskDataset(data["next_obs"].to('cpu'), data["dist_to_fail"].to('cpu')) dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu')) net_loss = 0 for batch in dataloader: pred = model(get_risk_obs(cfg, batch[0]).to(device)) - if cfg.model_type == "mlp": - loss = criterion(pred, batch[1].squeeze().to(device)) - else: - loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device)) + loss = criterion(pred, batch[1]) opt.zero_grad() loss.backward() opt.step() @@ -469,7 +460,7 @@ def test_policy(cfg, agent, envs, device, risk_model=None): def get_risk_obs(cfg, next_obs): if cfg.unifying_lidar: - return next_obs[:, -96:] + return next_obs if "goal" in cfg.risk_model_path.lower(): if "push" in cfg.env_id.lower(): #print("push") @@ -510,7 +501,7 @@ def train(cfg): cfg.use_risk = False if cfg.risk_model_path == "None" else True import wandb - run = wandb.init(config=vars(cfg), entity="kaustubh95", + run = wandb.init(config=vars(cfg), entity="kaustubh_umontreal", project="risk_aware_exploration", monitor_gym=True, sync_tensorboard=True, save_code=True) @@ -527,11 +518,10 @@ def train(cfg): np.random.seed(cfg.seed) torch.manual_seed(cfg.model_seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + quantiles = [0.01, 0.05, 0.1, 0.2, 0.5] device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") cfg.device = device - risk_bins = np.array([i*cfg.quantile_size for i in range(cfg.quantile_num)]) - quantile_means = torch.Tensor(np.array([((i+0.5)*(float(cfg.quantile_size)))**(i+1) for i in range(cfg.quantile_num-1)] + [np.inf])).to(device) # env setup envs = gym.vector.SyncVectorEnv( [make_env(cfg, i, cfg.capture_video, run_name, cfg.gamma) for i in range(cfg.num_envs)] @@ -549,32 +539,13 @@ def train(cfg): else: rb = ReplayBuffer(buffer_size=cfg.total_timesteps) #, observation_space=envs.single_observation_space, action_space=envs.single_action_space) - if cfg.risk_type == "quantile": - weight_tensor = torch.Tensor([1]*cfg.quantile_num).to(device) - weight_tensor[0] = cfg.weight - elif cfg.risk_type == "binary": - weight_tensor = torch.Tensor([1., cfg.weight]).to(device) - if cfg.model_type == "bayesian": - criterion = nn.NLLLoss(weight=weight_tensor) - else: - criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor) - - if cfg.risk_model_path == "scratch": - risk_obs_size = np.array(envs.single_observation_space.shape).prod() - else: - if "car" in cfg.risk_model_path.lower(): - risk_obs_size = 120 - elif "point" in cfg.risk_model_path.lower(): - risk_obs_size = 108 - + criterion = QuantileLoss(quantiles) if cfg.use_risk: print("using risk") - #if cfg.risk_type == "binary": agent = RiskAgent(envs=envs, risk_size=risk_size).to(device) - #else: - # agent = ContRiskAgent(envs=envs).to(device) - risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size) + risk_model = QRNN(np.array(envs.single_observation_space.shape).prod(), 64, 64, 1, quantiles).to(device) + if os.path.exists(cfg.risk_model_path): risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device)) print("Pretrained risk model loaded successfully") @@ -678,27 +649,7 @@ def train(cfg): with torch.no_grad(): next_obs_risk = get_risk_obs(cfg, next_obs) next_risk = torch.Tensor(risk_model(next_obs_risk.to(device))).to(device) - if cfg.risk_type == "continuous": - next_risk = next_risk.unsqueeze(0) - #print(next_risk.size()) - if cfg.binary_risk and cfg.risk_type == "binary": - id_risk = torch.argmax(next_risk, axis=1) - next_risk = torch.zeros_like(next_risk) - next_risk[:, id_risk] = 1 - elif cfg.binary_risk and cfg.risk_type == "continuous": - id_risk = int(next_risk[:,0] >= 1 / (cfg.fear_radius + 1)) - next_risk = torch.zeros_like(next_risk) - next_risk[:, id_risk] = 1 - if cfg.risk_model_path == "None" or (cfg.risk_model_path == "scratch" and total_risk_updates < cfg.risk_penalty_start): - risk_penalty = torch.Tensor([0.]).to(device) - else: - risk_penalty = torch.sum(torch.div(torch.exp(next_risk), quantile_means) * cfg.risk_penalty) - ep_risk_penalty += risk_penalty.item() - # print(next_risk) - risks[step] = torch.exp(next_risk) - all_risks[global_step] = torch.exp(next_risk)#, axis=-1) - - + print(next_risk) # ALGO LOGIC: action logic with torch.no_grad(): if cfg.use_risk: @@ -745,17 +696,6 @@ def train(cfg): f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0) obs_ = next_obs - # if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk: - # if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: - # for epoch in range(cfg.num_risk_epochs): - # if cfg.finetune_risk_online: - # print("I am online") - # data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0) - # else: - # data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk) - # risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device) - # writer.add_scalar("risk/risk_loss", risk_loss, global_step) - if cfg.fine_tune_risk == "sync" and cfg.use_risk: if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk: if cfg.finetune_risk_online: @@ -767,7 +707,7 @@ def train(cfg): writer.add_scalar("risk/risk_loss", risk_loss, global_step) elif cfg.fine_tune_risk == "off" and cfg.use_risk: if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: - for epoch in range(cfg.num_risk_epochs): + for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)): total_risk_updates +=1 print(total_risk_updates) if cfg.finetune_risk_online: @@ -825,13 +765,9 @@ def train(cfg): # f_dist_to_fail = torch.Tensor(np.array(list(reversed(range(f_obs.size()[0]))))).to(device) if cost > 0 else torch.Tensor(np.array([f_obs.size()[0]]*f_obs.shape[0])).to(device) e_risks = np.array(list(reversed(range(int(ep_len))))) if cum_cost > 0 else np.array([int(ep_len)]*int(ep_len)) # print(risks.size()) - e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1))) e_risks = torch.Tensor(e_risks) - - print(e_risks_quant.size()) if cfg.fine_tune_risk != "None" and cfg.use_risk: f_risks = e_risks.unsqueeze(1) - f_risks_quant = e_risks_quant elif cfg.collect_data: f_risks = e_risks.unsqueeze(1) if f_risks is None else torch.concat([f_risks, e_risks.unsqueeze(1)], axis=0) @@ -840,20 +776,10 @@ def train(cfg): if cfg.rb_type == "balanced": idx_risky = (f_dist_to_fail<=cfg.fear_radius) idx_safe = (f_dist_to_fail>cfg.fear_radius) - risk_ones = torch.ones_like(f_risks) - risk_zeros = torch.zeros_like(f_risks) - - if cfg.risk_type == "binary": - rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], risk_ones[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) - rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], risk_zeros[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) - else: - rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) - rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks_quant[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) + rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) + rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) else: - if cfg.risk_type == "binary": - rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1)) - else: - rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) + rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) f_obs[i] = None f_next_obs[i] = None @@ -900,7 +826,8 @@ def train(cfg): b_returns = returns.reshape(-1) b_values = values.reshape(-1) #if cfg.risk_type == "discrete": - b_risks = risks.reshape((-1, ) + (risk_size, )) + with torch.no_grad(): + b_risks = risk_model(b_obs) #else: # b_risks = risks.reshape((-1, ) + (1, ))