Skip to content

Commit

Permalink
Add eval test
Browse files Browse the repository at this point in the history
  • Loading branch information
daphne-cornelisse committed Jan 28, 2024
1 parent 84f1fec commit d59c271
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions evaluation/policy_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def evaluate_policy(

# Run
for _ in tqdm(range(num_episodes)):

if traffic_files is not None:
# Reset to a new scene
obs_dict = env.reset(
Expand All @@ -77,7 +78,7 @@ def evaluate_policy(
)

else:
obs_dict = env.reset()
obs_dict = env.reset(use_av_only=use_av_only)

agent_ids = list(obs_dict.keys())
dead_agent_ids = []
Expand Down Expand Up @@ -176,7 +177,7 @@ def evaluate_policy(
off_road[agend_idx] += last_info_dicts[agent_id]["veh_edge_collision"] * 1
goal_achieved[agend_idx] += last_info_dicts[agent_id]["goal_achieved"] * 1

logging.info(f"Goal achieved: {last_info_dicts[agent_id]['goal_achieved']}")
logging.debug(f"Goal achieved: {last_info_dicts[agent_id]['goal_achieved']}")

# Get scene info
if scene_path_mapping is not None:
Expand Down Expand Up @@ -239,17 +240,42 @@ def evaluate_policy(


if __name__ == "__main__":

# Logging
logger = logging.getLogger()
logger.setLevel('INFO')

# Load environment config
env_config = load_config("env_config")

# Set data path to new scenes
# Set data path to NEW scenes (with is_av flag)
env_config.data_path = "data_new/train_no_tl"

# Get all scene files
train_file_paths = glob.glob(f"{env_config.data_path}" + "/tfrecord*")
files = sorted([os.path.basename(file) for file in train_file_paths])

logging.info(f"Using {len(files)} scenes")
# Step through scene using expert-teleports
logging.info(f'Evaluating policy using EXPERT-TELEPORT mode {len(files)} new scenes...\n')

df_expert_replay = evaluate_policy(
env_config=env_config,
controlled_agents=500,
data_path=env_config.data_path,
traffic_files=files,
mode="expert-replay",
select_from_k_scenes=1000, # Use all scenes
num_episodes=100,
use_av_only=True,
)

logging.info(f'--- Results: EXPERT-TELEPORT ---')
print(df_expert_replay[["goal_rate", "off_road", "veh_veh_collision"]].mean())


logging.info(f'Evaluating policy using EXPERT-TRAJECTORY ACTIONS {len(files)} new scenes...\n')

df_expert_traj = evaluate_policy(
env_config=env_config,
controlled_agents=500,
data_path=env_config.data_path,
Expand All @@ -259,8 +285,8 @@ def evaluate_policy(
num_episodes=100,
use_av_only=True,
)

logging.info(f'--- Results: EXPERT-TRAJECTORY ACTIONS ---')
print(df_expert_traj[["goal_rate", "off_road", "veh_veh_collision"]].mean())

print(df_expert_replay[["goal_rate", "off_road", "veh_veh_collision"]].mean())

# with open("invalid_train", "wb") as fp: #Pickling
# pickle.dump(inval_scenes, fp)

0 comments on commit d59c271

Please sign in to comment.