diff --git a/configs/bc_config.yaml b/configs/bc_config.yaml
index 8079114d..13547552 100644
--- a/configs/bc_config.yaml
+++ b/configs/bc_config.yaml
@@ -1,8 +1,8 @@
# Model
save_model: true # Save model after training
-model_name: "human_policy_data_2_scenes" # Name of saved model
+model_name: "human_policy" # Name of saved model
save_model_path: ./models/il/ # Path to save model
# Train
-total_samples: 10_000 # Number of obs-act-next_obs-done pairs to generate
-n_epochs: 100 # Training epochs
+total_samples: 150_000 # Number of obs-act-next_obs-done pairs to generate
+n_epochs: 200 # Training epochs
\ No newline at end of file
diff --git a/configs/env_config.yaml b/configs/env_config.yaml
index f8c20c3a..10644e49 100644
--- a/configs/env_config.yaml
+++ b/configs/env_config.yaml
@@ -8,7 +8,7 @@ env: my_custom_multi_env_v1 # name of the env, hardcoded for now
episode_length: 80
warmup_period: 10 # In the RL setting we use a warmup of 10 steps
# How many files of the total dataset to use. -1 indicates to use all of them
-num_files: 20
+num_files: 100
fix_file_order: true # If true, always select the SAME files (when creating the environent), if false, pick files at random
sample_file_method: "random" # ALTERNATIVES: "no_replacement" (SUPPORTED) / "score-based" (TODO: @Daphne)
dt: 0.1
@@ -18,9 +18,9 @@ discretize_actions: true
include_head_angle: false # Whether to include the head tilt/angle as part of a vehicle's action
accel_discretization: 5
accel_lower_bound: -3
-accel_upper_bound: 3
-steering_lower_bound: -0.7
-steering_upper_bound: 0.7
+accel_upper_bound: 3
+steering_lower_bound: -0.7 # steer right
+steering_upper_bound: 0.7 # steer left
steering_discretization: 5
max_num_vehicles: 20
randomize_goals: false
@@ -109,4 +109,5 @@ subscriber:
n_frames_stacked: 1 # Agent memory
# Path to folder with traffic scene(s) from which to create an environment
-data_path: ./data_full/train
\ No newline at end of file
+data_path: ./data_full/train
+val_data_path: ./data_full/valid
\ No newline at end of file
diff --git a/configs/exp_config.yaml b/configs/exp_config.yaml
index 1ef37674..91946233 100644
--- a/configs/exp_config.yaml
+++ b/configs/exp_config.yaml
@@ -22,7 +22,7 @@ ma_callback:
model_save_freq: 100 # In iterations (one iter ~ (num_agents x n_steps))
save_video: true
record_n_scenes: 10 # Number of different scenes to render
- video_save_freq: 50 # Make a video every k iterations (100 iters ~ 1M steps)
+ video_save_freq: 100 # Make a video every k iterations (100 iters ~ 1M steps)
video_deterministic: true
ppo:
@@ -32,7 +32,7 @@ ppo:
vf_coef: 0.5 # Default in SB3 is 0.5
learn:
- total_timesteps: 10_000_000
+ total_timesteps: 2_000_000
progress_bar: false
# human-regularized RL
diff --git a/evaluation/il_analysis.ipynb b/evaluation/il_analysis.ipynb
new file mode 100644
index 00000000..0834727c
--- /dev/null
+++ b/evaluation/il_analysis.ipynb
@@ -0,0 +1,5971 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Dependencies\n",
+ "import numpy as np\n",
+ "import glob\n",
+ "import pandas as pd\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import torch\n",
+ "import logging\n",
+ "import os\n",
+ "import matplotlib.pyplot as plt\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "from typing import Callable\n",
+ "from gym import spaces\n",
+ "from stable_baselines3.common.policies import ActorCriticPolicy\n",
+ "from networks.mlp_late_fusion import LateFusionMLP, LateFusionMLPPolicy\n",
+ "from collections import Counter\n",
+ "from utils.plot import plot_agent_trajectory\n",
+ "from utils.config import load_config_nb\n",
+ "from utils.eval import EvaluatePolicy\n",
+ "from utils.policies import load_policy\n",
+ "\n",
+ "sns.set('notebook', font_scale=1.1, rc={'figure.figsize': (10, 5)})\n",
+ "sns.set_style('ticks', rc={'figure.facecolor': 'none', 'axes.facecolor': 'none'})\n",
+ "%config InlineBackend.figure_format = 'svg'\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "plt.set_loglevel('WARNING')\n",
+ "pd.options.display.float_format = \"{:,.2f}\".format\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MAX_FILES = 1000\n",
+ "\n",
+ "# Load config files\n",
+ "env_config = load_config_nb(\"env_config\")\n",
+ "exp_config = load_config_nb(\"exp_config\")\n",
+ "model_config = load_config_nb(\"model_config\")\n",
+ "\n",
+ "# Set data path\n",
+ "env_config.data_path = \"../data_full/train/\"\n",
+ "env_config.val_data_path = \"../data_full/valid/\"\n",
+ "env_config.num_files = MAX_FILES\n",
+ "\n",
+ "# Logging level set to INFO\n",
+ "LOGGING_LEVEL = \"INFO\"\n",
+ "\n",
+ "# Scenes on which to evaluate the models\n",
+ "# Make sure file order is fixed\n",
+ "train_file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n",
+ "train_eval_files = sorted([os.path.basename(file) for file in train_file_paths])\n",
+ "\n",
+ "# Valid\n",
+ "valid_file_paths = glob.glob(f\"{env_config.val_data_path}\" + \"/tfrecord*\")\n",
+ "valid_eval_files = sorted([os.path.basename(file) for file in valid_file_paths])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### IL action distributions\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:root:Using 996 file(s)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from utils.imitation_learning.waymo_iterator import TrajectoryIterator\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "NUM_IL_FILES = 1000\n",
+ "\n",
+ "env_config.num_files = NUM_IL_FILES\n",
+ "\n",
+ "# Create iterator\n",
+ "waymo_iterator = TrajectoryIterator(\n",
+ " data_path=env_config.data_path,\n",
+ " env_config=env_config,\n",
+ " file_limit=env_config.num_files,\n",
+ ") \n",
+ "\n",
+ "# Rollout to get obs-act-obs-done trajectories \n",
+ "rollouts = next(iter(\n",
+ " DataLoader(\n",
+ " waymo_iterator,\n",
+ " batch_size=10_000, # Number of samples to generate\n",
+ " pin_memory=True,\n",
+ ")))\n",
+ "\n",
+ "obs, acts, next_obs, dones = rollouts\n",
+ "action_cats = dict(Counter(acts.numpy()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " action_idx | \n",
+ " count | \n",
+ " perc | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 17 | \n",
+ " 1482 | \n",
+ " 14.82 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 12 | \n",
+ " 2773 | \n",
+ " 27.73 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 7 | \n",
+ " 1419 | \n",
+ " 14.19 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 22 | \n",
+ " 1322 | \n",
+ " 13.22 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 1335 | \n",
+ " 13.35 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0 | \n",
+ " 463 | \n",
+ " 4.63 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 21 | \n",
+ " 95 | \n",
+ " 0.95 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 1 | \n",
+ " 56 | \n",
+ " 0.56 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 13 | \n",
+ " 134 | \n",
+ " 1.34 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 23 | \n",
+ " 90 | \n",
+ " 0.90 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " action_idx count perc\n",
+ "0 17 1482 14.82\n",
+ "1 12 2773 27.73\n",
+ "2 7 1419 14.19\n",
+ "3 22 1322 13.22\n",
+ "4 2 1335 13.35\n",
+ "5 0 463 4.63\n",
+ "6 21 95 0.95\n",
+ "7 1 56 0.56\n",
+ "8 13 134 1.34\n",
+ "9 23 90 0.90"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_acts = pd.DataFrame(\n",
+ " {\n",
+ " 'action_idx': list(action_cats.keys()), \n",
+ " 'count': list(action_cats.values()),\n",
+ " 'perc': np.array(list(action_cats.values())) / len(acts.numpy()) * 100,\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "df_acts['action_idx'] = df_acts['action_idx'].astype(int)\n",
+ "df_acts.head(n=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Create a Seaborn categorical plot\n",
+ "sns.barplot(data=df_acts, x='action_idx', y='perc')\n",
+ "\n",
+ "plt.xlabel('Action indices')\n",
+ "plt.ylabel('Percentage %')\n",
+ "plt.title(f'Action Distribution in IL training dataset (N = {NUM_IL_FILES})')\n",
+ "\n",
+ "plt.grid(True, alpha=.5)\n",
+ "sns.despine();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "BASE_PATH = \"../models/il\"\n",
+ "\n",
+ "# Scenes on which to evaluate the models\n",
+ "il_policy_names = ['human_policy_S10_2024_01_02', 'human_policy_S100_2024_01_02']\n",
+ "num_scenes = [10, 100]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Evaluate Behavioral Cloning policies on training scenes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 120,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:root:Evaluating policy on 10 files...\n",
+ " 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 10/10 [00:02<00:00, 3.92it/s]\n",
+ "INFO:root:Evaluating policy on 100 files...\n",
+ "100%|██████████| 100/100 [00:25<00:00, 3.98it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "env_config.data_path = \"../data_full/train/\"\n",
+ "\n",
+ "df_il_train = pd.DataFrame()\n",
+ "\n",
+ "for trained_policy, num_files in zip(il_policy_names, num_scenes):\n",
+ "\n",
+ " eval_files = train_eval_files[:num_files]\n",
+ "\n",
+ " # Load trained human reference policy\n",
+ " human_policy = load_policy(\n",
+ " data_path=BASE_PATH,\n",
+ " file_name=trained_policy, \n",
+ " )\n",
+ "\n",
+ " # Evaluate policy\n",
+ " evaluator = EvaluatePolicy(\n",
+ " env_config=env_config, \n",
+ " exp_config=exp_config,\n",
+ " policy=human_policy,\n",
+ " eval_files=eval_files,\n",
+ " log_to_wandb=False, \n",
+ " deterministic=False,\n",
+ " reg_coef=0.0,\n",
+ " return_trajectories=True,\n",
+ " )\n",
+ "\n",
+ " df_il_res, df_il_trajs = evaluator._get_scores()\n",
+ "\n",
+ " df_il_res['num_scenes'] = num_files\n",
+ " df_il_res['policy'] = trained_policy\n",
+ " df_il_train = pd.concat([df_il_train, df_il_res])\n",
+ "\n",
+ "df_il_train['type'] = 'train'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 121,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " Aggregated Behavioral Cloning Human Likeness Scores (train data)
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " act_acc | \n",
+ " accel_val_mae | \n",
+ " speed_mae | \n",
+ " steer_val_mae | \n",
+ " pos_rmse | \n",
+ "
\n",
+ " \n",
+ " num_scenes | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 10 | \n",
+ " 0.23 | \n",
+ " 1.66 | \n",
+ " 29.34 | \n",
+ " 0.08 | \n",
+ " 50.42 | \n",
+ "
\n",
+ " \n",
+ " 100 | \n",
+ " 0.17 | \n",
+ " 1.84 | \n",
+ " 44.74 | \n",
+ " 0.15 | \n",
+ " 62.22 | \n",
+ "
\n",
+ " \n",
+ "
\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 121,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Show aggregate human likeness scores\n",
+ "hl_df_avg_train = df_il_train.groupby('num_scenes')[['act_acc', 'accel_val_mae', 'speed_mae', 'steer_val_mae', 'pos_rmse']].mean()\n",
+ "\n",
+ "hl_df_avg_train.style.format('{:.3f}', na_rep=\"\")\\\n",
+ " .bar(subset=['act_acc'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n",
+ " .bar(subset=['accel_val_mae'], align=0, vmin=0, vmax=9, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['speed_mae'], align=0, vmin=0, vmax=50, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['steer_val_mae'], align=0, vmin=0, vmax=1.4, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['pos_rmse'], align=0, vmin=0, vmax=75, cmap=\"Blues\", height=50, width=60)\\\n",
+ " .set_caption(\"Aggregated Behavioral Cloning Human Likeness Scores (train data)
\").format(\"{:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 122,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " Aggregated Behavioral Cloning Performance Metrics (train data)
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " goal_rate | \n",
+ " veh_edge_cr | \n",
+ " veh_veh_cr | \n",
+ "
\n",
+ " \n",
+ " num_scenes | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 10 | \n",
+ " 0.30 | \n",
+ " 0.27 | \n",
+ " 0.07 | \n",
+ "
\n",
+ " \n",
+ " 100 | \n",
+ " 0.16 | \n",
+ " 0.32 | \n",
+ " 0.17 | \n",
+ "
\n",
+ " \n",
+ "
\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 122,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "performance_df_avg_train = df_il_train.groupby('num_scenes')[['goal_rate', 'veh_edge_cr', 'veh_veh_cr']].mean()\n",
+ "\n",
+ "performance_df_avg_train.style.format('{:.3f}', na_rep=\"\")\\\n",
+ " .bar(subset=['goal_rate'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n",
+ " .bar(subset=['veh_edge_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['veh_veh_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .set_caption(\"Aggregated Behavioral Cloning Performance Metrics (train data)
\").format(\"{:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Evaluate Behavioral Cloning policies on new (unseen) scenes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 123,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:root:Evaluating policy on 10 files...\n",
+ "100%|██████████| 10/10 [00:02<00:00, 3.47it/s]\n",
+ "INFO:root:Evaluating policy on 99 files...\n",
+ "100%|██████████| 99/99 [00:25<00:00, 3.91it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Set data path\n",
+ "env_config.data_path = \"../data_full/valid/\"\n",
+ "\n",
+ "df_il_valid = pd.DataFrame()\n",
+ "\n",
+ "for trained_policy, num_files in zip(il_policy_names, num_scenes):\n",
+ "\n",
+ " eval_files = valid_eval_files[:num_files]\n",
+ "\n",
+ " # Load trained human reference policy\n",
+ " human_policy = load_policy(\n",
+ " data_path=BASE_PATH,\n",
+ " file_name=trained_policy, \n",
+ " )\n",
+ "\n",
+ " # Evaluate policy\n",
+ " evaluator = EvaluatePolicy(\n",
+ " env_config=env_config, \n",
+ " exp_config=exp_config,\n",
+ " policy=human_policy,\n",
+ " eval_files=eval_files,\n",
+ " log_to_wandb=False, \n",
+ " deterministic=False,\n",
+ " reg_coef=0.0,\n",
+ " return_trajectories=True,\n",
+ " )\n",
+ "\n",
+ " df_il_res, df_il_trajs = evaluator._get_scores()\n",
+ "\n",
+ " df_il_res['num_scenes'] = num_files\n",
+ " df_il_res['policy'] = trained_policy\n",
+ " df_il_valid = pd.concat([df_il_valid, df_il_res])\n",
+ "\n",
+ "df_il_valid['type'] = 'validation'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 124,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " Aggregated Behavioral Cloning Human Likeness Scores (validation data)
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " act_acc | \n",
+ " accel_val_mae | \n",
+ " speed_mae | \n",
+ " steer_val_mae | \n",
+ " pos_rmse | \n",
+ "
\n",
+ " \n",
+ " num_scenes | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 10 | \n",
+ " 0.08 | \n",
+ " 2.17 | \n",
+ " 16.42 | \n",
+ " 0.19 | \n",
+ " 35.51 | \n",
+ "
\n",
+ " \n",
+ " 100 | \n",
+ " 0.10 | \n",
+ " 2.04 | \n",
+ " 34.33 | \n",
+ " 0.19 | \n",
+ " 49.43 | \n",
+ "
\n",
+ " \n",
+ "
\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 124,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Show aggregate human likeness scores\n",
+ "hl_df_avg_valid = df_il_valid.groupby('num_scenes')[['act_acc', 'accel_val_mae', 'speed_mae', 'steer_val_mae', 'pos_rmse']].mean()\n",
+ "\n",
+ "hl_df_avg_valid.style.format('{:.3f}', na_rep=\"\")\\\n",
+ " .bar(subset=['act_acc'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n",
+ " .bar(subset=['accel_val_mae'], align=0, vmin=0, vmax=9, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['speed_mae'], align=0, vmin=0, vmax=50, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['steer_val_mae'], align=0, vmin=0, vmax=1.4, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['pos_rmse'], align=0, vmin=0, vmax=75, cmap=\"Blues\", height=50, width=60)\\\n",
+ " .set_caption(\"Aggregated Behavioral Cloning Human Likeness Scores (validation data)
\").format(\"{:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 125,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " Aggregated Behavioral Cloning Performance Metrics (validation data)
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " goal_rate | \n",
+ " veh_edge_cr | \n",
+ " veh_veh_cr | \n",
+ "
\n",
+ " \n",
+ " num_scenes | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 10 | \n",
+ " 0.06 | \n",
+ " 0.55 | \n",
+ " 0.15 | \n",
+ "
\n",
+ " \n",
+ " 100 | \n",
+ " 0.12 | \n",
+ " 0.33 | \n",
+ " 0.22 | \n",
+ "
\n",
+ " \n",
+ "
\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 125,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "performance_df_avg_valid = df_il_valid.groupby('num_scenes')[['goal_rate', 'veh_edge_cr', 'veh_veh_cr']].mean()\n",
+ "\n",
+ "performance_df_avg_valid.style.format('{:.3f}', na_rep=\"\")\\\n",
+ " .bar(subset=['goal_rate'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n",
+ " .bar(subset=['veh_edge_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .bar(subset=['veh_veh_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n",
+ " .set_caption(\"Aggregated Behavioral Cloning Performance Metrics (validation data)
\").format(\"{:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Compare "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 145,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_il = pd.concat([df_il_train, df_il_valid])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 150,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n",
+ "\n",
+ "# g = sns.violinplot(data=df_il, x='num_scenes', y='act_acc', hue='type', ax=axs[0]);\n",
+ "# sns.swarmplot(data=df_il, x='num_scenes', y='act_acc', hue='type', color=\"k\", size=3, ax=g.ax)\n",
+ "\n",
+ "# g = sns.violinplot(data=df_il, x='num_scenes', y='act_acc', hue='type', ax=axs[0]);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 173,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n",
+ "\n",
+ "fig.suptitle('Human likeness scores for IL policies', fontsize=16)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='act_acc', hue='type', ax=axs[0], legend=True,)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='accel_val_mae', hue='type', ax=axs[1], legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='steer_val_mae', hue='type', ax=axs[2], legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='pos_rmse', hue='type', ax=axs[3], legend=False)\n",
+ "\n",
+ "fig.tight_layout()\n",
+ "sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 172,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n",
+ "\n",
+ "fig.suptitle('Performance scores for IL policies', fontsize=16)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='goal_rate', hue='type', ax=axs[0], legend=True,)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='veh_edge_cr', hue='type', ax=axs[1], legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_il, x='num_scenes', y='veh_veh_cr', hue='type', ax=axs[2], legend=False)\n",
+ "\n",
+ "fig.tight_layout()\n",
+ "sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "nocturne_lab",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/policy_performance_analysis.ipynb b/evaluation/policy_performance_analysis.ipynb
index d6600319..50e6bf82 100644
--- a/evaluation/policy_performance_analysis.ipynb
+++ b/evaluation/policy_performance_analysis.ipynb
@@ -24,7 +24,7 @@
},
{
"cell_type": "code",
- "execution_count": 86,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -64,53 +64,27 @@
},
{
"cell_type": "code",
- "execution_count": 87,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
+ "MAX_FILES = 50\n",
+ "\n",
"# Load config files\n",
"env_config = load_config_nb(\"env_config\")\n",
"exp_config = load_config_nb(\"exp_config\")\n",
"model_config = load_config_nb(\"model_config\")\n",
"\n",
"# Set data path\n",
- "env_config.data_path = \"../data_10/train/\"\n",
+ "env_config.data_path = \"../data_full/train/\"\n",
+ "env_config.num_files = MAX_FILES\n",
"\n",
"# Logging level set to INFO\n",
"LOGGING_LEVEL = \"INFO\"\n",
"\n",
"# Scenes on which to evaluate the models\n",
"file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n",
- "eval_files = [os.path.basename(file) for file in file_paths]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 88,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['tfrecord-00004-of-01000_378.json',\n",
- " 'tfrecord-00003-of-01000_109.json',\n",
- " 'tfrecord-00004-of-01000_61.json',\n",
- " 'tfrecord-00012-of-01000_87.json',\n",
- " 'tfrecord-00007-of-01000_237.json',\n",
- " 'tfrecord-00005-of-01000_423.json',\n",
- " 'tfrecord-00012-of-01000_246.json',\n",
- " 'tfrecord-00012-of-01000_389.json',\n",
- " 'tfrecord-00001-of-01000_307.json',\n",
- " 'tfrecord-00004-of-01000_157.json']"
- ]
- },
- "execution_count": 88,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "eval_files"
+ "eval_files = [os.path.basename(file) for file in file_paths][:MAX_FILES]"
]
},
{
@@ -122,7 +96,7 @@
},
{
"cell_type": "code",
- "execution_count": 89,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -141,9 +115,582 @@
},
{
"cell_type": "code",
- "execution_count": 131,
+ "execution_count": 8,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:root:veh 37 at t = 50 returns None action!\n",
+ "INFO:root:veh 37 at t = 51 returns None action!\n",
+ "INFO:root:veh 44 at t = 75 returns None action!\n",
+ "INFO:root:veh 44 at t = 76 returns None action!\n",
+ "INFO:root:veh 44 at t = 32 returns None action!\n",
+ "INFO:root:veh 44 at t = 33 returns None action!\n",
+ "INFO:root:veh 44 at t = 34 returns None action!\n",
+ "INFO:root:veh 44 at t = 36 returns None action!\n",
+ "INFO:root:veh 44 at t = 37 returns None action!\n",
+ "INFO:root:veh 44 at t = 38 returns None action!\n",
+ "INFO:root:veh 4 at t = 2 returns None action!\n",
+ "INFO:root:veh 41 at t = 1 returns None action!\n",
+ "INFO:root:veh 41 at t = 2 returns None action!\n",
+ "INFO:root:veh 17 at t = 15 returns None action!\n",
+ "INFO:root:veh 17 at t = 16 returns None action!\n",
+ "INFO:root:veh 17 at t = 17 returns None action!\n",
+ "INFO:root:veh 17 at t = 18 returns None action!\n",
+ "INFO:root:veh 17 at t = 19 returns None action!\n",
+ "INFO:root:veh 17 at t = 31 returns None action!\n",
+ "INFO:root:veh 17 at t = 32 returns None action!\n",
+ "INFO:root:veh 17 at t = 33 returns None action!\n",
+ "INFO:root:veh 17 at t = 34 returns None action!\n",
+ "INFO:root:veh 17 at t = 43 returns None action!\n",
+ "INFO:root:veh 17 at t = 44 returns None action!\n",
+ "INFO:root:veh 17 at t = 45 returns None action!\n",
+ "INFO:root:veh 17 at t = 46 returns None action!\n",
+ "INFO:root:veh 17 at t = 47 returns None action!\n",
+ "INFO:root:veh 17 at t = 51 returns None action!\n",
+ "INFO:root:veh 17 at t = 52 returns None action!\n",
+ "INFO:root:veh 14 at t = 1 returns None action!\n",
+ "INFO:root:veh 14 at t = 2 returns None action!\n",
+ "INFO:root:veh 14 at t = 3 returns None action!\n",
+ "INFO:root:veh 14 at t = 4 returns None action!\n",
+ "INFO:root:veh 14 at t = 44 returns None action!\n",
+ "INFO:root:veh 14 at t = 45 returns None action!\n",
+ "INFO:root:veh 14 at t = 46 returns None action!\n",
+ "INFO:root:veh 0 at t = 14 returns None action!\n",
+ "INFO:root:veh 0 at t = 15 returns None action!\n",
+ "INFO:root:veh 0 at t = 16 returns None action!\n",
+ "INFO:root:veh 0 at t = 17 returns None action!\n",
+ "INFO:root:veh 0 at t = 18 returns None action!\n",
+ "INFO:root:veh 0 at t = 20 returns None action!\n",
+ "INFO:root:veh 0 at t = 21 returns None action!\n",
+ "INFO:root:veh 1 at t = 62 returns None action!\n",
+ "INFO:root:veh 1 at t = 63 returns None action!\n",
+ "INFO:root:veh 7 at t = 33 returns None action!\n",
+ "INFO:root:veh 17 at t = 34 returns None action!\n",
+ "INFO:root:veh 17 at t = 35 returns None action!\n",
+ "INFO:root:veh 2 at t = 20 returns None action!\n",
+ "INFO:root:veh 2 at t = 21 returns None action!\n",
+ "INFO:root:veh 2 at t = 29 returns None action!\n",
+ "INFO:root:veh 2 at t = 30 returns None action!\n",
+ "INFO:root:veh 2 at t = 31 returns None action!\n",
+ "INFO:root:veh 47 at t = 0 returns None action!\n",
+ "INFO:root:veh 47 at t = 1 returns None action!\n",
+ "INFO:root:veh 47 at t = 2 returns None action!\n",
+ "INFO:root:veh 25 at t = 23 returns None action!\n",
+ "INFO:root:veh 25 at t = 24 returns None action!\n",
+ "INFO:root:veh 25 at t = 25 returns None action!\n",
+ "INFO:root:veh 25 at t = 26 returns None action!\n",
+ "INFO:root:veh 25 at t = 27 returns None action!\n",
+ "INFO:root:veh 26 at t = 29 returns None action!\n",
+ "INFO:root:veh 26 at t = 30 returns None action!\n",
+ "INFO:root:veh 11 at t = 16 returns None action!\n",
+ "INFO:root:veh 11 at t = 17 returns None action!\n",
+ "INFO:root:veh 10 at t = 65 returns None action!\n",
+ "INFO:root:veh 10 at t = 66 returns None action!\n",
+ "INFO:root:veh 58 at t = 1 returns None action!\n",
+ "INFO:root:veh 58 at t = 2 returns None action!\n",
+ "INFO:root:veh 59 at t = 3 returns None action!\n",
+ "INFO:root:veh 59 at t = 4 returns None action!\n",
+ "INFO:root:veh 55 at t = 12 returns None action!\n",
+ "INFO:root:veh 55 at t = 13 returns None action!\n",
+ "INFO:root:veh 55 at t = 14 returns None action!\n",
+ "INFO:root:veh 55 at t = 15 returns None action!\n",
+ "INFO:root:veh 58 at t = 20 returns None action!\n",
+ "INFO:root:veh 58 at t = 21 returns None action!\n",
+ "INFO:root:veh 59 at t = 24 returns None action!\n",
+ "INFO:root:veh 59 at t = 25 returns None action!\n",
+ "INFO:root:veh 69 at t = 29 returns None action!\n",
+ "INFO:root:veh 69 at t = 30 returns None action!\n",
+ "INFO:root:veh 58 at t = 32 returns None action!\n",
+ "INFO:root:veh 58 at t = 33 returns None action!\n",
+ "INFO:root:veh 46 at t = 1 returns None action!\n",
+ "INFO:root:veh 46 at t = 2 returns None action!\n",
+ "INFO:root:veh 46 at t = 3 returns None action!\n",
+ "INFO:root:veh 46 at t = 4 returns None action!\n",
+ "INFO:root:veh 14 at t = 8 returns None action!\n",
+ "INFO:root:veh 14 at t = 9 returns None action!\n",
+ "INFO:root:veh 26 at t = 25 returns None action!\n",
+ "INFO:root:veh 26 at t = 26 returns None action!\n",
+ "INFO:root:veh 26 at t = 29 returns None action!\n",
+ "INFO:root:veh 26 at t = 30 returns None action!\n",
+ "INFO:root:veh 30 at t = 11 returns None action!\n",
+ "INFO:root:veh 30 at t = 12 returns None action!\n",
+ "INFO:root:veh 30 at t = 13 returns None action!\n",
+ "INFO:root:veh 30 at t = 14 returns None action!\n",
+ "INFO:root:veh 30 at t = 31 returns None action!\n",
+ "INFO:root:veh 30 at t = 32 returns None action!\n",
+ "INFO:root:veh 30 at t = 33 returns None action!\n",
+ "INFO:root:veh 30 at t = 34 returns None action!\n",
+ "INFO:root:veh 11 at t = 66 returns None action!\n",
+ "INFO:root:veh 11 at t = 67 returns None action!\n",
+ "INFO:root:veh 11 at t = 68 returns None action!\n",
+ "INFO:root:veh 11 at t = 70 returns None action!\n",
+ "INFO:root:veh 11 at t = 71 returns None action!\n",
+ "INFO:root:veh 59 at t = 31 returns None action!\n",
+ "INFO:root:veh 59 at t = 32 returns None action!\n",
+ "INFO:root:veh 59 at t = 33 returns None action!\n",
+ "INFO:root:veh 59 at t = 34 returns None action!\n",
+ "INFO:root:veh 62 at t = 35 returns None action!\n",
+ "INFO:root:veh 62 at t = 36 returns None action!\n",
+ "INFO:root:veh 8 at t = 42 returns None action!\n",
+ "INFO:root:veh 8 at t = 43 returns None action!\n",
+ "INFO:root:veh 8 at t = 44 returns None action!\n",
+ "INFO:root:veh 8 at t = 45 returns None action!\n",
+ "INFO:root:veh 8 at t = 46 returns None action!\n",
+ "INFO:root:veh 60 at t = 0 returns None action!\n",
+ "INFO:root:veh 26 at t = 0 returns None action!\n",
+ "INFO:root:veh 60 at t = 1 returns None action!\n",
+ "INFO:root:veh 60 at t = 2 returns None action!\n",
+ "INFO:root:veh 60 at t = 3 returns None action!\n",
+ "INFO:root:veh 49 at t = 3 returns None action!\n",
+ "INFO:root:veh 60 at t = 4 returns None action!\n",
+ "INFO:root:veh 49 at t = 4 returns None action!\n",
+ "INFO:root:veh 49 at t = 5 returns None action!\n",
+ "INFO:root:veh 60 at t = 7 returns None action!\n",
+ "INFO:root:veh 60 at t = 8 returns None action!\n",
+ "INFO:root:veh 29 at t = 50 returns None action!\n",
+ "INFO:root:veh 29 at t = 51 returns None action!\n",
+ "INFO:root:veh 29 at t = 52 returns None action!\n",
+ "INFO:root:veh 4 at t = 2 returns None action!\n",
+ "INFO:root:veh 4 at t = 3 returns None action!\n",
+ "INFO:root:veh 4 at t = 4 returns None action!\n",
+ "INFO:root:veh 4 at t = 5 returns None action!\n",
+ "INFO:root:veh 13 at t = 8 returns None action!\n",
+ "INFO:root:veh 13 at t = 9 returns None action!\n",
+ "INFO:root:veh 35 at t = 0 returns None action!\n",
+ "INFO:root:veh 35 at t = 1 returns None action!\n",
+ "INFO:root:veh 35 at t = 2 returns None action!\n",
+ "INFO:root:veh 35 at t = 3 returns None action!\n",
+ "INFO:root:veh 35 at t = 4 returns None action!\n",
+ "INFO:root:veh 23 at t = 5 returns None action!\n",
+ "INFO:root:veh 6 at t = 2 returns None action!\n",
+ "INFO:root:veh 6 at t = 3 returns None action!\n",
+ "INFO:root:veh 6 at t = 29 returns None action!\n",
+ "INFO:root:veh 6 at t = 30 returns None action!\n",
+ "INFO:root:veh 22 at t = 0 returns None action!\n",
+ "INFO:root:veh 22 at t = 1 returns None action!\n",
+ "INFO:root:veh 22 at t = 2 returns None action!\n",
+ "INFO:root:veh 22 at t = 3 returns None action!\n",
+ "INFO:root:veh 3 at t = 7 returns None action!\n",
+ "INFO:root:veh 3 at t = 8 returns None action!\n",
+ "INFO:root:veh 20 at t = 20 returns None action!\n",
+ "INFO:root:veh 50 at t = 0 returns None action!\n",
+ "INFO:root:veh 50 at t = 1 returns None action!\n",
+ "INFO:root:veh 2 at t = 71 returns None action!\n",
+ "INFO:root:veh 26 at t = 63 returns None action!\n",
+ "INFO:root:veh 26 at t = 64 returns None action!\n",
+ "INFO:root:veh 26 at t = 65 returns None action!\n",
+ "INFO:root:veh 26 at t = 74 returns None action!\n",
+ "INFO:root:veh 26 at t = 75 returns None action!\n",
+ "INFO:root:veh 26 at t = 76 returns None action!\n",
+ "INFO:root:veh 26 at t = 77 returns None action!\n",
+ "INFO:root:veh 7 at t = 20 returns None action!\n",
+ "INFO:root:veh 7 at t = 21 returns None action!\n",
+ "INFO:root:veh 4 at t = 41 returns None action!\n",
+ "INFO:root:veh 4 at t = 42 returns None action!\n",
+ "INFO:root:veh 3 at t = 10 returns None action!\n",
+ "INFO:root:veh 3 at t = 11 returns None action!\n",
+ "INFO:root:veh 3 at t = 12 returns None action!\n",
+ "INFO:root:veh 24 at t = 21 returns None action!\n",
+ "INFO:root:veh 24 at t = 22 returns None action!\n",
+ "INFO:root:veh 24 at t = 23 returns None action!\n",
+ "INFO:root:veh 24 at t = 24 returns None action!\n",
+ "INFO:root:veh 16 at t = 28 returns None action!\n",
+ "INFO:root:veh 16 at t = 29 returns None action!\n",
+ "INFO:root:veh 16 at t = 30 returns None action!\n",
+ "INFO:root:veh 16 at t = 31 returns None action!\n",
+ "INFO:root:veh 23 at t = 17 returns None action!\n",
+ "INFO:root:veh 23 at t = 18 returns None action!\n",
+ "INFO:root:veh 23 at t = 22 returns None action!\n",
+ "INFO:root:veh 23 at t = 23 returns None action!\n",
+ "INFO:root:veh 23 at t = 31 returns None action!\n",
+ "INFO:root:veh 23 at t = 32 returns None action!\n",
+ "INFO:root:veh 1 at t = 76 returns None action!\n",
+ "INFO:root:veh 1 at t = 77 returns None action!\n",
+ "INFO:root:veh 1 at t = 78 returns None action!\n",
+ "INFO:root:veh 36 at t = 2 returns None action!\n",
+ "INFO:root:veh 36 at t = 3 returns None action!\n",
+ "INFO:root:veh 36 at t = 4 returns None action!\n",
+ "INFO:root:veh 36 at t = 5 returns None action!\n",
+ "INFO:root:veh 36 at t = 6 returns None action!\n",
+ "INFO:root:veh 15 at t = 4 returns None action!\n",
+ "INFO:root:veh 15 at t = 5 returns None action!\n",
+ "INFO:root:veh 15 at t = 6 returns None action!\n",
+ "INFO:root:veh 15 at t = 7 returns None action!\n",
+ "INFO:root:veh 8 at t = 69 returns None action!\n",
+ "INFO:root:veh 8 at t = 70 returns None action!\n",
+ "INFO:root:veh 8 at t = 71 returns None action!\n",
+ "INFO:root:veh 1 at t = 37 returns None action!\n",
+ "INFO:root:veh 1 at t = 38 returns None action!\n",
+ "INFO:root:veh 1 at t = 39 returns None action!\n",
+ "INFO:root:veh 25 at t = 15 returns None action!\n",
+ "INFO:root:veh 25 at t = 16 returns None action!\n",
+ "INFO:root:veh 25 at t = 17 returns None action!\n",
+ "INFO:root:veh 25 at t = 18 returns None action!\n",
+ "INFO:root:veh 28 at t = 1 returns None action!\n",
+ "INFO:root:veh 28 at t = 2 returns None action!\n",
+ "INFO:root:veh 2 at t = 54 returns None action!\n",
+ "INFO:root:veh 2 at t = 55 returns None action!\n",
+ "INFO:root:veh 2 at t = 56 returns None action!\n",
+ "INFO:root:veh 2 at t = 57 returns None action!\n",
+ "INFO:root:veh 2 at t = 58 returns None action!\n",
+ "INFO:root:veh 2 at t = 59 returns None action!\n",
+ "INFO:root:veh 2 at t = 60 returns None action!\n",
+ "INFO:root:veh 2 at t = 61 returns None action!\n",
+ "INFO:root:veh 2 at t = 64 returns None action!\n",
+ "INFO:root:veh 2 at t = 65 returns None action!\n",
+ "INFO:root:veh 2 at t = 66 returns None action!\n",
+ "INFO:root:veh 9 at t = 15 returns None action!\n",
+ "INFO:root:veh 6 at t = 55 returns None action!\n",
+ "INFO:root:veh 6 at t = 56 returns None action!\n",
+ "INFO:root:veh 6 at t = 57 returns None action!\n",
+ "INFO:root:veh 6 at t = 59 returns None action!\n",
+ "INFO:root:veh 6 at t = 60 returns None action!\n",
+ "INFO:root:veh 6 at t = 61 returns None action!\n",
+ "INFO:root:veh 54 at t = 32 returns None action!\n",
+ "INFO:root:veh 54 at t = 33 returns None action!\n",
+ "INFO:root:veh 54 at t = 37 returns None action!\n",
+ "INFO:root:veh 54 at t = 38 returns None action!\n",
+ "INFO:root:veh 54 at t = 39 returns None action!\n",
+ "INFO:root:veh 54 at t = 40 returns None action!\n",
+ "INFO:root:veh 54 at t = 41 returns None action!\n",
+ "INFO:root:veh 54 at t = 42 returns None action!\n",
+ "INFO:root:veh 54 at t = 53 returns None action!\n",
+ "INFO:root:veh 54 at t = 54 returns None action!\n",
+ "INFO:root:veh 54 at t = 55 returns None action!\n",
+ "INFO:root:veh 54 at t = 56 returns None action!\n",
+ "INFO:root:veh 56 at t = 63 returns None action!\n",
+ "INFO:root:veh 56 at t = 64 returns None action!\n",
+ "INFO:root:veh 56 at t = 65 returns None action!\n",
+ "INFO:root:veh 56 at t = 66 returns None action!\n",
+ "INFO:root:veh 56 at t = 67 returns None action!\n",
+ "INFO:root:veh 56 at t = 68 returns None action!\n",
+ "INFO:root:veh 56 at t = 70 returns None action!\n",
+ "INFO:root:veh 56 at t = 71 returns None action!\n",
+ "INFO:root:veh 56 at t = 72 returns None action!\n",
+ "INFO:root:veh 56 at t = 73 returns None action!\n",
+ "INFO:root:veh 56 at t = 74 returns None action!\n",
+ "INFO:root:veh 56 at t = 75 returns None action!\n",
+ "INFO:root:veh 11 at t = 37 returns None action!\n",
+ "INFO:root:veh 11 at t = 38 returns None action!\n",
+ "INFO:root:veh 10 at t = 33 returns None action!\n",
+ "INFO:root:veh 10 at t = 34 returns None action!\n",
+ "INFO:root:veh 10 at t = 35 returns None action!\n",
+ "INFO:root:veh 10 at t = 36 returns None action!\n",
+ "INFO:root:veh 86 at t = 4 returns None action!\n",
+ "INFO:root:veh 86 at t = 5 returns None action!\n",
+ "INFO:root:veh 86 at t = 6 returns None action!\n",
+ "INFO:root:veh 86 at t = 7 returns None action!\n",
+ "INFO:root:veh 86 at t = 8 returns None action!\n",
+ "INFO:root:veh 86 at t = 21 returns None action!\n",
+ "INFO:root:veh 86 at t = 22 returns None action!\n",
+ "INFO:root:veh 6 at t = 18 returns None action!\n",
+ "INFO:root:veh 6 at t = 19 returns None action!\n",
+ "INFO:root:veh 31 at t = 54 returns None action!\n",
+ "INFO:root:veh 31 at t = 55 returns None action!\n",
+ "INFO:root:veh 36 at t = 72 returns None action!\n",
+ "INFO:root:veh 36 at t = 73 returns None action!\n",
+ "INFO:root:veh 4 at t = 19 returns None action!\n",
+ "INFO:root:veh 4 at t = 20 returns None action!\n",
+ "INFO:root:veh 54 at t = 67 returns None action!\n",
+ "INFO:root:veh 54 at t = 68 returns None action!\n",
+ "INFO:root:veh 54 at t = 69 returns None action!\n",
+ "INFO:root:veh 3 at t = 4 returns None action!\n",
+ "INFO:root:veh 3 at t = 5 returns None action!\n",
+ "INFO:root:veh 45 at t = 44 returns None action!\n",
+ "INFO:root:veh 45 at t = 45 returns None action!\n",
+ "INFO:root:veh 45 at t = 46 returns None action!\n",
+ "INFO:root:veh 45 at t = 47 returns None action!\n",
+ "INFO:root:veh 48 at t = 64 returns None action!\n",
+ "INFO:root:veh 48 at t = 65 returns None action!\n",
+ "INFO:root:veh 48 at t = 68 returns None action!\n",
+ "INFO:root:veh 48 at t = 69 returns None action!\n",
+ "INFO:root:veh 45 at t = 78 returns None action!\n",
+ "INFO:root:veh 17 at t = 37 returns None action!\n",
+ "INFO:root:veh 17 at t = 38 returns None action!\n",
+ "INFO:root:veh 17 at t = 39 returns None action!\n",
+ "INFO:root:veh 17 at t = 40 returns None action!\n",
+ "INFO:root:veh 17 at t = 41 returns None action!\n",
+ "INFO:root:veh 44 at t = 0 returns None action!\n",
+ "INFO:root:veh 44 at t = 1 returns None action!\n",
+ "INFO:root:veh 3 at t = 17 returns None action!\n",
+ "INFO:root:veh 3 at t = 18 returns None action!\n",
+ "INFO:root:veh 3 at t = 19 returns None action!\n",
+ "INFO:root:veh 3 at t = 20 returns None action!\n",
+ "INFO:root:veh 44 at t = 28 returns None action!\n",
+ "INFO:root:veh 44 at t = 29 returns None action!\n",
+ "INFO:root:veh 3 at t = 3 returns None action!\n",
+ "INFO:root:veh 3 at t = 4 returns None action!\n",
+ "INFO:root:veh 22 at t = 23 returns None action!\n",
+ "INFO:root:veh 22 at t = 24 returns None action!\n",
+ "INFO:root:veh 22 at t = 28 returns None action!\n",
+ "INFO:root:veh 22 at t = 29 returns None action!\n",
+ "INFO:root:veh 22 at t = 30 returns None action!\n",
+ "INFO:root:veh 22 at t = 31 returns None action!\n",
+ "INFO:root:veh 22 at t = 34 returns None action!\n",
+ "INFO:root:veh 22 at t = 35 returns None action!\n",
+ "INFO:root:veh 22 at t = 37 returns None action!\n",
+ "INFO:root:veh 22 at t = 38 returns None action!\n",
+ "INFO:root:veh 22 at t = 39 returns None action!\n",
+ "INFO:root:veh 22 at t = 40 returns None action!\n",
+ "INFO:root:veh 22 at t = 41 returns None action!\n",
+ "INFO:root:veh 22 at t = 42 returns None action!\n",
+ "INFO:root:veh 22 at t = 43 returns None action!\n",
+ "INFO:root:veh 22 at t = 47 returns None action!\n",
+ "INFO:root:veh 22 at t = 48 returns None action!\n",
+ "INFO:root:veh 94 at t = 24 returns None action!\n",
+ "INFO:root:veh 94 at t = 25 returns None action!\n",
+ "INFO:root:veh 94 at t = 26 returns None action!\n",
+ "INFO:root:veh 9 at t = 24 returns None action!\n",
+ "INFO:root:veh 9 at t = 25 returns None action!\n",
+ "INFO:root:veh 59 at t = 60 returns None action!\n",
+ "INFO:root:veh 59 at t = 61 returns None action!\n",
+ "INFO:root:veh 59 at t = 62 returns None action!\n",
+ "INFO:root:veh 59 at t = 63 returns None action!\n",
+ "INFO:root:veh 59 at t = 64 returns None action!\n",
+ "INFO:root:veh 59 at t = 67 returns None action!\n",
+ "INFO:root:veh 59 at t = 68 returns None action!\n",
+ "INFO:root:veh 59 at t = 69 returns None action!\n",
+ "INFO:root:veh 44 at t = 0 returns None action!\n",
+ "INFO:root:veh 9 at t = 12 returns None action!\n",
+ "INFO:root:veh 9 at t = 13 returns None action!\n",
+ "INFO:root:veh 9 at t = 14 returns None action!\n",
+ "INFO:root:veh 9 at t = 15 returns None action!\n",
+ "INFO:root:veh 9 at t = 16 returns None action!\n",
+ "INFO:root:veh 9 at t = 17 returns None action!\n",
+ "INFO:root:veh 9 at t = 18 returns None action!\n",
+ "INFO:root:veh 109 at t = 0 returns None action!\n",
+ "INFO:root:veh 109 at t = 1 returns None action!\n",
+ "INFO:root:veh 109 at t = 2 returns None action!\n",
+ "INFO:root:veh 109 at t = 3 returns None action!\n",
+ "INFO:root:veh 109 at t = 19 returns None action!\n",
+ "INFO:root:veh 109 at t = 20 returns None action!\n",
+ "INFO:root:veh 109 at t = 21 returns None action!\n",
+ "INFO:root:veh 109 at t = 22 returns None action!\n",
+ "INFO:root:veh 109 at t = 30 returns None action!\n",
+ "INFO:root:veh 109 at t = 31 returns None action!\n",
+ "INFO:root:veh 109 at t = 32 returns None action!\n",
+ "INFO:root:veh 109 at t = 63 returns None action!\n",
+ "INFO:root:veh 109 at t = 64 returns None action!\n",
+ "INFO:root:veh 109 at t = 66 returns None action!\n",
+ "INFO:root:veh 109 at t = 67 returns None action!\n",
+ "INFO:root:veh 109 at t = 78 returns None action!\n",
+ "INFO:root:veh 3 at t = 75 returns None action!\n",
+ "INFO:root:veh 3 at t = 76 returns None action!\n",
+ "INFO:root:veh 10 at t = 5 returns None action!\n",
+ "INFO:root:veh 10 at t = 6 returns None action!\n",
+ "INFO:root:veh 4 at t = 11 returns None action!\n",
+ "INFO:root:veh 4 at t = 12 returns None action!\n",
+ "INFO:root:veh 1 at t = 15 returns None action!\n",
+ "INFO:root:veh 1 at t = 16 returns None action!\n",
+ "INFO:root:veh 1 at t = 17 returns None action!\n",
+ "INFO:root:veh 1 at t = 18 returns None action!\n",
+ "INFO:root:veh 5 at t = 18 returns None action!\n",
+ "INFO:root:veh 51 at t = 65 returns None action!\n",
+ "INFO:root:veh 51 at t = 66 returns None action!\n",
+ "INFO:root:veh 48 at t = 67 returns None action!\n",
+ "INFO:root:veh 48 at t = 68 returns None action!\n",
+ "INFO:root:veh 2 at t = 6 returns None action!\n",
+ "INFO:root:veh 2 at t = 7 returns None action!\n",
+ "INFO:root:veh 2 at t = 22 returns None action!\n",
+ "INFO:root:veh 2 at t = 23 returns None action!\n",
+ "INFO:root:veh 3 at t = 22 returns None action!\n",
+ "INFO:root:veh 3 at t = 23 returns None action!\n",
+ "INFO:root:veh 3 at t = 24 returns None action!\n",
+ "INFO:root:veh 70 at t = 73 returns None action!\n",
+ "INFO:root:veh 70 at t = 74 returns None action!\n",
+ "INFO:root:veh 70 at t = 75 returns None action!\n",
+ "INFO:root:veh 55 at t = 58 returns None action!\n",
+ "INFO:root:veh 55 at t = 59 returns None action!\n",
+ "INFO:root:veh 55 at t = 60 returns None action!\n",
+ "INFO:root:veh 55 at t = 61 returns None action!\n",
+ "INFO:root:veh 13 at t = 11 returns None action!\n",
+ "INFO:root:veh 13 at t = 12 returns None action!\n",
+ "INFO:root:veh 13 at t = 36 returns None action!\n",
+ "INFO:root:veh 13 at t = 37 returns None action!\n",
+ "INFO:root:veh 84 at t = 1 returns None action!\n",
+ "INFO:root:veh 84 at t = 2 returns None action!\n",
+ "INFO:root:veh 10 at t = 15 returns None action!\n",
+ "INFO:root:veh 10 at t = 16 returns None action!\n",
+ "INFO:root:veh 10 at t = 17 returns None action!\n",
+ "INFO:root:veh 9 at t = 20 returns None action!\n",
+ "INFO:root:veh 9 at t = 21 returns None action!\n",
+ "INFO:root:veh 9 at t = 22 returns None action!\n",
+ "INFO:root:veh 5 at t = 40 returns None action!\n",
+ "INFO:root:veh 5 at t = 41 returns None action!\n",
+ "INFO:root:veh 5 at t = 44 returns None action!\n",
+ "INFO:root:veh 5 at t = 45 returns None action!\n",
+ "INFO:root:veh 5 at t = 46 returns None action!\n",
+ "INFO:root:veh 5 at t = 47 returns None action!\n",
+ "INFO:root:veh 5 at t = 50 returns None action!\n",
+ "INFO:root:veh 5 at t = 51 returns None action!\n",
+ "INFO:root:veh 5 at t = 52 returns None action!\n",
+ "INFO:root:veh 5 at t = 53 returns None action!\n",
+ "INFO:root:veh 5 at t = 54 returns None action!\n",
+ "INFO:root:veh 5 at t = 55 returns None action!\n",
+ "INFO:root:veh 5 at t = 56 returns None action!\n",
+ "INFO:root:veh 19 at t = 10 returns None action!\n",
+ "INFO:root:veh 19 at t = 11 returns None action!\n",
+ "INFO:root:veh 19 at t = 12 returns None action!\n",
+ "INFO:root:veh 3 at t = 58 returns None action!\n",
+ "INFO:root:veh 3 at t = 59 returns None action!\n",
+ "INFO:root:veh 3 at t = 60 returns None action!\n",
+ "INFO:root:veh 25 at t = 12 returns None action!\n",
+ "INFO:root:veh 25 at t = 13 returns None action!\n",
+ "INFO:root:veh 25 at t = 16 returns None action!\n",
+ "INFO:root:veh 15 at t = 3 returns None action!\n",
+ "INFO:root:veh 15 at t = 4 returns None action!\n",
+ "INFO:root:veh 15 at t = 5 returns None action!\n",
+ "INFO:root:veh 7 at t = 5 returns None action!\n",
+ "INFO:root:veh 15 at t = 6 returns None action!\n",
+ "INFO:root:veh 94 at t = 2 returns None action!\n",
+ "INFO:root:veh 94 at t = 3 returns None action!\n",
+ "INFO:root:veh 94 at t = 4 returns None action!\n",
+ "INFO:root:veh 88 at t = 5 returns None action!\n",
+ "INFO:root:veh 88 at t = 6 returns None action!\n",
+ "INFO:root:veh 88 at t = 7 returns None action!\n",
+ "INFO:root:veh 88 at t = 8 returns None action!\n",
+ "INFO:root:veh 88 at t = 22 returns None action!\n",
+ "INFO:root:veh 88 at t = 23 returns None action!\n",
+ "INFO:root:veh 10 at t = 22 returns None action!\n",
+ "INFO:root:veh 10 at t = 23 returns None action!\n",
+ "INFO:root:veh 22 at t = 2 returns None action!\n",
+ "INFO:root:veh 22 at t = 3 returns None action!\n",
+ "INFO:root:veh 34 at t = 0 returns None action!\n",
+ "INFO:root:veh 34 at t = 1 returns None action!\n",
+ "INFO:root:veh 32 at t = 7 returns None action!\n",
+ "INFO:root:veh 32 at t = 8 returns None action!\n",
+ "INFO:root:veh 28 at t = 8 returns None action!\n",
+ "INFO:root:veh 32 at t = 9 returns None action!\n",
+ "INFO:root:veh 28 at t = 9 returns None action!\n",
+ "INFO:root:veh 21 at t = 67 returns None action!\n",
+ "INFO:root:veh 21 at t = 68 returns None action!\n",
+ "INFO:root:veh 21 at t = 69 returns None action!\n",
+ "INFO:root:veh 21 at t = 70 returns None action!\n",
+ "INFO:root:veh 21 at t = 71 returns None action!\n",
+ "INFO:root:veh 21 at t = 1 returns None action!\n",
+ "INFO:root:veh 21 at t = 2 returns None action!\n",
+ "INFO:root:veh 21 at t = 3 returns None action!\n",
+ "INFO:root:veh 20 at t = 32 returns None action!\n",
+ "INFO:root:veh 20 at t = 33 returns None action!\n",
+ "INFO:root:veh 20 at t = 34 returns None action!\n",
+ "INFO:root:veh 71 at t = 13 returns None action!\n",
+ "INFO:root:veh 71 at t = 14 returns None action!\n",
+ "INFO:root:veh 71 at t = 15 returns None action!\n",
+ "INFO:root:veh 4 at t = 20 returns None action!\n",
+ "INFO:root:veh 4 at t = 21 returns None action!\n",
+ "INFO:root:veh 19 at t = 46 returns None action!\n",
+ "INFO:root:veh 19 at t = 47 returns None action!\n",
+ "INFO:root:veh 31 at t = 6 returns None action!\n",
+ "INFO:root:veh 31 at t = 7 returns None action!\n",
+ "INFO:root:veh 31 at t = 8 returns None action!\n",
+ "INFO:root:veh 26 at t = 50 returns None action!\n",
+ "INFO:root:veh 26 at t = 51 returns None action!\n",
+ "INFO:root:veh 15 at t = 58 returns None action!\n",
+ "INFO:root:veh 22 at t = 4 returns None action!\n",
+ "INFO:root:veh 22 at t = 5 returns None action!\n",
+ "INFO:root:veh 6 at t = 31 returns None action!\n",
+ "INFO:root:veh 6 at t = 32 returns None action!\n",
+ "INFO:root:veh 6 at t = 33 returns None action!\n",
+ "INFO:root:veh 6 at t = 34 returns None action!\n",
+ "INFO:root:veh 6 at t = 35 returns None action!\n",
+ "INFO:root:veh 15 at t = 16 returns None action!\n",
+ "INFO:root:veh 15 at t = 17 returns None action!\n",
+ "INFO:root:veh 15 at t = 18 returns None action!\n",
+ "INFO:root:veh 15 at t = 19 returns None action!\n",
+ "INFO:root:veh 50 at t = 1 returns None action!\n",
+ "INFO:root:veh 50 at t = 2 returns None action!\n",
+ "INFO:root:veh 50 at t = 5 returns None action!\n",
+ "INFO:root:veh 50 at t = 6 returns None action!\n",
+ "INFO:root:veh 26 at t = 7 returns None action!\n",
+ "INFO:root:veh 26 at t = 8 returns None action!\n",
+ "INFO:root:veh 41 at t = 41 returns None action!\n",
+ "INFO:root:veh 41 at t = 42 returns None action!\n",
+ "INFO:root:veh 41 at t = 43 returns None action!\n",
+ "INFO:root:veh 15 at t = 4 returns None action!\n",
+ "INFO:root:veh 15 at t = 5 returns None action!\n",
+ "INFO:root:veh 5 at t = 41 returns None action!\n",
+ "INFO:root:veh 5 at t = 42 returns None action!\n",
+ "INFO:root:veh 5 at t = 44 returns None action!\n",
+ "INFO:root:veh 5 at t = 45 returns None action!\n",
+ "INFO:root:veh 5 at t = 46 returns None action!\n",
+ "INFO:root:veh 5 at t = 47 returns None action!\n",
+ "INFO:root:veh 0 at t = 34 returns None action!\n",
+ "INFO:root:veh 0 at t = 35 returns None action!\n",
+ "INFO:root:veh 0 at t = 36 returns None action!\n",
+ "INFO:root:veh 8 at t = 37 returns None action!\n",
+ "INFO:root:veh 8 at t = 38 returns None action!\n",
+ "INFO:root:veh 8 at t = 39 returns None action!\n",
+ "INFO:root:veh 5 at t = 42 returns None action!\n",
+ "INFO:root:veh 5 at t = 43 returns None action!\n",
+ "INFO:root:veh 5 at t = 44 returns None action!\n",
+ "INFO:root:veh 5 at t = 45 returns None action!\n",
+ "INFO:root:veh 7 at t = 48 returns None action!\n",
+ "INFO:root:veh 7 at t = 49 returns None action!\n",
+ "INFO:root:veh 7 at t = 52 returns None action!\n",
+ "INFO:root:veh 7 at t = 53 returns None action!\n",
+ "INFO:root:veh 8 at t = 64 returns None action!\n",
+ "INFO:root:veh 8 at t = 65 returns None action!\n",
+ "INFO:root:veh 8 at t = 66 returns None action!\n",
+ "INFO:root:veh 21 at t = 6 returns None action!\n",
+ "INFO:root:veh 21 at t = 7 returns None action!\n",
+ "INFO:root:veh 21 at t = 8 returns None action!\n",
+ "INFO:root:veh 10 at t = 10 returns None action!\n",
+ "INFO:root:veh 10 at t = 11 returns None action!\n",
+ "INFO:root:veh 10 at t = 12 returns None action!\n",
+ "INFO:root:veh 27 at t = 21 returns None action!\n",
+ "INFO:root:veh 27 at t = 22 returns None action!\n",
+ "INFO:root:veh 27 at t = 23 returns None action!\n",
+ "INFO:root:veh 6 at t = 61 returns None action!\n",
+ "INFO:root:veh 6 at t = 62 returns None action!\n",
+ "INFO:root:veh 6 at t = 63 returns None action!\n",
+ "INFO:root:veh 6 at t = 69 returns None action!\n",
+ "INFO:root:veh 6 at t = 70 returns None action!\n",
+ "INFO:root:veh 6 at t = 71 returns None action!\n",
+ "INFO:root:veh 6 at t = 72 returns None action!\n",
+ "INFO:root:veh 6 at t = 73 returns None action!\n",
+ "INFO:root:veh 6 at t = 74 returns None action!\n",
+ "INFO:root:veh 42 at t = 20 returns None action!\n",
+ "INFO:root:veh 42 at t = 21 returns None action!\n",
+ "INFO:root:veh 42 at t = 22 returns None action!\n",
+ "INFO:root:veh 42 at t = 23 returns None action!\n",
+ "INFO:root:veh 42 at t = 65 returns None action!\n",
+ "INFO:root:veh 42 at t = 66 returns None action!\n",
+ "INFO:root:veh 42 at t = 78 returns None action!\n",
+ "INFO:root:veh 42 at t = 79 returns None action!\n",
+ "INFO:root:veh 3 at t = 0 returns None action!\n",
+ "INFO:root:veh 3 at t = 1 returns None action!\n",
+ "INFO:root:veh 3 at t = 2 returns None action!\n",
+ "INFO:root:veh 3 at t = 3 returns None action!\n",
+ "INFO:root:veh 0 at t = 1 returns None action!\n",
+ "INFO:root:veh 0 at t = 2 returns None action!\n",
+ "INFO:root:veh 17 at t = 7 returns None action!\n",
+ "INFO:root:veh 17 at t = 8 returns None action!\n",
+ "INFO:root:veh 17 at t = 10 returns None action!\n",
+ "INFO:root:veh 17 at t = 11 returns None action!\n",
+ "INFO:root:veh 17 at t = 12 returns None action!\n",
+ "INFO:root:veh 17 at t = 13 returns None action!\n",
+ "INFO:root:veh 17 at t = 32 returns None action!\n",
+ "INFO:root:veh 17 at t = 33 returns None action!\n",
+ "INFO:root:veh 7 at t = 7 returns None action!\n",
+ "INFO:root:veh 7 at t = 8 returns None action!\n",
+ "INFO:root:veh 1 at t = 36 returns None action!\n",
+ "INFO:root:veh 1 at t = 37 returns None action!\n",
+ "INFO:root:veh 11 at t = 77 returns None action!\n",
+ "INFO:root:veh 11 at t = 78 returns None action!\n",
+ "INFO:root:veh 5 at t = 6 returns None action!\n",
+ "INFO:root:veh 5 at t = 7 returns None action!\n",
+ "INFO:root:veh 5 at t = 8 returns None action!\n",
+ "INFO:root:veh 5 at t = 10 returns None action!\n",
+ "INFO:root:veh 5 at t = 11 returns None action!\n",
+ "INFO:root:veh 5 at t = 12 returns None action!\n",
+ "INFO:root:veh 5 at t = 13 returns None action!\n",
+ "INFO:root:veh 5 at t = 14 returns None action!\n"
+ ]
+ }
+ ],
"source": [
"# Load trained human reference policy\n",
"human_policy = load_policy(\n",
@@ -169,7 +716,7 @@
},
{
"cell_type": "code",
- "execution_count": 133,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -411,7 +958,7 @@
},
{
"cell_type": "code",
- "execution_count": 134,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -12571,7 +13118,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -12580,7 +13127,10 @@
"# Scenes on which to evaluate the models\n",
"rl_policy_paths = glob.glob(f\"{RL_POLICY_PATH}\" + \"/*.pt\")\n",
"rl_policy_names = [os.path.basename(file)[:-3] for file in rl_policy_paths]\n",
- "reg_weights = [0.025, 0.0]"
+ "reg_weights = [0.0]\n",
+ "\n",
+ "\n",
+ "rl_policy_names"
]
},
{
diff --git a/experiments/hr_rl/run_hr_ppo.py b/experiments/hr_rl/run_hr_ppo.py
index 0868c59d..9b1788a9 100644
--- a/experiments/hr_rl/run_hr_ppo.py
+++ b/experiments/hr_rl/run_hr_ppo.py
@@ -143,10 +143,11 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl
}
)
- lambdas = [0.0]
- for lam in lambdas:
+ num_files_list = [10, 100, 1000]
+
+ for scenes in num_files_list:
# Set regularization weight
- exp_config.reg_weight = lam
+ env_config.num_files = scenes
# Train
train(
diff --git a/experiments/hr_rl/run_hr_ppo_cli.py b/experiments/hr_rl/run_hr_ppo_cli.py
index abff0554..34ca1dfd 100644
--- a/experiments/hr_rl/run_hr_ppo_cli.py
+++ b/experiments/hr_rl/run_hr_ppo_cli.py
@@ -43,7 +43,7 @@
def run_hr_ppo(
- sweep_name: str = "hr_ppo",
+ sweep_name: str = exp_config.group,
steer_disc: int = 5,
accel_disc: int = 5,
ent_coef: float = 0.0,
@@ -159,6 +159,7 @@ def run_hr_ppo(
seed=exp_config.seed, # Seed for the pseudo random generators
verbose=exp_config.verbose,
tensorboard_log=f"runs/{run_id}" if run_id is not None else None,
+
device=exp_config.ppo.device,
env_config=env_config,
mlp_class=LateFusionMLP,
diff --git a/experiments/il/run_behavioral_cloning.py b/experiments/il/run_behavioral_cloning.py
index 86319163..2812a9de 100644
--- a/experiments/il/run_behavioral_cloning.py
+++ b/experiments/il/run_behavioral_cloning.py
@@ -7,6 +7,7 @@
from imitation.algorithms import bc
from imitation.data.types import Transitions
from imitation.util import logger as imit_logger
+from networks.mlp_late_fusion import LateFusionMLP, LateFusionMLPPolicy
from utils.wrappers import LightNocturneEnvWrapper
from utils.config import load_config
@@ -16,17 +17,22 @@
from utils.string_utils import date_to_str
if __name__ == "__main__":
+
+ MAX_EVAL_FILES = 12
+ NUM_TRAIN_FILES = 1000
# Create run
run = wandb.init(
project="eval_il_policy",
sync_tensorboard=True,
+ group=f"BC_S{NUM_TRAIN_FILES}",
)
# Configs
video_config = load_config("video_config")
- env_config = load_config("env_config")
bc_config = load_config("bc_config")
+ env_config = load_config("env_config")
+ env_config.num_files = NUM_TRAIN_FILES
# Device TODO: Add support for CUDA
device = "cpu"
@@ -63,23 +69,26 @@
demonstrations=transitions,
rng=rng,
device=device,
+ #policy=LateFusionMLPPolicy,
)
+ print(f'IL policy: \n{bc_trainer.policy}')
+
# Create evaluation env
env = LightNocturneEnvWrapper(env_config)
- eval_files = env.files
-
- # Check random behavior
- reward_before_training, _ = evaluate_policy(
- model=bc_trainer.policy,
- env=LightNocturneEnvWrapper(env_config),
- n_steps_per_episode=env_config.episode_length,
- n_eval_episodes=1,
- eval_files=eval_files,
- video_config=video_config,
- video_caption="BEFORE training",
- render=True,
- )
+ eval_files = env.files[:MAX_EVAL_FILES]
+
+ # # Check random behavior
+ # reward_before_training, _ = evaluate_policy(
+ # model=bc_trainer.policy,
+ # env=LightNocturneEnvWrapper(env_config),
+ # n_steps_per_episode=env_config.episode_length,
+ # n_eval_episodes=1,
+ # eval_files=eval_files,
+ # video_config=video_config,
+ # video_caption="BEFORE training",
+ # render=True,
+ # )
# Train
bc_trainer.train(
@@ -105,4 +114,4 @@
date_ = date_to_str(datetime.now())
if bc_config.save_model:
- bc_trainer.policy.save(path=f"{bc_config.save_model_path}{bc_config.model_name}_{date_}.pt")
\ No newline at end of file
+ bc_trainer.policy.save(path=f"{bc_config.save_model_path}{bc_config.model_name}_S{NUM_TRAIN_FILES}_{date_}.pt")
\ No newline at end of file
diff --git a/models/il/human_policy_2023_11_20.pt b/models/il/human_policy_2023_11_20.pt
deleted file mode 100644
index 9f950824..00000000
Binary files a/models/il/human_policy_2023_11_20.pt and /dev/null differ
diff --git a/networks/mlp_late_fusion.py b/networks/mlp_late_fusion.py
index ed9a7db3..aed85e0b 100644
--- a/networks/mlp_late_fusion.py
+++ b/networks/mlp_late_fusion.py
@@ -245,7 +245,6 @@ def _build_mlp_extractor(self) -> None:
**self.mlp_config,
)
-
if __name__ == "__main__":
# Load config
@@ -263,6 +262,21 @@ def _build_mlp_extractor(self) -> None:
obs = env.reset()
obs = torch.Tensor(obs)[:2]
+
+ # Define model architecture
+ model_config = Box(
+ {
+ "arch_ego_state": [8],
+ "arch_road_objects": [64],
+ "arch_road_graph": [128, 64],
+ "arch_shared_net": [],
+ "act_func": "tanh",
+ "dropout": 0.0,
+ "last_layer_dim_pi": 64,
+ "last_layer_dim_vf": 64,
+ }
+ )
+
# Test
model = RegularizedPPO(
reg_policy=None,
@@ -275,6 +289,9 @@ def _build_mlp_extractor(self) -> None:
seed=exp_config.seed, # Seed for the pseudo random generators
verbose=1,
device='cuda',
+ env_config=env_config,
+ mlp_class=LateFusionMLP,
+ mlp_config=model_config,
)
# print(model.policy)
model.learn(5000)
\ No newline at end of file
diff --git a/nocturne/envs/base_env.py b/nocturne/envs/base_env.py
index 173e38ca..ea089b7d 100644
--- a/nocturne/envs/base_env.py
+++ b/nocturne/envs/base_env.py
@@ -708,18 +708,27 @@ def _position_as_array(position: Vector2D) -> np.ndarray:
env = BaseEnv(config=env_config)
# Reset
- obs_dict = env.reset()
+ obs_dict = env.reset(filename='tfrecord-00421-of-01000_364.json')
# Get info
agent_ids = [agent_id for agent_id in obs_dict.keys()]
+ veh_objects = {agent.id: agent for agent in env.controlled_vehicles}
dead_agent_ids = []
- for step in range(100):
+ for step in range(80):
# Sample actions
action_dict = {agent_id: env.action_space.sample() for agent_id in agent_ids if agent_id not in dead_agent_ids}
- # Step in env
+ # Set in expert controlled mode
+ for obj in env.controlled_vehicles:
+ obj.expert_control = True
+
obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict)
+
+ print(f'step: {step}, done: {done_dict[37]}, info:\n {info_dict[37]}')
+
+ expert_action = env.scenario.expert_action(veh_objects[37], step)
+ print(f'act = {expert_action} \n')
# Update dead agents
for agent_id, is_done in done_dict.items():
diff --git a/utils/eval.py b/utils/eval.py
index 301de528..062a976b 100644
--- a/utils/eval.py
+++ b/utils/eval.py
@@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import logging
+from tqdm import tqdm
import torch
import wandb
import glob
@@ -48,6 +49,8 @@ def __init__(
def _get_scores(self):
"""Evaluate policy across a set of traffic scenes."""
+ logging.info(f'\n Evaluating policy on {len(self.eval_files)} files...')
+
# Make tables
df_eval = pd.DataFrame(
columns=[
@@ -56,12 +59,13 @@ def _get_scores(self):
"agents_controlled",
"reg_coef",
"act_acc",
+ "accel_val_mae",
+ "steer_val_mae",
"pos_rmse",
"speed_mae",
"goal_rate",
"veh_edge_cr",
"veh_veh_cr",
- "num_violations",
]
)
@@ -81,7 +85,7 @@ def _get_scores(self):
]
)
- for file in self.eval_files:
+ for file in tqdm(self.eval_files):
logging.debug(f"Evaluating policy on {file}...")
@@ -96,23 +100,23 @@ def _get_scores(self):
)
# Filter out invalid steps
- nonnan_ids = np.logical_not(
- np.logical_or(
- np.isnan(policy_actions),
- np.isnan(expert_actions),
- )
- )
+ nonnan_ids = ~np.isnan(expert_actions)
+
# Compute metrics
action_accuracy = self.get_action_accuracy(
policy_actions, expert_actions, nonnan_ids
)
position_rmse = self.get_pos_rmse(
- policy_pos, expert_pos, nonnan_ids
+ policy_pos, expert_pos,
)
speed_agent_mae = self.get_speed_mae(
- policy_speed, expert_speed, nonnan_ids
+ policy_speed, expert_speed,
+ )
+
+ abs_diff_accel, abs_diff_steer = self.get_action_val_diff(
+ policy_actions, expert_actions,
)
# Violations of the 3-second rule
@@ -125,12 +129,13 @@ def _get_scores(self):
"agents_controlled": expert_actions.shape[0],
"reg_coef": self.exp_config.reg_weight if self.reg_coef is None else self.reg_coef,
"act_acc": action_accuracy,
+ "accel_val_mae": abs_diff_accel,
+ "steer_val_mae": abs_diff_steer,
"pos_rmse": position_rmse,
"speed_mae": speed_agent_mae,
"goal_rate": policy_gr,
"veh_edge_cr": policy_edge_cr,
"veh_veh_cr": policy_veh_cr,
- "num_violations": num_violations,
}
df_eval.loc[len(df_eval)] = scene_perf
@@ -218,7 +223,9 @@ def _step_through_scene(self, filename: str, mode: str):
agent_speed[veh_idx, timestep] = veh_obj.speed
action_indices[veh_idx, timestep] = action_idx
else:
- logging.info(f'veh {veh_obj.id} at t = {timestep} returns None action!')
+ # Skip None actions (these are invalid)
+ logging.debug(f'veh {veh_obj.id} at t = {timestep} returns None action!')
+ continue
action_dict = {}
@@ -269,42 +276,67 @@ def _step_through_scene(self, filename: str, mode: str):
veh_edge_collision/self.num_agents,
veh_veh_collision/self.num_agents,
)
-
- def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids):
- """Get accuracy of agent actions.
+
+ def get_action_val_diff(self, pred_actions, expert_actions):
+ """Get difference between human action values and predicted action values.
Args:
pred_actions: (num_agents, num_steps_per_episode) the predicted actions of the agents.
expert_actions: (num_agents, num_steps_per_episode) the expert actions of the agents.
nonnan_ids: (num_agents, num_steps_per_episode) the indices of non-nan actions.
"""
- return (expert_actions[nonnan_ids] == pred_actions[nonnan_ids]).sum() / nonnan_ids.flatten().shape[0]
+ # Filter out invalid actions
+ nonnan_ids = np.logical_not(
+ np.logical_or(
+ np.isnan(pred_actions),
+ np.isnan(expert_actions),
+ )
+ )
+ valid_expert_acts = expert_actions[nonnan_ids]
+ valid_pred_acts = pred_actions[nonnan_ids]
- def get_pos_rmse(self, pred_actions, expert_actions, nonnan_ids):
- return np.sqrt(np.linalg.norm(pred_actions[nonnan_ids] - expert_actions[nonnan_ids])).mean()
-
- def get_speed_mae(self, pred_actions, expert_actions, nonnan_ids):
- return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean()
-
- def get_steer_mae(self, pred_actions, expert_actions, nonnan_ids):
- return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean()
+ exp_acc_vals, exp_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts)
+ pred_acc_vals, pred_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts)
+
+ for idx in range(valid_expert_acts.shape[0]):
+
+ # Get expert and predicted values
+ exp_acc_vals[idx], exp_steer_vals[idx] = self.env.idx_to_actions[valid_expert_acts[idx]]
+ pred_acc_vals[idx], pred_steer_vals[idx] = self.env.idx_to_actions[valid_pred_acts[idx]]
+
+ # Get mean absolute difference
+ abs_accel_diff = np.abs(exp_acc_vals - pred_acc_vals).mean()
+ abs_steer_diff = np.abs(exp_steer_vals - pred_steer_vals).mean()
+
+ return abs_accel_diff, abs_steer_diff
- def get_action_abs_distance(self, pred_actions, expert_actions, nonnan_ids, action_space_dim):
+ def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids):
"""Get accuracy of agent actions.
Args:
pred_actions: (num_agents, num_steps_per_episode) the predicted actions of the agents.
expert_actions: (num_agents, num_steps_per_episode) the expert actions of the agents.
nonnan_ids: (num_agents, num_steps_per_episode) the indices of non-nan actions.
"""
+ return (expert_actions[nonnan_ids] == pred_actions[nonnan_ids]).sum() / nonnan_ids.flatten().shape[0]
- num_agents = pred_actions.shape[0]
- agg_abs_dist = 0
-
- for idx in range(pred_actions.shape[0]):
- n_samples = pred_actions[0][nonnan_ids[0]].shape[0]
- agent_abs_dist = np.abs(pred_actions[idx][nonnan_ids[idx]] - expert_actions[idx][nonnan_ids[idx]]).sum() / n_samples
- agg_abs_dist += agent_abs_dist
-
- return agg_abs_dist / num_agents
+ def get_pos_rmse(self, pred_actions, expert_actions):
+ # Filter out invalid actions
+ nonnan_ids = np.logical_not(
+ np.logical_or(
+ np.isnan(pred_actions),
+ np.isnan(expert_actions),
+ )
+ )
+ return np.sqrt(np.linalg.norm(pred_actions[nonnan_ids] - expert_actions[nonnan_ids])).mean()
+
+ def get_speed_mae(self, pred_actions, expert_actions):
+ # Filter out invalid actions
+ nonnan_ids = np.logical_not(
+ np.logical_or(
+ np.isnan(pred_actions),
+ np.isnan(expert_actions),
+ )
+ )
+ return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean()
def get_veh_to_veh_distances(self, positions, velocities, time_gap_in_sec=3):
"""Calculate distances between vehicles at each time step and track
@@ -367,29 +399,11 @@ def _get_files(self, eval_files, file_limit):
env_config = load_config("env_config")
exp_config = load_config("exp_config")
- # env_config.data_path = "./data_10/train"
-
- # # Load trained human reference policy
- # human_policy = load_policy(
- # data_path="./models/il",
- # file_name="human_policy_10_scenes_2023_11_21",
- # )
-
- # # Evaluate policy
- # evaluator = EvaluatePolicy(
- # env_config=env_config,
- # exp_config=exp_config,
- # policy=human_policy,
- # log_to_wandb=False,
- # deterministic=True,
- # reg_coef=0.0,
- # return_trajectories=True,
- # )
-
- # il_results_check = evaluator._get_scores()
+ MAX_FILES = 50
- # Set data path
- env_config.data_path = "./data/train/"
+ # Train
+ train_file_paths = glob.glob(f"{env_config.data_path}" + "/tfrecord*")
+ train_eval_files = [os.path.basename(file) for file in train_file_paths][:MAX_FILES]
# Load human reference policy
human_policy = load_policy(
@@ -402,7 +416,7 @@ def _get_files(self, eval_files, file_limit):
env_config=env_config,
exp_config=exp_config,
policy=human_policy,
- eval_files=["tfrecord-00012-of-01000_389.json"],
+ eval_files=train_eval_files,
log_to_wandb=False,
deterministic=True,
reg_coef=0.0,
diff --git a/utils/imitation_learning/waymo_iterator.py b/utils/imitation_learning/waymo_iterator.py
index cde7eca8..a7184711 100644
--- a/utils/imitation_learning/waymo_iterator.py
+++ b/utils/imitation_learning/waymo_iterator.py
@@ -19,7 +19,7 @@
class TrajectoryIterator(IterableDataset):
"""Generates trajectories in Waymo scenes: sequences of observations and actions."""
- def __init__(self, data_path, env_config, with_replacement=True, file_limit=None):
+ def __init__(self, data_path, env_config, with_replacement=True, file_limit=-1):
self.data_path = data_path
self.config = env_config
self.env = BaseEnv(env_config)
@@ -32,7 +32,7 @@ def __init__(self, data_path, env_config, with_replacement=True, file_limit=None
super(TrajectoryIterator).__init__()
- logging.info(f"Using {len(self.file_names)} file(s): {self.file_names}")
+ logging.info(f"Using {len(self.file_names)} file(s)")
def __iter__(self):
"""Return an (expert_state, expert_action) iterable."""
@@ -93,7 +93,7 @@ def _discretize_expert_actions(self, filename: str):
# Get (continuous) expert action
expert_action = scenario.expert_action(veh_obj, timestep)
- # Check for invalid action (because no value available for taking
+ # Check for invalid actions (None) (because no value available for taking
# derivative) or because the vehicle is at an invalid state
if expert_action is None:
continue
@@ -106,6 +106,9 @@ def _discretize_expert_actions(self, filename: str):
expert_action_idx = self.actions_to_joint_idx[accel_grid_val, steering_grid_val][0]
+ if expert_action_idx is None:
+ logging.debug("Expert action is None!")
+
# Store
if timestep >= self.config.warmup_period:
df_actions.loc[timestep-self.config.warmup_period][veh_obj.getID()] = expert_action_idx
@@ -229,6 +232,7 @@ def _find_closest_index(self, action_grid, action):
if __name__ == "__main__":
env_config = load_config("env_config")
+ env_config.num_files = 1000
# Create iterator
waymo_iterator = TrajectoryIterator(
@@ -241,7 +245,11 @@ def _find_closest_index(self, action_grid, action):
rollouts = next(iter(
DataLoader(
waymo_iterator,
- batch_size=400,
+ batch_size=10_000, # Number of samples to generate
pin_memory=True,
)))
+
+ obs, acts, next_obs, dones = rollouts
+
+ print('hi')
diff --git a/utils/manage_models.py b/utils/manage_models.py
index 356f7c1c..f2433cf8 100644
--- a/utils/manage_models.py
+++ b/utils/manage_models.py
@@ -9,7 +9,7 @@
entity = "daphnecor"
project = "scaling_ppo"
- collection_name = "nocturne-hr-ppo-12_29_07_55"
+ collection_name = "nocturne-hr-ppo-12_30_21_43"
# Always initialize a W&B run to start tracking
wandb.init()