Skip to content

Commit 858438e

Browse files
committed
bug fix
1 parent e6fd43f commit 858438e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

cleanrl/dqn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
205205
np.random.seed(args.seed)
206206
torch.manual_seed(args.seed)
207207
torch.backends.cudnn.deterministic = args.torch_deterministic
208-
torch.set_default_tensor_type('torch.cuda.FloatTensor')
208+
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
209209

210210
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
211211
args.use_risk = False if args.risk_model_path == "None" else True
@@ -220,7 +220,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
220220
"mlp": {"continuous": RiskEst, "binary": RiskEst}}
221221

222222
risk_size_dict = {"continuous": 1, "binary": 2, "quantile": args.quantile_num}
223-
risk_size = risk_size_dict[args.risk_type]
223+
risk_size = risk_size_dict[args.risk_type] if args.use_risk else 0
224224
risk_bins = np.array([i*args.quantile_size for i in range(args.quantile_num)])
225225

226226
if args.use_risk:
@@ -287,7 +287,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
287287
# TRY NOT TO MODIFY: execute the game and log data.
288288
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
289289
cost = int(terminated) and (rewards == 0)
290-
if (args.fine_tune_risk != "None" and args.use_risk) or args.collect_data:
290+
if (args.fine_tune_risk != "None" and args.use_risk):
291291
for i in range(args.num_envs):
292292
f_obs[i] = torch.Tensor(obs["image"][i]).reshape(1, -1).to(device) if f_obs[i] is None else torch.concat([f_obs[i], torch.Tensor(obs["image"][i]).reshape(1, -1).to(device)], axis=0)
293293
f_next_obs[i] = torch.Tensor(next_obs["image"][i]).reshape(1, -1).to(device) if f_next_obs[i] is None else torch.concat([f_next_obs[i], torch.Tensor(next_obs["image"][i]).reshape(1, -1).to(device)], axis=0)

0 commit comments

Comments
 (0)