@@ -205,7 +205,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
205
205
np .random .seed (args .seed )
206
206
torch .manual_seed (args .seed )
207
207
torch .backends .cudnn .deterministic = args .torch_deterministic
208
- torch .set_default_tensor_type ('torch.cuda.FloatTensor' )
208
+ # torch.set_default_tensor_type('torch.cuda.FloatTensor')
209
209
210
210
device = torch .device ("cuda" if torch .cuda .is_available () and args .cuda else "cpu" )
211
211
args .use_risk = False if args .risk_model_path == "None" else True
@@ -220,7 +220,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
220
220
"mlp" : {"continuous" : RiskEst , "binary" : RiskEst }}
221
221
222
222
risk_size_dict = {"continuous" : 1 , "binary" : 2 , "quantile" : args .quantile_num }
223
- risk_size = risk_size_dict [args .risk_type ]
223
+ risk_size = risk_size_dict [args .risk_type ] if args . use_risk else 0
224
224
risk_bins = np .array ([i * args .quantile_size for i in range (args .quantile_num )])
225
225
226
226
if args .use_risk :
@@ -287,7 +287,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
287
287
# TRY NOT TO MODIFY: execute the game and log data.
288
288
next_obs , rewards , terminated , truncated , infos = envs .step (actions )
289
289
cost = int (terminated ) and (rewards == 0 )
290
- if (args .fine_tune_risk != "None" and args .use_risk ) or args . collect_data :
290
+ if (args .fine_tune_risk != "None" and args .use_risk ):
291
291
for i in range (args .num_envs ):
292
292
f_obs [i ] = torch .Tensor (obs ["image" ][i ]).reshape (1 , - 1 ).to (device ) if f_obs [i ] is None else torch .concat ([f_obs [i ], torch .Tensor (obs ["image" ][i ]).reshape (1 , - 1 ).to (device )], axis = 0 )
293
293
f_next_obs [i ] = torch .Tensor (next_obs ["image" ][i ]).reshape (1 , - 1 ).to (device ) if f_next_obs [i ] is None else torch .concat ([f_next_obs [i ], torch .Tensor (next_obs ["image" ][i ]).reshape (1 , - 1 ).to (device )], axis = 0 )
0 commit comments