Skip to content

Commit a974cd6

Browse files
committed
state action risk
1 parent cd0f682 commit a974cd6

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

cleanrl/ppo_continuous_action_wandb.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import gymnasium as gym
1010
import numpy as np
1111
import torch
12+
import tqdm
1213
import torch.nn as nn
1314
import torch.optim as optim
1415
from torch.utils.data import Dataset
@@ -247,7 +248,9 @@ def get_action_and_value(self, x, risk, action=None):
247248
probs = Normal(action_mean, action_std)
248249
if action is None:
249250
action = probs.sample()
250-
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk)
251+
candidates = probs.sample_n(5)
252+
253+
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk), candidates
251254

252255
class RiskAgent1(nn.Module):
253256
def __init__(self, envs, linear_size=64, risk_size=2):
@@ -280,7 +283,9 @@ def get_action_and_value(self, x, risk, action=None):
280283
probs = Normal(action_mean, action_std)
281284
if action is None:
282285
action = probs.sample()
283-
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
286+
candidates = probs.sample_n(5)
287+
288+
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates
284289

285290

286291
class Agent(nn.Module):
@@ -312,7 +317,8 @@ def get_action_and_value(self, x, action=None):
312317
probs = Normal(action_mean, action_std)
313318
if action is None:
314319
action = probs.sample()
315-
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
320+
candidates = probs.sample_n(5)
321+
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates
316322

317323

318324
class ContRiskAgent(nn.Module):
@@ -412,28 +418,6 @@ def risk_sgd_step(cfg, model, data, criterion, opt, device):
412418
return loss
413419

414420

415-
def train_risk(cfg, model, data, criterion, opt, device):
416-
model.train()
417-
dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type,
418-
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
419-
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu'))
420-
net_loss = 0
421-
for batch in dataloader:
422-
pred = model(get_risk_obs(cfg, batch[0]).to(device))
423-
if cfg.model_type == "mlp":
424-
loss = criterion(pred, batch[1].squeeze().to(device))
425-
else:
426-
loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device))
427-
opt.zero_grad()
428-
loss.backward()
429-
opt.step()
430-
431-
net_loss += loss.item()
432-
torch.save(model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt"))
433-
wandb.save("risk_model.pt")
434-
model.eval()
435-
print("risk_loss:", net_loss)
436-
return net_loss
437421

438422
def test_policy(cfg, agent, envs, device, risk_model=None):
439423
envs = gym.vector.SyncVectorEnv(
@@ -574,7 +558,7 @@ def train(cfg):
574558
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
575559
#else:
576560
# agent = ContRiskAgent(envs=envs).to(device)
577-
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size)
561+
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk")
578562
if os.path.exists(cfg.risk_model_path):
579563
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
580564
print("Pretrained risk model loaded successfully")
@@ -702,14 +686,20 @@ def train(cfg):
702686
# ALGO LOGIC: action logic
703687
with torch.no_grad():
704688
if cfg.use_risk:
705-
action, logprob, _, value = agent.get_action_and_value(next_obs, next_risk)
689+
action, logprob, _, value, candidates = agent.get_action_and_value(next_obs, next_risk)
706690
else:
707-
action, logprob, _, value = agent.get_action_and_value(next_obs)
691+
action, logprob, _, value, candidates = agent.get_action_and_value(next_obs)
708692

709693
values[step] = value.flatten()
710694
actions[step] = action
711695
logprobs[step] = logprob
712696

697+
if cfg.use_risk:
698+
with torch.no_grad():
699+
candidates = candidates.squeeze()
700+
# print(next_obs_risk.repeat(5, 1).size(), candidates.size())
701+
candidates_risk = torch.sum(torch.exp(risk_model(next_obs_risk.repeat(5, 1).to(device), candidates))[:, :2], -1)
702+
action = candidates[torch.argmin(candidates_risk)]
713703
# TRY NOT TO MODIFY: execute the game and log data.
714704
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
715705
done = np.logical_or(terminated, truncated)
@@ -738,12 +728,13 @@ def train(cfg):
738728
for i in range(cfg.num_envs):
739729
f_obs[i] = obs_[i].unsqueeze(0).to(device) if f_obs[i] is None else torch.concat([f_obs[i], obs_[i].unsqueeze(0).to(device)], axis=0)
740730
f_next_obs[i] = next_obs[i].unsqueeze(0).to(device) if f_next_obs[i] is None else torch.concat([f_next_obs[i], next_obs[i].unsqueeze(0).to(device)], axis=0)
741-
f_actions[i] = action[i].unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], action[i].unsqueeze(0).to(device)], axis=0)
731+
f_actions[i] = action.unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], action.unsqueeze(0).to(device)], axis=0)
742732
f_rewards[i] = reward[i].unsqueeze(0).to(device) if f_rewards[i] is None else torch.concat([f_rewards[i], reward[i].unsqueeze(0).to(device)], axis=0)
743733
# f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0)
744734
f_costs[i] = cost[i].unsqueeze(0).to(device) if f_costs[i] is None else torch.concat([f_costs[i], cost[i].unsqueeze(0).to(device)], axis=0)
745735
f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0)
746736

737+
# print(f_actions[0].size())
747738
obs_ = next_obs
748739
# if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk:
749740
# if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
@@ -767,15 +758,24 @@ def train(cfg):
767758
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
768759
elif cfg.fine_tune_risk == "off" and cfg.use_risk:
769760
if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
770-
for epoch in range(cfg.num_risk_epochs):
761+
for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)):
771762
total_risk_updates +=1
772763
print(total_risk_updates)
773764
if cfg.finetune_risk_online:
774765
print("I am online")
775766
data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0)
776767
else:
777768
data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk)
778-
risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device)
769+
state = torch.cat([data["obs"], data["next_obs"]], axis=0)
770+
actions = torch.cat([data["actions"], torch.zeros_like(data["actions"])], axis=0)
771+
dist_to_fail = torch.cat([data["dist_to_fail"], data["dist_to_fail"]], axis=0)
772+
print(state.size(), actions.size(), dist_to_fail.size())
773+
risk_dataset = RiskyDataset(state.to(device), actions.to(device), dist_to_fail.to(device), True, risk_type=cfg.risk_type,
774+
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
775+
risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device=device))
776+
777+
risk_loss = train_risk(risk_model, risk_dataloader, criterion, opt_risk, 1, device, train_mode="state_action")
778+
779779
writer.add_scalar("risk/risk_loss", risk_loss, global_step)
780780

781781
# Only print when at least 1 env is done
@@ -853,7 +853,7 @@ def train(cfg):
853853
if cfg.risk_type == "binary":
854854
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
855855
else:
856-
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)
856+
rb.add(get_risk_obs(cfg, f_obs[i]), get_risk_obs(cfg, f_next_obs[i]), f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks)
857857

858858
f_obs[i] = None
859859
f_next_obs[i] = None
@@ -915,9 +915,9 @@ def train(cfg):
915915
mb_inds = b_inds[start:end]
916916

917917
if cfg.use_risk:
918-
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
918+
_, newlogprob, entropy, newvalue, cands = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
919919
else:
920-
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
920+
_, newlogprob, entropy, newvalue, cands = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
921921

922922
logratio = newlogprob - b_logprobs[mb_inds]
923923
ratio = logratio.exp()

0 commit comments

Comments
 (0)