Skip to content

Commit

Permalink
working with QR version of risk model
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Feb 22, 2024
1 parent cd0f682 commit ffad08a
Showing 1 changed file with 49 additions and 122 deletions.
171 changes: 49 additions & 122 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy
import argparse
import os
import tqdm
import random
import time
from distutils.util import strtobool
Expand Down Expand Up @@ -167,6 +168,38 @@ def parse_args():
return args


class QuantileLoss(nn.Module):
def __init__(self, quantiles):
super().__init__()
self.quantiles = quantiles

def forward(self, preds, target):
assert not target.requires_grad
assert preds.size(0) == target.size(0)
losses = []
for i, q in enumerate(self.quantiles):
errors = target - preds[:, i]
losses.append(torch.max((q-1) * errors, q * errors).unsqueeze(1))
loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
return loss

class QRNN(nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, output_size, quantiles):
super(QRNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size1)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size1, hidden_size2)
self.fc3 = nn.Linear(hidden_size2, output_size * len(quantiles))
self.quantiles = quantiles

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = out.view(out.size(0), len(self.quantiles)) # Reshape output to separate quantiles
return out



Expand Down Expand Up @@ -315,42 +348,6 @@ def get_action_and_value(self, x, action=None):
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)


class ContRiskAgent(nn.Module):
def __init__(self, envs, linear_size=64, risk_size=1):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+1, linear_size)),
nn.Tanh(),
layer_init(nn.Linear(linear_size, linear_size)),
nn.Tanh(),
layer_init(nn.Linear(linear_size, 1), std=1.0),
)
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+1, linear_size)),
nn.Tanh(),
layer_init(nn.Linear(linear_size, linear_size)),
nn.Tanh(),
layer_init(nn.Linear(linear_size, np.prod(envs.single_action_space.shape)), std=0.01),
)
self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

def get_value(self, x, risk):
x = torch.cat([x, risk], axis=1)
return self.critic(x)

def get_action_and_value(self, x, risk, action=None):
x = torch.cat([x, risk], axis=1)
action_mean = self.actor_mean(x)
action_logstd = self.actor_logstd.expand_as(action_mean)
action_std = torch.exp(action_logstd)
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)




class RiskDataset(nn.Module):
def __init__(self, inputs, targets):
self.inputs = inputs
Expand All @@ -360,9 +357,7 @@ def __len__(self):
return self.inputs.size()[0]

def __getitem__(self, idx):
y = torch.zeros(2)
y[int(self.targets[idx][0])] = 1.0
return self.inputs[idx], y
return self.inputs[idx], self.targets[idx]



Expand Down Expand Up @@ -414,16 +409,12 @@ def risk_sgd_step(cfg, model, data, criterion, opt, device):

def train_risk(cfg, model, data, criterion, opt, device):
model.train()
dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
dataset = RiskDataset(data["next_obs"].to('cpu'), data["dist_to_fail"].to('cpu'))
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu'))
net_loss = 0
for batch in dataloader:
pred = model(get_risk_obs(cfg, batch[0]).to(device))
if cfg.model_type == "mlp":
loss = criterion(pred, batch[1].squeeze().to(device))
else:
loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device))
loss = criterion(pred, batch[1])
opt.zero_grad()
loss.backward()
opt.step()
Expand Down Expand Up @@ -469,7 +460,7 @@ def test_policy(cfg, agent, envs, device, risk_model=None):

def get_risk_obs(cfg, next_obs):
if cfg.unifying_lidar:
return next_obs[:, -96:]
return next_obs
if "goal" in cfg.risk_model_path.lower():
if "push" in cfg.env_id.lower():
#print("push")
Expand Down Expand Up @@ -510,7 +501,7 @@ def train(cfg):
cfg.use_risk = False if cfg.risk_model_path == "None" else True

import wandb
run = wandb.init(config=vars(cfg), entity="kaustubh95",
run = wandb.init(config=vars(cfg), entity="kaustubh_umontreal",
project="risk_aware_exploration",
monitor_gym=True,
sync_tensorboard=True, save_code=True)
Expand All @@ -527,11 +518,10 @@ def train(cfg):
np.random.seed(cfg.seed)
torch.manual_seed(cfg.model_seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic
quantiles = [0.01, 0.05, 0.1, 0.2, 0.5]

device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
cfg.device = device
risk_bins = np.array([i*cfg.quantile_size for i in range(cfg.quantile_num)])
quantile_means = torch.Tensor(np.array([((i+0.5)*(float(cfg.quantile_size)))**(i+1) for i in range(cfg.quantile_num-1)] + [np.inf])).to(device)
# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(cfg, i, cfg.capture_video, run_name, cfg.gamma) for i in range(cfg.num_envs)]
Expand All @@ -549,32 +539,13 @@ def train(cfg):
else:
rb = ReplayBuffer(buffer_size=cfg.total_timesteps)
#, observation_space=envs.single_observation_space, action_space=envs.single_action_space)
if cfg.risk_type == "quantile":
weight_tensor = torch.Tensor([1]*cfg.quantile_num).to(device)
weight_tensor[0] = cfg.weight
elif cfg.risk_type == "binary":
weight_tensor = torch.Tensor([1., cfg.weight]).to(device)
if cfg.model_type == "bayesian":
criterion = nn.NLLLoss(weight=weight_tensor)
else:
criterion = nn.BCEWithLogitsLoss(pos_weight=weight_tensor)

if cfg.risk_model_path == "scratch":
risk_obs_size = np.array(envs.single_observation_space.shape).prod()
else:
if "car" in cfg.risk_model_path.lower():
risk_obs_size = 120
elif "point" in cfg.risk_model_path.lower():
risk_obs_size = 108

criterion = QuantileLoss(quantiles)

if cfg.use_risk:
print("using risk")
#if cfg.risk_type == "binary":
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size)
risk_model = QRNN(np.array(envs.single_observation_space.shape).prod(), 64, 64, 1, quantiles).to(device)

if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")
Expand Down Expand Up @@ -678,27 +649,7 @@ def train(cfg):
with torch.no_grad():
next_obs_risk = get_risk_obs(cfg, next_obs)
next_risk = torch.Tensor(risk_model(next_obs_risk.to(device))).to(device)
if cfg.risk_type == "continuous":
next_risk = next_risk.unsqueeze(0)
#print(next_risk.size())
if cfg.binary_risk and cfg.risk_type == "binary":
id_risk = torch.argmax(next_risk, axis=1)
next_risk = torch.zeros_like(next_risk)
next_risk[:, id_risk] = 1
elif cfg.binary_risk and cfg.risk_type == "continuous":
id_risk = int(next_risk[:,0] >= 1 / (cfg.fear_radius + 1))
next_risk = torch.zeros_like(next_risk)
next_risk[:, id_risk] = 1
if cfg.risk_model_path == "None" or (cfg.risk_model_path == "scratch" and total_risk_updates < cfg.risk_penalty_start):
risk_penalty = torch.Tensor([0.]).to(device)
else:
risk_penalty = torch.sum(torch.div(torch.exp(next_risk), quantile_means) * cfg.risk_penalty)
ep_risk_penalty += risk_penalty.item()
# print(next_risk)
risks[step] = torch.exp(next_risk)
all_risks[global_step] = torch.exp(next_risk)#, axis=-1)


print(next_risk)
# ALGO LOGIC: action logic
with torch.no_grad():
if cfg.use_risk:
Expand Down Expand Up @@ -745,17 +696,6 @@ def train(cfg):
f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0)

obs_ = next_obs
# if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk:
# if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
# for epoch in range(cfg.num_risk_epochs):
# if cfg.finetune_risk_online:
# print("I am online")
# data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0)
# else:
# data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
# risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
# writer.add_scalar("risk/risk_loss", risk_loss, global_step)

if cfg.fine_tune_risk == "sync" and cfg.use_risk:
if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk:
if cfg.finetune_risk_online:
Expand All @@ -767,7 +707,7 @@ def train(cfg):
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
elif cfg.fine_tune_risk == "off" and cfg.use_risk:
if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
for epoch in range(cfg.num_risk_epochs):
for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)):
total_risk_updates +=1
print(total_risk_updates)
if cfg.finetune_risk_online:
Expand Down Expand Up @@ -825,13 +765,9 @@ def train(cfg):
# f_dist_to_fail = torch.Tensor(np.array(list(reversed(range(f_obs.size()[0]))))).to(device) if cost > 0 else torch.Tensor(np.array([f_obs.size()[0]]*f_obs.shape[0])).to(device)
e_risks = np.array(list(reversed(range(int(ep_len))))) if cum_cost > 0 else np.array([int(ep_len)]*int(ep_len))
# print(risks.size())
e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1)))
e_risks = torch.Tensor(e_risks)

print(e_risks_quant.size())
if cfg.fine_tune_risk != "None" and cfg.use_risk:
f_risks = e_risks.unsqueeze(1)
f_risks_quant = e_risks_quant
elif cfg.collect_data:
f_risks = e_risks.unsqueeze(1) if f_risks is None else torch.concat([f_risks, e_risks.unsqueeze(1)], axis=0)

Expand All @@ -840,20 +776,10 @@ def train(cfg):
if cfg.rb_type == "balanced":
idx_risky = (f_dist_to_fail<=cfg.fear_radius)
idx_safe = (f_dist_to_fail>cfg.fear_radius)
risk_ones = torch.ones_like(f_risks)
risk_zeros = torch.zeros_like(f_risks)

if cfg.risk_type == "binary":
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], risk_ones[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], risk_zeros[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
else:
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks_quant[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
else:
if cfg.risk_type == "binary":
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
else:
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)

f_obs[i] = None
f_next_obs[i] = None
Expand Down Expand Up @@ -900,7 +826,8 @@ def train(cfg):
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
#if cfg.risk_type == "discrete":
b_risks = risks.reshape((-1, ) + (risk_size, ))
with torch.no_grad():
b_risks = risk_model(b_obs)
#else:
# b_risks = risks.reshape((-1, ) + (1, ))

Expand Down

0 comments on commit ffad08a

Please sign in to comment.