diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index 8e958a46c..d7663512f 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -74,6 +74,10 @@ def parse_args(): help="timestep to start learning") parser.add_argument("--train-frequency", type=int, default=10, help="the frequency of training") + parser.add_argument("--mode", type=str, default="train", + help="whether to train or evaluate the policy") + parser.add_argument("--pretrained-policy-path", type=str, default="None", + help="the id of the environment") ## Arguments related to risk model parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, @@ -252,6 +256,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): q_network = QNetwork(envs, risk_size=risk_size).to(device) + + if os.path.exists(args.pretrained_policy_path): + q_network.load_state_dict(torch.load(args.pretrained_policy_path, map_location=device)) optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) target_network = QNetwork(envs, risk_size=risk_size).to(device) target_network.load_state_dict(q_network.state_dict()) @@ -272,10 +279,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): obs = obs total_cost = 0 scores = [] + total_goals = 0 for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) - if random.random() < epsilon: + if random.random() < epsilon and args.mode == "train": actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: obs_in = torch.Tensor(obs["image"]).reshape(args.num_envs, -1).to(device) @@ -305,10 +313,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): continue total_cost += cost ep_len = info["episode"]["l"] - e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len)) - e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1))) - e_risks = torch.Tensor(e_risks) + if args.use_risk and args.fine_tune_risk != "None": + e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len)) + e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1))) + e_risks = torch.Tensor(e_risks) if args.risk_type == "binary": risk_rb.add(f_obs[i], f_next_obs[i], f_actions[i], None, None, None, (e_risks <= args.fear_radius).float(), e_risks.unsqueeze(1)) else: @@ -336,7 +345,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): # ALGO LOGIC: training. if global_step > args.learning_starts: - if global_step % args.train_frequency == 0: + if global_step % args.train_frequency == 0 and args.mode=="train": data = rb.sample(args.batch_size) with torch.no_grad(): next_risk = risk_model(data.next_observations.reshape(args.batch_size, -1).float()) if args.use_risk else None @@ -361,7 +370,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): torch.save(q_network.state_dict(), os.path.join(wandb.run.dir, "qnet.pt")) wandb.save("qnet.pt") ## Update Risk Network - if args.use_risk and args.fine_tune_risk != "None" and global_step % args.risk_update_period == 0: + if args.use_risk and args.fine_tune_risk != "None" and global_step % args.risk_update_period == 0 and args.mode=="train": risk_model.train() risk_data = risk_rb.sample(args.risk_batch_size) pred = risk_model(risk_data["next_obs"].to(device)) @@ -375,7 +384,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): wandb.save("risk_model.pt") # update target network - if global_step % args.target_network_frequency == 0: + if global_step % args.target_network_frequency == 0 and args.mode=="train": for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): target_network_param.data.copy_( args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data