Skip to content

Commit

Permalink
adding pretrained policy loader
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Nov 22, 2023
1 parent f46ed15 commit d5050ae
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def parse_args():
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=10,
help="the frequency of training")
parser.add_argument("--mode", type=str, default="train",
help="whether to train or evaluate the policy")
parser.add_argument("--pretrained-policy-path", type=str, default="None",
help="the id of the environment")

## Arguments related to risk model
parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
Expand Down Expand Up @@ -252,6 +256,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


q_network = QNetwork(envs, risk_size=risk_size).to(device)

if os.path.exists(args.pretrained_policy_path):
q_network.load_state_dict(torch.load(args.pretrained_policy_path, map_location=device))
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
target_network = QNetwork(envs, risk_size=risk_size).to(device)
target_network.load_state_dict(q_network.state_dict())
Expand All @@ -272,10 +279,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
obs = obs
total_cost = 0
scores = []
total_goals = 0
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
if random.random() < epsilon:
if random.random() < epsilon and args.mode == "train":
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
obs_in = torch.Tensor(obs["image"]).reshape(args.num_envs, -1).to(device)
Expand Down Expand Up @@ -305,10 +313,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
continue
total_cost += cost
ep_len = info["episode"]["l"]
e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len))
e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1)))
e_risks = torch.Tensor(e_risks)

if args.use_risk and args.fine_tune_risk != "None":
e_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len))
e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1)))
e_risks = torch.Tensor(e_risks)
if args.risk_type == "binary":
risk_rb.add(f_obs[i], f_next_obs[i], f_actions[i], None, None, None, (e_risks <= args.fear_radius).float(), e_risks.unsqueeze(1))
else:
Expand Down Expand Up @@ -336,7 +345,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# ALGO LOGIC: training.
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
if global_step % args.train_frequency == 0 and args.mode=="train":
data = rb.sample(args.batch_size)
with torch.no_grad():
next_risk = risk_model(data.next_observations.reshape(args.batch_size, -1).float()) if args.use_risk else None
Expand All @@ -361,7 +370,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
torch.save(q_network.state_dict(), os.path.join(wandb.run.dir, "qnet.pt"))
wandb.save("qnet.pt")
## Update Risk Network
if args.use_risk and args.fine_tune_risk != "None" and global_step % args.risk_update_period == 0:
if args.use_risk and args.fine_tune_risk != "None" and global_step % args.risk_update_period == 0 and args.mode=="train":
risk_model.train()
risk_data = risk_rb.sample(args.risk_batch_size)
pred = risk_model(risk_data["next_obs"].to(device))
Expand All @@ -375,7 +384,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
wandb.save("risk_model.pt")

# update target network
if global_step % args.target_network_frequency == 0:
if global_step % args.target_network_frequency == 0 and args.mode=="train":
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
target_network_param.data.copy_(
args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
Expand Down

0 comments on commit d5050ae

Please sign in to comment.