Skip to content

Commit

Permalink
Minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
daphne-cornelisse committed Dec 6, 2023
1 parent 7501ddc commit 0082243
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 27 deletions.
44 changes: 36 additions & 8 deletions utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import numpy as np
import pandas as pd
import logging
import torch
import wandb
import glob
Expand Down Expand Up @@ -82,7 +83,7 @@ def _get_scores(self):

for file in self.eval_files:

print(f"Evaluating policy on {file}...")
logging.info(f"Evaluating policy on {file}...")

# Step through scene in expert control mode to obtain ground truth
expert_actions, expert_pos, expert_speed, expert_gr, expert_edge_cr, expert_veh_cr = self._step_through_scene(
Expand Down Expand Up @@ -186,12 +187,16 @@ def _step_through_scene(self, filename: str, mode: str):

# Set control mode
if mode == "expert":
logging.debug(f'EXPERT MODE')
for obj in self.env.controlled_vehicles:
obj.expert_control = True
if mode == "policy":
logging.debug(f'POLICY MODE')
for obj in self.env.controlled_vehicles:
obj.expert_control = False


logging.debug(f'agent_ids: {agent_ids}')

# Step through scene
for timestep in range(num_steps):

Expand All @@ -213,7 +218,7 @@ def _step_through_scene(self, filename: str, mode: str):
agent_speed[veh_idx, timestep] = veh_obj.speed
action_indices[veh_idx, timestep] = action_idx
else:
print(f'veh {veh_obj.id} at t = {timestep} returns None action!')
logging.debug(f'veh {veh_obj.id} at t = {timestep} returns None action!')

action_dict = {}

Expand Down Expand Up @@ -328,7 +333,7 @@ def get_veh_to_veh_distances(self, positions, velocities, time_gap_in_sec=3):
veh_distances_per_step[veh_i, veh_j, step] = distance_between_veh_ij

if distance_between_veh_ij < safe_distance:
print(f"Vehicles {veh_i + 1} and {veh_j + 1} are too close!")
logging.debug(f"Vehicles {veh_i + 1} and {veh_j + 1} are too close!")

# Aggregate
distance_violations_matrix = (veh_distances_per_step < safe_distances_per_step).sum(axis=2)
Expand Down Expand Up @@ -362,23 +367,46 @@ def _get_files(self, eval_files, file_limit):
env_config = load_config("env_config")
exp_config = load_config("exp_config")

env_config.data_path = "./data_10/train"
# env_config.data_path = "./data_10/train"

# # Load trained human reference policy
# human_policy = load_policy(
# data_path="./models/il",
# file_name="human_policy_10_scenes_2023_11_21",
# )

# # Evaluate policy
# evaluator = EvaluatePolicy(
# env_config=env_config,
# exp_config=exp_config,
# policy=human_policy,
# log_to_wandb=False,
# deterministic=True,
# reg_coef=0.0,
# return_trajectories=True,
# )

# il_results_check = evaluator._get_scores()

# Set data path
env_config.data_path = "./data/train/"

# Load trained human reference policy
# Load human reference policy
human_policy = load_policy(
data_path="./models/il",
file_name="human_policy_10_scenes_2023_11_21",
file_name="human_policy_2_scenes_2023_11_22",
)

# Evaluate policy
evaluator = EvaluatePolicy(
env_config=env_config,
exp_config=exp_config,
policy=human_policy,
eval_files=["tfrecord-00012-of-01000_389.json"],
log_to_wandb=False,
deterministic=True,
reg_coef=0.0,
return_trajectories=True,
)

il_results_check = evaluator._get_scores()
df_il_res_2, df_il_trajs_2 = evaluator._get_scores()
22 changes: 22 additions & 0 deletions utils/manage_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Download trained policies from W&B."""

import wandb


if __name__ == "__main__":

ROOT = "models/hr_rl"

entity = "daphnecor"
project = "hr_ppo_2scenes"
collection_name = "nocturne-hr-ppo-11_28_21_01"

# Always initialize a W&B run to start tracking
wandb.init()

# Download model version files
path = wandb.use_artifact(f"{entity}/{project}/{collection_name}:latest").download(
root=ROOT,
)

print('Downloaded model files to: ', path)
6 changes: 4 additions & 2 deletions utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def plot_agent_trajectory(agent_df, act_space_dim):
)
acc = (agent_df.policy_act.values[nonnan_ids] == agent_df.expert_act.values[nonnan_ids]).sum() / nonnan_ids.shape[0]

fig, axs = plt.subplots(1, 3, figsize=(15, 4))
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

fig.suptitle(f'Scene: {agent_df.traffic_scene.iloc[0][9:27]} | Agent # {agent_df.agent_id.iloc[0]}')

# Plot expert and agent positions
axs[0].plot(agent_df.expert_pos_x, agent_df.expert_pos_y, '.-', color='g', label='Expert')
Expand All @@ -37,7 +39,7 @@ def plot_agent_trajectory(agent_df, act_space_dim):
axs[2].legend(facecolor='white', framealpha=1)
axs[2].set_xlabel(r'$t$')
axs[2].set_ylabel('Joint action index')
axs[2].set_title(f'Action accuracy: {acc*100} % ($D^A$ = {act_space_dim})')
axs[2].set_title(f'Action accuracy: {np.round(acc*100, 2)} % ($D^A$ = {act_space_dim})')

# Adding grids with a specific alpha value to both subplots
axs[0].grid(alpha=0.5) # Grid for axs[0] with alpha value
Expand Down
46 changes: 30 additions & 16 deletions utils/sb3/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import torch
import torch.nn as nn
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
import os
Expand Down Expand Up @@ -73,7 +74,7 @@ def _on_rollout_end(self) -> None:

# Average normalized by the number of agents in the scene
num_agents_per_step = np.array(self.locals["env"].agents_in_scene)
ep_rewards_avg_norm = sum(rewards.sum(axis=1) / num_agents_per_step) / self.n_episodes
self.ep_rewards_avg_norm = sum(rewards.sum(axis=1) / num_agents_per_step) / self.n_episodes

# Obtain the sum of reward per episode (accross all agents)
sum_rewards = rewards.sum() / self.n_episodes
Expand Down Expand Up @@ -105,7 +106,7 @@ def _on_rollout_end(self) -> None:

# Log aggregate performance measures
self.logger.record("rollout/avg_num_agents_controlled", np.mean(num_agents_per_step))
self.logger.record("rollout/ep_rew_mean_norm", ep_rewards_avg_norm)
self.logger.record("rollout/ep_rew_mean_norm", self.ep_rewards_avg_norm)
self.logger.record("rollout/ep_rew_sum", sum_rewards)
self.logger.record("rollout/ep_len_mean", avg_ep_len)
self.logger.record("rollout/perc_goal_achieved", self.avg_frac_goal_achieved)
Expand Down Expand Up @@ -142,12 +143,15 @@ def _on_rollout_end(self) -> None:
# Save model
if self.exp_config.ma_callback.save_model:
if self.iteration % self.exp_config.ma_callback.model_save_freq == 0:
self.save_model()
self._save_model()

def _on_training_end(self) -> None:
"""
This event is triggered before exiting the `learn()` method.
"""
# Save model to wandb
self._save_model()

if self.exp_config.ma_callback.save_video:
logging.info(f"Making video at last iter = {self.iteration} in deterministic mode | global_step = {self.num_timesteps}")
# Set deterministic to True
Expand All @@ -162,9 +166,6 @@ def _on_training_end(self) -> None:
deterministic=self.exp_config.ma_callback.video_deterministic,
)

if self.model_path is not None:
self.save_model()

if self.exp_config.ma_callback.log_human_metrics:
evaluator = EvaluatePolicy(
env_config=self.env_config,
Expand All @@ -174,9 +175,10 @@ def _on_training_end(self) -> None:
)
table = evaluator._get_scores()

def save_model(self) -> None:
def _save_model(self) -> None:
"""Save model to wandb."""
self.model_name = f"ppo_policy_net_{self.num_timesteps}"

self.model_name = f"nocturne-hr-ppo-{wandb.run.id}"
self.model_path = os.path.join(wandb.run.dir, f"{self.model_name}.pt")

# Create model artifact
Expand All @@ -189,19 +191,31 @@ def save_model(self) -> None:
# Save torch model
torch.save(
obj={
"iter": self.iteration,
"model_state_dict": self.locals["self"].policy.state_dict(),
"obs_space_dim": self.locals["env"].observation_space.shape[0],
"act_space_dim": self.locals["env"].action_space.n,
"norm_reward": self.ep_advantage_avg_norm,
"collision_rate": self.avg_frac_collided,
"goal_rate": self.avg_frac_collided,
"state_dict": self.locals["self"].policy.state_dict(),
"data": self.locals["self"].policy._get_constructor_parameters(),
"model": {
"model_cls": self.locals["self"].policy.__class__,
"feat_dim": self.locals["env"].observation_space.shape[0], # Input dimension
"act_func": self.locals["self"].policy.mlp_extractor.act_func, # Activation function used
"arch_ego_state": self.locals["self"].policy.mlp_extractor.arch_ego_state,
"arch_road_objects": self.locals["self"].policy.mlp_extractor.arch_road_objects, # Network layers
"arch_road_graph": self.locals["self"].policy.mlp_extractor.arch_road_graph,
"arch_shared": self.locals["self"].policy.mlp_extractor.arch_shared,
},
"train": {
"global_step": self.num_timesteps,
"trained_in_k_scenes": len(self.locals["env"].env.files),
"act_space_dim": self.locals["env"].action_space.n,
"norm_reward": self.ep_rewards_avg_norm,
"coll_rate": self.avg_frac_collided,
"goal_rate": self.avg_frac_goal_achieved,
},
},
f=self.model_path,
)

# Save model artifact
model_artifact.add_file(local_path=self.model_path)
wandb.save(self.model_path, base_path=wandb.run.dir)
self.wandb_run.log_artifact(model_artifact)
self.wandb_run.log_artifact(model_artifact, aliases=[f"lam_{self.exp_config.reg_weight}"])
logging.info(f"Saving model checkpoint to {self.model_path} | Global_step: {self.num_timesteps}")
2 changes: 1 addition & 1 deletion utils/string_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def datetime_to_str(dt: datetime) -> str:
Returns:
str: String representation of the datetime object.
"""
return dt.strftime("%Y_%m_%d__%H_%M_%S")
return dt.strftime("%m_%d_%H_%S")


def date_to_str(date_: date) -> str:
Expand Down

0 comments on commit 0082243

Please sign in to comment.