9
9
import gymnasium as gym
10
10
import numpy as np
11
11
import torch
12
+ import tqdm
12
13
import torch .nn as nn
13
14
import torch .optim as optim
14
15
from torch .utils .data import Dataset
@@ -247,7 +248,9 @@ def get_action_and_value(self, x, risk, action=None):
247
248
probs = Normal (action_mean , action_std )
248
249
if action is None :
249
250
action = probs .sample ()
250
- return action , probs .log_prob (action ).sum (1 ), probs .entropy ().sum (1 ), self .get_value (x , risk )
251
+ candidates = probs .sample_n (5 )
252
+
253
+ return action , probs .log_prob (action ).sum (1 ), probs .entropy ().sum (1 ), self .get_value (x , risk ), candidates
251
254
252
255
class RiskAgent1 (nn .Module ):
253
256
def __init__ (self , envs , linear_size = 64 , risk_size = 2 ):
@@ -280,7 +283,9 @@ def get_action_and_value(self, x, risk, action=None):
280
283
probs = Normal (action_mean , action_std )
281
284
if action is None :
282
285
action = probs .sample ()
283
- return action , probs .log_prob (action ).sum (1 ), probs .entropy ().sum (1 ), self .critic (x )
286
+ candidates = probs .sample_n (5 )
287
+
288
+ return action , probs .log_prob (action ).sum (1 ), probs .entropy ().sum (1 ), self .critic (x ), candidates
284
289
285
290
286
291
class Agent (nn .Module ):
@@ -312,7 +317,8 @@ def get_action_and_value(self, x, action=None):
312
317
probs = Normal (action_mean , action_std )
313
318
if action is None :
314
319
action = probs .sample ()
315
- return action , probs .log_prob (action ).sum (1 ), probs .entropy ().sum (1 ), self .critic (x )
320
+ candidates = probs .sample_n (5 )
321
+ return action , probs .log_prob (action ).sum (1 ), probs .entropy ().sum (1 ), self .critic (x ), candidates
316
322
317
323
318
324
class ContRiskAgent (nn .Module ):
@@ -412,28 +418,6 @@ def risk_sgd_step(cfg, model, data, criterion, opt, device):
412
418
return loss
413
419
414
420
415
- def train_risk (cfg , model , data , criterion , opt , device ):
416
- model .train ()
417
- dataset = RiskyDataset (data ["next_obs" ].to ('cpu' ), None , data ["risks" ].to ('cpu' ), False , risk_type = cfg .risk_type ,
418
- fear_clip = None , fear_radius = cfg .fear_radius , one_hot = True , quantile_size = cfg .quantile_size , quantile_num = cfg .quantile_num )
419
- dataloader = DataLoader (dataset , batch_size = cfg .risk_batch_size , shuffle = True , num_workers = 10 , generator = torch .Generator (device = 'cpu' ))
420
- net_loss = 0
421
- for batch in dataloader :
422
- pred = model (get_risk_obs (cfg , batch [0 ]).to (device ))
423
- if cfg .model_type == "mlp" :
424
- loss = criterion (pred , batch [1 ].squeeze ().to (device ))
425
- else :
426
- loss = criterion (pred , torch .argmax (batch [1 ].squeeze (), axis = 1 ).to (device ))
427
- opt .zero_grad ()
428
- loss .backward ()
429
- opt .step ()
430
-
431
- net_loss += loss .item ()
432
- torch .save (model .state_dict (), os .path .join (wandb .run .dir , "risk_model.pt" ))
433
- wandb .save ("risk_model.pt" )
434
- model .eval ()
435
- print ("risk_loss:" , net_loss )
436
- return net_loss
437
421
438
422
def test_policy (cfg , agent , envs , device , risk_model = None ):
439
423
envs = gym .vector .SyncVectorEnv (
@@ -574,7 +558,7 @@ def train(cfg):
574
558
agent = RiskAgent (envs = envs , risk_size = risk_size ).to (device )
575
559
#else:
576
560
# agent = ContRiskAgent(envs=envs).to(device)
577
- risk_model = risk_model_class [cfg .model_type ][cfg .risk_type ](obs_size = 96 , batch_norm = True , out_size = risk_size )
561
+ risk_model = risk_model_class [cfg .model_type ][cfg .risk_type ](obs_size = 96 , batch_norm = True , out_size = risk_size , action_size = envs . single_action_space . shape [ 0 ], model_type = "state_action_risk" )
578
562
if os .path .exists (cfg .risk_model_path ):
579
563
risk_model .load_state_dict (torch .load (cfg .risk_model_path , map_location = device ))
580
564
print ("Pretrained risk model loaded successfully" )
@@ -702,14 +686,20 @@ def train(cfg):
702
686
# ALGO LOGIC: action logic
703
687
with torch .no_grad ():
704
688
if cfg .use_risk :
705
- action , logprob , _ , value = agent .get_action_and_value (next_obs , next_risk )
689
+ action , logprob , _ , value , candidates = agent .get_action_and_value (next_obs , next_risk )
706
690
else :
707
- action , logprob , _ , value = agent .get_action_and_value (next_obs )
691
+ action , logprob , _ , value , candidates = agent .get_action_and_value (next_obs )
708
692
709
693
values [step ] = value .flatten ()
710
694
actions [step ] = action
711
695
logprobs [step ] = logprob
712
696
697
+ if cfg .use_risk :
698
+ with torch .no_grad ():
699
+ candidates = candidates .squeeze ()
700
+ # print(next_obs_risk.repeat(5, 1).size(), candidates.size())
701
+ candidates_risk = torch .sum (torch .exp (risk_model (next_obs_risk .repeat (5 , 1 ).to (device ), candidates ))[:, :2 ], - 1 )
702
+ action = candidates [torch .argmin (candidates_risk )]
713
703
# TRY NOT TO MODIFY: execute the game and log data.
714
704
next_obs , reward , terminated , truncated , infos = envs .step (action .cpu ().numpy ())
715
705
done = np .logical_or (terminated , truncated )
@@ -738,12 +728,13 @@ def train(cfg):
738
728
for i in range (cfg .num_envs ):
739
729
f_obs [i ] = obs_ [i ].unsqueeze (0 ).to (device ) if f_obs [i ] is None else torch .concat ([f_obs [i ], obs_ [i ].unsqueeze (0 ).to (device )], axis = 0 )
740
730
f_next_obs [i ] = next_obs [i ].unsqueeze (0 ).to (device ) if f_next_obs [i ] is None else torch .concat ([f_next_obs [i ], next_obs [i ].unsqueeze (0 ).to (device )], axis = 0 )
741
- f_actions [i ] = action [ i ] .unsqueeze (0 ).to (device ) if f_actions [i ] is None else torch .concat ([f_actions [i ], action [ i ] .unsqueeze (0 ).to (device )], axis = 0 )
731
+ f_actions [i ] = action .unsqueeze (0 ).to (device ) if f_actions [i ] is None else torch .concat ([f_actions [i ], action .unsqueeze (0 ).to (device )], axis = 0 )
742
732
f_rewards [i ] = reward [i ].unsqueeze (0 ).to (device ) if f_rewards [i ] is None else torch .concat ([f_rewards [i ], reward [i ].unsqueeze (0 ).to (device )], axis = 0 )
743
733
# f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0)
744
734
f_costs [i ] = cost [i ].unsqueeze (0 ).to (device ) if f_costs [i ] is None else torch .concat ([f_costs [i ], cost [i ].unsqueeze (0 ).to (device )], axis = 0 )
745
735
f_dones [i ] = next_done [i ].unsqueeze (0 ).to (device ) if f_dones [i ] is None else torch .concat ([f_dones [i ], next_done [i ].unsqueeze (0 ).to (device )], axis = 0 )
746
736
737
+ # print(f_actions[0].size())
747
738
obs_ = next_obs
748
739
# if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk:
749
740
# if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0:
@@ -767,15 +758,24 @@ def train(cfg):
767
758
writer .add_scalar ("risk/risk_loss" , risk_loss , global_step )
768
759
elif cfg .fine_tune_risk == "off" and cfg .use_risk :
769
760
if cfg .use_risk and (global_step > cfg .start_risk_update and cfg .fine_tune_risk ) and global_step % cfg .risk_update_period == 0 :
770
- for epoch in range (cfg .num_risk_epochs ):
761
+ for epoch in tqdm . tqdm ( range (cfg .num_risk_epochs ) ):
771
762
total_risk_updates += 1
772
763
print (total_risk_updates )
773
764
if cfg .finetune_risk_online :
774
765
print ("I am online" )
775
766
data = rb .slice_data (- cfg .risk_batch_size * cfg .num_update_risk , 0 )
776
767
else :
777
768
data = rb .sample (cfg .risk_batch_size * cfg .num_update_risk )
778
- risk_loss = train_risk (cfg , risk_model , data , criterion , opt_risk , device )
769
+ state = torch .cat ([data ["obs" ], data ["next_obs" ]], axis = 0 )
770
+ actions = torch .cat ([data ["actions" ], torch .zeros_like (data ["actions" ])], axis = 0 )
771
+ dist_to_fail = torch .cat ([data ["dist_to_fail" ], data ["dist_to_fail" ]], axis = 0 )
772
+ print (state .size (), actions .size (), dist_to_fail .size ())
773
+ risk_dataset = RiskyDataset (state .to (device ), actions .to (device ), dist_to_fail .to (device ), True , risk_type = cfg .risk_type ,
774
+ fear_clip = None , fear_radius = cfg .fear_radius , one_hot = True , quantile_size = cfg .quantile_size , quantile_num = cfg .quantile_num )
775
+ risk_dataloader = DataLoader (risk_dataset , batch_size = cfg .risk_batch_size , shuffle = True , num_workers = 4 , generator = torch .Generator (device = device ))
776
+
777
+ risk_loss = train_risk (risk_model , risk_dataloader , criterion , opt_risk , 1 , device , train_mode = "state_action" )
778
+
779
779
writer .add_scalar ("risk/risk_loss" , risk_loss , global_step )
780
780
781
781
# Only print when at least 1 env is done
@@ -853,7 +853,7 @@ def train(cfg):
853
853
if cfg .risk_type == "binary" :
854
854
rb .add (f_obs [i ], f_next_obs [i ], f_actions [i ], f_rewards [i ], f_dones [i ], f_costs [i ], (f_risks <= cfg .fear_radius ).float (), e_risks .unsqueeze (1 ))
855
855
else :
856
- rb .add (f_obs [i ], f_next_obs [i ], f_actions [i ], f_rewards [i ], f_dones [i ], f_costs [i ], f_risks , f_risks )
856
+ rb .add (get_risk_obs ( cfg , f_obs [i ]), get_risk_obs ( cfg , f_next_obs [i ]) , f_actions [i ], f_rewards [i ], f_dones [i ], f_costs [i ], f_risks , f_risks )
857
857
858
858
f_obs [i ] = None
859
859
f_next_obs [i ] = None
@@ -915,9 +915,9 @@ def train(cfg):
915
915
mb_inds = b_inds [start :end ]
916
916
917
917
if cfg .use_risk :
918
- _ , newlogprob , entropy , newvalue = agent .get_action_and_value (b_obs [mb_inds ], b_risks [mb_inds ], b_actions [mb_inds ])
918
+ _ , newlogprob , entropy , newvalue , cands = agent .get_action_and_value (b_obs [mb_inds ], b_risks [mb_inds ], b_actions [mb_inds ])
919
919
else :
920
- _ , newlogprob , entropy , newvalue = agent .get_action_and_value (b_obs [mb_inds ], b_actions [mb_inds ])
920
+ _ , newlogprob , entropy , newvalue , cands = agent .get_action_and_value (b_obs [mb_inds ], b_actions [mb_inds ])
921
921
922
922
logratio = newlogprob - b_logprobs [mb_inds ]
923
923
ratio = logratio .exp ()
0 commit comments