diff --git a/configs/exp_config.yaml b/configs/exp_config.yaml
index 91946233..2407da24 100644
--- a/configs/exp_config.yaml
+++ b/configs/exp_config.yaml
@@ -1,5 +1,5 @@
project: scaling_ppo
-group: playground
+group: effect_of_human_reg
env_id: Nocturne
seed: 42
track_wandb: true
@@ -19,10 +19,10 @@ ma_callback:
log_indiv_metrics: false
log_agent_actions: false
save_model: true
- model_save_freq: 100 # In iterations (one iter ~ (num_agents x n_steps))
+ model_save_freq: 300 # In iterations (one iter ~ (num_agents x n_steps))
save_video: true
record_n_scenes: 10 # Number of different scenes to render
- video_save_freq: 100 # Make a video every k iterations (100 iters ~ 1M steps)
+ video_save_freq: 500 # Make a video every k iterations (100 iters ~ 1M steps)
video_deterministic: true
ppo:
@@ -32,7 +32,7 @@ ppo:
vf_coef: 0.5 # Default in SB3 is 0.5
learn:
- total_timesteps: 2_000_000
+ total_timesteps: 10_000_000
progress_bar: false
# human-regularized RL
diff --git a/configs/video_config.yaml b/configs/video_config.yaml
index ddb7bcd8..b33096f7 100644
--- a/configs/video_config.yaml
+++ b/configs/video_config.yaml
@@ -6,5 +6,5 @@ render:
logging:
render_interval: 3
- fps: 4
+ fps: 2
where_am_i: headless_machine
\ No newline at end of file
diff --git a/evaluation/policy_performance_analysis.ipynb b/evaluation/policy_performance_analysis.ipynb
index 50e6bf82..07d01e7d 100644
--- a/evaluation/policy_performance_analysis.ipynb
+++ b/evaluation/policy_performance_analysis.ipynb
@@ -18,13 +18,12 @@
"- **Accuracy** to the expert actions. \n",
"- **Euclidean distance** to the expert positions at a given state and timepoint (the L2 distance between the controlled object's XY\n",
" location and its position in the logged history at the same timestep.) `np.linalg.norm(object_xy - log_xy, axis=-1)`\n",
- "- **Safe distance through the 3-second rule**\n",
- "- **Mean absolute error to the expert speed**"
+ "- **Mean absolute error to the expert speed, acceleration and steering wheel angle**"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -37,7 +36,10 @@
"import torch\n",
"import logging\n",
"import os\n",
+ "import wandb\n",
"import matplotlib.pyplot as plt\n",
+ "from utils.evaluation import evaluate_policy\n",
+ "from utils.wrappers import LightNocturneEnvWrapper\n",
"\n",
"from typing import Callable\n",
"from gym import spaces\n",
@@ -52,7 +54,9 @@
"sns.set('notebook', font_scale=1.1, rc={'figure.figsize': (10, 5)})\n",
"sns.set_style('ticks', rc={'figure.facecolor': 'none', 'axes.facecolor': 'none'})\n",
"%config InlineBackend.figure_format = 'svg'\n",
- "warnings.filterwarnings(\"ignore\")"
+ "warnings.filterwarnings(\"ignore\")\n",
+ "plt.set_loglevel('WARNING')\n",
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"policy_performance_analysis.ipynb\""
]
},
{
@@ -64,27 +68,83 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
- "MAX_FILES = 50\n",
+ "MAX_FILES = 100\n",
+ "DETERMINISTIC = True\n",
"\n",
"# Load config files\n",
"env_config = load_config_nb(\"env_config\")\n",
"exp_config = load_config_nb(\"exp_config\")\n",
+ "video_config = load_config_nb(\"video_config\")\n",
"model_config = load_config_nb(\"model_config\")\n",
"\n",
"# Set data path\n",
"env_config.data_path = \"../data_full/train/\"\n",
+ "env_config.val_data_path = \"../data_full/valid/\"\n",
"env_config.num_files = MAX_FILES\n",
"\n",
"# Logging level set to INFO\n",
"LOGGING_LEVEL = \"INFO\"\n",
"\n",
"# Scenes on which to evaluate the models\n",
- "file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n",
- "eval_files = [os.path.basename(file) for file in file_paths][:MAX_FILES]"
+ "# Make sure file order is fixed\n",
+ "train_file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n",
+ "train_eval_files = sorted([os.path.basename(file) for file in train_file_paths])\n",
+ "\n",
+ "# Valid\n",
+ "valid_file_paths = glob.glob(f\"{env_config.val_data_path}\" + \"/tfrecord*\")\n",
+ "valid_eval_files = sorted([os.path.basename(file) for file in valid_file_paths])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Helper functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def render_traffic_scenes(\n",
+ " eval_scenes, \n",
+ " policy, \n",
+ " run_name,\n",
+ " deterministic=True, \n",
+ " group_name=\"RL_S100\", \n",
+ " project_name=\"eval_hr_rl_policies\"\n",
+ " ):\n",
+ "\n",
+ " if deterministic:\n",
+ " mode = \"det\"\n",
+ " else:\n",
+ " mode = \"stoch\"\n",
+ "\n",
+ " # Create run\n",
+ " run = wandb.init(\n",
+ " project=project_name,\n",
+ " sync_tensorboard=True,\n",
+ " group=group_name,\n",
+ " name=f\"{run_name}_{mode}\",\n",
+ " )\n",
+ "\n",
+ " avg_rew, std_rew = evaluate_policy(\n",
+ " model=policy, \n",
+ " env=LightNocturneEnvWrapper(env_config),\n",
+ " eval_files=eval_scenes,\n",
+ " video_config=video_config,\n",
+ " video_caption=f\"\",\n",
+ " deterministic=deterministic,\n",
+ " render=True,\n",
+ " )\n",
+ "\n",
+ " run.finish()"
]
},
{
@@ -96,7 +156,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -115,579 +175,16 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:root:veh 37 at t = 50 returns None action!\n",
- "INFO:root:veh 37 at t = 51 returns None action!\n",
- "INFO:root:veh 44 at t = 75 returns None action!\n",
- "INFO:root:veh 44 at t = 76 returns None action!\n",
- "INFO:root:veh 44 at t = 32 returns None action!\n",
- "INFO:root:veh 44 at t = 33 returns None action!\n",
- "INFO:root:veh 44 at t = 34 returns None action!\n",
- "INFO:root:veh 44 at t = 36 returns None action!\n",
- "INFO:root:veh 44 at t = 37 returns None action!\n",
- "INFO:root:veh 44 at t = 38 returns None action!\n",
- "INFO:root:veh 4 at t = 2 returns None action!\n",
- "INFO:root:veh 41 at t = 1 returns None action!\n",
- "INFO:root:veh 41 at t = 2 returns None action!\n",
- "INFO:root:veh 17 at t = 15 returns None action!\n",
- "INFO:root:veh 17 at t = 16 returns None action!\n",
- "INFO:root:veh 17 at t = 17 returns None action!\n",
- "INFO:root:veh 17 at t = 18 returns None action!\n",
- "INFO:root:veh 17 at t = 19 returns None action!\n",
- "INFO:root:veh 17 at t = 31 returns None action!\n",
- "INFO:root:veh 17 at t = 32 returns None action!\n",
- "INFO:root:veh 17 at t = 33 returns None action!\n",
- "INFO:root:veh 17 at t = 34 returns None action!\n",
- "INFO:root:veh 17 at t = 43 returns None action!\n",
- "INFO:root:veh 17 at t = 44 returns None action!\n",
- "INFO:root:veh 17 at t = 45 returns None action!\n",
- "INFO:root:veh 17 at t = 46 returns None action!\n",
- "INFO:root:veh 17 at t = 47 returns None action!\n",
- "INFO:root:veh 17 at t = 51 returns None action!\n",
- "INFO:root:veh 17 at t = 52 returns None action!\n",
- "INFO:root:veh 14 at t = 1 returns None action!\n",
- "INFO:root:veh 14 at t = 2 returns None action!\n",
- "INFO:root:veh 14 at t = 3 returns None action!\n",
- "INFO:root:veh 14 at t = 4 returns None action!\n",
- "INFO:root:veh 14 at t = 44 returns None action!\n",
- "INFO:root:veh 14 at t = 45 returns None action!\n",
- "INFO:root:veh 14 at t = 46 returns None action!\n",
- "INFO:root:veh 0 at t = 14 returns None action!\n",
- "INFO:root:veh 0 at t = 15 returns None action!\n",
- "INFO:root:veh 0 at t = 16 returns None action!\n",
- "INFO:root:veh 0 at t = 17 returns None action!\n",
- "INFO:root:veh 0 at t = 18 returns None action!\n",
- "INFO:root:veh 0 at t = 20 returns None action!\n",
- "INFO:root:veh 0 at t = 21 returns None action!\n",
- "INFO:root:veh 1 at t = 62 returns None action!\n",
- "INFO:root:veh 1 at t = 63 returns None action!\n",
- "INFO:root:veh 7 at t = 33 returns None action!\n",
- "INFO:root:veh 17 at t = 34 returns None action!\n",
- "INFO:root:veh 17 at t = 35 returns None action!\n",
- "INFO:root:veh 2 at t = 20 returns None action!\n",
- "INFO:root:veh 2 at t = 21 returns None action!\n",
- "INFO:root:veh 2 at t = 29 returns None action!\n",
- "INFO:root:veh 2 at t = 30 returns None action!\n",
- "INFO:root:veh 2 at t = 31 returns None action!\n",
- "INFO:root:veh 47 at t = 0 returns None action!\n",
- "INFO:root:veh 47 at t = 1 returns None action!\n",
- "INFO:root:veh 47 at t = 2 returns None action!\n",
- "INFO:root:veh 25 at t = 23 returns None action!\n",
- "INFO:root:veh 25 at t = 24 returns None action!\n",
- "INFO:root:veh 25 at t = 25 returns None action!\n",
- "INFO:root:veh 25 at t = 26 returns None action!\n",
- "INFO:root:veh 25 at t = 27 returns None action!\n",
- "INFO:root:veh 26 at t = 29 returns None action!\n",
- "INFO:root:veh 26 at t = 30 returns None action!\n",
- "INFO:root:veh 11 at t = 16 returns None action!\n",
- "INFO:root:veh 11 at t = 17 returns None action!\n",
- "INFO:root:veh 10 at t = 65 returns None action!\n",
- "INFO:root:veh 10 at t = 66 returns None action!\n",
- "INFO:root:veh 58 at t = 1 returns None action!\n",
- "INFO:root:veh 58 at t = 2 returns None action!\n",
- "INFO:root:veh 59 at t = 3 returns None action!\n",
- "INFO:root:veh 59 at t = 4 returns None action!\n",
- "INFO:root:veh 55 at t = 12 returns None action!\n",
- "INFO:root:veh 55 at t = 13 returns None action!\n",
- "INFO:root:veh 55 at t = 14 returns None action!\n",
- "INFO:root:veh 55 at t = 15 returns None action!\n",
- "INFO:root:veh 58 at t = 20 returns None action!\n",
- "INFO:root:veh 58 at t = 21 returns None action!\n",
- "INFO:root:veh 59 at t = 24 returns None action!\n",
- "INFO:root:veh 59 at t = 25 returns None action!\n",
- "INFO:root:veh 69 at t = 29 returns None action!\n",
- "INFO:root:veh 69 at t = 30 returns None action!\n",
- "INFO:root:veh 58 at t = 32 returns None action!\n",
- "INFO:root:veh 58 at t = 33 returns None action!\n",
- "INFO:root:veh 46 at t = 1 returns None action!\n",
- "INFO:root:veh 46 at t = 2 returns None action!\n",
- "INFO:root:veh 46 at t = 3 returns None action!\n",
- "INFO:root:veh 46 at t = 4 returns None action!\n",
- "INFO:root:veh 14 at t = 8 returns None action!\n",
- "INFO:root:veh 14 at t = 9 returns None action!\n",
- "INFO:root:veh 26 at t = 25 returns None action!\n",
- "INFO:root:veh 26 at t = 26 returns None action!\n",
- "INFO:root:veh 26 at t = 29 returns None action!\n",
- "INFO:root:veh 26 at t = 30 returns None action!\n",
- "INFO:root:veh 30 at t = 11 returns None action!\n",
- "INFO:root:veh 30 at t = 12 returns None action!\n",
- "INFO:root:veh 30 at t = 13 returns None action!\n",
- "INFO:root:veh 30 at t = 14 returns None action!\n",
- "INFO:root:veh 30 at t = 31 returns None action!\n",
- "INFO:root:veh 30 at t = 32 returns None action!\n",
- "INFO:root:veh 30 at t = 33 returns None action!\n",
- "INFO:root:veh 30 at t = 34 returns None action!\n",
- "INFO:root:veh 11 at t = 66 returns None action!\n",
- "INFO:root:veh 11 at t = 67 returns None action!\n",
- "INFO:root:veh 11 at t = 68 returns None action!\n",
- "INFO:root:veh 11 at t = 70 returns None action!\n",
- "INFO:root:veh 11 at t = 71 returns None action!\n",
- "INFO:root:veh 59 at t = 31 returns None action!\n",
- "INFO:root:veh 59 at t = 32 returns None action!\n",
- "INFO:root:veh 59 at t = 33 returns None action!\n",
- "INFO:root:veh 59 at t = 34 returns None action!\n",
- "INFO:root:veh 62 at t = 35 returns None action!\n",
- "INFO:root:veh 62 at t = 36 returns None action!\n",
- "INFO:root:veh 8 at t = 42 returns None action!\n",
- "INFO:root:veh 8 at t = 43 returns None action!\n",
- "INFO:root:veh 8 at t = 44 returns None action!\n",
- "INFO:root:veh 8 at t = 45 returns None action!\n",
- "INFO:root:veh 8 at t = 46 returns None action!\n",
- "INFO:root:veh 60 at t = 0 returns None action!\n",
- "INFO:root:veh 26 at t = 0 returns None action!\n",
- "INFO:root:veh 60 at t = 1 returns None action!\n",
- "INFO:root:veh 60 at t = 2 returns None action!\n",
- "INFO:root:veh 60 at t = 3 returns None action!\n",
- "INFO:root:veh 49 at t = 3 returns None action!\n",
- "INFO:root:veh 60 at t = 4 returns None action!\n",
- "INFO:root:veh 49 at t = 4 returns None action!\n",
- "INFO:root:veh 49 at t = 5 returns None action!\n",
- "INFO:root:veh 60 at t = 7 returns None action!\n",
- "INFO:root:veh 60 at t = 8 returns None action!\n",
- "INFO:root:veh 29 at t = 50 returns None action!\n",
- "INFO:root:veh 29 at t = 51 returns None action!\n",
- "INFO:root:veh 29 at t = 52 returns None action!\n",
- "INFO:root:veh 4 at t = 2 returns None action!\n",
- "INFO:root:veh 4 at t = 3 returns None action!\n",
- "INFO:root:veh 4 at t = 4 returns None action!\n",
- "INFO:root:veh 4 at t = 5 returns None action!\n",
- "INFO:root:veh 13 at t = 8 returns None action!\n",
- "INFO:root:veh 13 at t = 9 returns None action!\n",
- "INFO:root:veh 35 at t = 0 returns None action!\n",
- "INFO:root:veh 35 at t = 1 returns None action!\n",
- "INFO:root:veh 35 at t = 2 returns None action!\n",
- "INFO:root:veh 35 at t = 3 returns None action!\n",
- "INFO:root:veh 35 at t = 4 returns None action!\n",
- "INFO:root:veh 23 at t = 5 returns None action!\n",
- "INFO:root:veh 6 at t = 2 returns None action!\n",
- "INFO:root:veh 6 at t = 3 returns None action!\n",
- "INFO:root:veh 6 at t = 29 returns None action!\n",
- "INFO:root:veh 6 at t = 30 returns None action!\n",
- "INFO:root:veh 22 at t = 0 returns None action!\n",
- "INFO:root:veh 22 at t = 1 returns None action!\n",
- "INFO:root:veh 22 at t = 2 returns None action!\n",
- "INFO:root:veh 22 at t = 3 returns None action!\n",
- "INFO:root:veh 3 at t = 7 returns None action!\n",
- "INFO:root:veh 3 at t = 8 returns None action!\n",
- "INFO:root:veh 20 at t = 20 returns None action!\n",
- "INFO:root:veh 50 at t = 0 returns None action!\n",
- "INFO:root:veh 50 at t = 1 returns None action!\n",
- "INFO:root:veh 2 at t = 71 returns None action!\n",
- "INFO:root:veh 26 at t = 63 returns None action!\n",
- "INFO:root:veh 26 at t = 64 returns None action!\n",
- "INFO:root:veh 26 at t = 65 returns None action!\n",
- "INFO:root:veh 26 at t = 74 returns None action!\n",
- "INFO:root:veh 26 at t = 75 returns None action!\n",
- "INFO:root:veh 26 at t = 76 returns None action!\n",
- "INFO:root:veh 26 at t = 77 returns None action!\n",
- "INFO:root:veh 7 at t = 20 returns None action!\n",
- "INFO:root:veh 7 at t = 21 returns None action!\n",
- "INFO:root:veh 4 at t = 41 returns None action!\n",
- "INFO:root:veh 4 at t = 42 returns None action!\n",
- "INFO:root:veh 3 at t = 10 returns None action!\n",
- "INFO:root:veh 3 at t = 11 returns None action!\n",
- "INFO:root:veh 3 at t = 12 returns None action!\n",
- "INFO:root:veh 24 at t = 21 returns None action!\n",
- "INFO:root:veh 24 at t = 22 returns None action!\n",
- "INFO:root:veh 24 at t = 23 returns None action!\n",
- "INFO:root:veh 24 at t = 24 returns None action!\n",
- "INFO:root:veh 16 at t = 28 returns None action!\n",
- "INFO:root:veh 16 at t = 29 returns None action!\n",
- "INFO:root:veh 16 at t = 30 returns None action!\n",
- "INFO:root:veh 16 at t = 31 returns None action!\n",
- "INFO:root:veh 23 at t = 17 returns None action!\n",
- "INFO:root:veh 23 at t = 18 returns None action!\n",
- "INFO:root:veh 23 at t = 22 returns None action!\n",
- "INFO:root:veh 23 at t = 23 returns None action!\n",
- "INFO:root:veh 23 at t = 31 returns None action!\n",
- "INFO:root:veh 23 at t = 32 returns None action!\n",
- "INFO:root:veh 1 at t = 76 returns None action!\n",
- "INFO:root:veh 1 at t = 77 returns None action!\n",
- "INFO:root:veh 1 at t = 78 returns None action!\n",
- "INFO:root:veh 36 at t = 2 returns None action!\n",
- "INFO:root:veh 36 at t = 3 returns None action!\n",
- "INFO:root:veh 36 at t = 4 returns None action!\n",
- "INFO:root:veh 36 at t = 5 returns None action!\n",
- "INFO:root:veh 36 at t = 6 returns None action!\n",
- "INFO:root:veh 15 at t = 4 returns None action!\n",
- "INFO:root:veh 15 at t = 5 returns None action!\n",
- "INFO:root:veh 15 at t = 6 returns None action!\n",
- "INFO:root:veh 15 at t = 7 returns None action!\n",
- "INFO:root:veh 8 at t = 69 returns None action!\n",
- "INFO:root:veh 8 at t = 70 returns None action!\n",
- "INFO:root:veh 8 at t = 71 returns None action!\n",
- "INFO:root:veh 1 at t = 37 returns None action!\n",
- "INFO:root:veh 1 at t = 38 returns None action!\n",
- "INFO:root:veh 1 at t = 39 returns None action!\n",
- "INFO:root:veh 25 at t = 15 returns None action!\n",
- "INFO:root:veh 25 at t = 16 returns None action!\n",
- "INFO:root:veh 25 at t = 17 returns None action!\n",
- "INFO:root:veh 25 at t = 18 returns None action!\n",
- "INFO:root:veh 28 at t = 1 returns None action!\n",
- "INFO:root:veh 28 at t = 2 returns None action!\n",
- "INFO:root:veh 2 at t = 54 returns None action!\n",
- "INFO:root:veh 2 at t = 55 returns None action!\n",
- "INFO:root:veh 2 at t = 56 returns None action!\n",
- "INFO:root:veh 2 at t = 57 returns None action!\n",
- "INFO:root:veh 2 at t = 58 returns None action!\n",
- "INFO:root:veh 2 at t = 59 returns None action!\n",
- "INFO:root:veh 2 at t = 60 returns None action!\n",
- "INFO:root:veh 2 at t = 61 returns None action!\n",
- "INFO:root:veh 2 at t = 64 returns None action!\n",
- "INFO:root:veh 2 at t = 65 returns None action!\n",
- "INFO:root:veh 2 at t = 66 returns None action!\n",
- "INFO:root:veh 9 at t = 15 returns None action!\n",
- "INFO:root:veh 6 at t = 55 returns None action!\n",
- "INFO:root:veh 6 at t = 56 returns None action!\n",
- "INFO:root:veh 6 at t = 57 returns None action!\n",
- "INFO:root:veh 6 at t = 59 returns None action!\n",
- "INFO:root:veh 6 at t = 60 returns None action!\n",
- "INFO:root:veh 6 at t = 61 returns None action!\n",
- "INFO:root:veh 54 at t = 32 returns None action!\n",
- "INFO:root:veh 54 at t = 33 returns None action!\n",
- "INFO:root:veh 54 at t = 37 returns None action!\n",
- "INFO:root:veh 54 at t = 38 returns None action!\n",
- "INFO:root:veh 54 at t = 39 returns None action!\n",
- "INFO:root:veh 54 at t = 40 returns None action!\n",
- "INFO:root:veh 54 at t = 41 returns None action!\n",
- "INFO:root:veh 54 at t = 42 returns None action!\n",
- "INFO:root:veh 54 at t = 53 returns None action!\n",
- "INFO:root:veh 54 at t = 54 returns None action!\n",
- "INFO:root:veh 54 at t = 55 returns None action!\n",
- "INFO:root:veh 54 at t = 56 returns None action!\n",
- "INFO:root:veh 56 at t = 63 returns None action!\n",
- "INFO:root:veh 56 at t = 64 returns None action!\n",
- "INFO:root:veh 56 at t = 65 returns None action!\n",
- "INFO:root:veh 56 at t = 66 returns None action!\n",
- "INFO:root:veh 56 at t = 67 returns None action!\n",
- "INFO:root:veh 56 at t = 68 returns None action!\n",
- "INFO:root:veh 56 at t = 70 returns None action!\n",
- "INFO:root:veh 56 at t = 71 returns None action!\n",
- "INFO:root:veh 56 at t = 72 returns None action!\n",
- "INFO:root:veh 56 at t = 73 returns None action!\n",
- "INFO:root:veh 56 at t = 74 returns None action!\n",
- "INFO:root:veh 56 at t = 75 returns None action!\n",
- "INFO:root:veh 11 at t = 37 returns None action!\n",
- "INFO:root:veh 11 at t = 38 returns None action!\n",
- "INFO:root:veh 10 at t = 33 returns None action!\n",
- "INFO:root:veh 10 at t = 34 returns None action!\n",
- "INFO:root:veh 10 at t = 35 returns None action!\n",
- "INFO:root:veh 10 at t = 36 returns None action!\n",
- "INFO:root:veh 86 at t = 4 returns None action!\n",
- "INFO:root:veh 86 at t = 5 returns None action!\n",
- "INFO:root:veh 86 at t = 6 returns None action!\n",
- "INFO:root:veh 86 at t = 7 returns None action!\n",
- "INFO:root:veh 86 at t = 8 returns None action!\n",
- "INFO:root:veh 86 at t = 21 returns None action!\n",
- "INFO:root:veh 86 at t = 22 returns None action!\n",
- "INFO:root:veh 6 at t = 18 returns None action!\n",
- "INFO:root:veh 6 at t = 19 returns None action!\n",
- "INFO:root:veh 31 at t = 54 returns None action!\n",
- "INFO:root:veh 31 at t = 55 returns None action!\n",
- "INFO:root:veh 36 at t = 72 returns None action!\n",
- "INFO:root:veh 36 at t = 73 returns None action!\n",
- "INFO:root:veh 4 at t = 19 returns None action!\n",
- "INFO:root:veh 4 at t = 20 returns None action!\n",
- "INFO:root:veh 54 at t = 67 returns None action!\n",
- "INFO:root:veh 54 at t = 68 returns None action!\n",
- "INFO:root:veh 54 at t = 69 returns None action!\n",
- "INFO:root:veh 3 at t = 4 returns None action!\n",
- "INFO:root:veh 3 at t = 5 returns None action!\n",
- "INFO:root:veh 45 at t = 44 returns None action!\n",
- "INFO:root:veh 45 at t = 45 returns None action!\n",
- "INFO:root:veh 45 at t = 46 returns None action!\n",
- "INFO:root:veh 45 at t = 47 returns None action!\n",
- "INFO:root:veh 48 at t = 64 returns None action!\n",
- "INFO:root:veh 48 at t = 65 returns None action!\n",
- "INFO:root:veh 48 at t = 68 returns None action!\n",
- "INFO:root:veh 48 at t = 69 returns None action!\n",
- "INFO:root:veh 45 at t = 78 returns None action!\n",
- "INFO:root:veh 17 at t = 37 returns None action!\n",
- "INFO:root:veh 17 at t = 38 returns None action!\n",
- "INFO:root:veh 17 at t = 39 returns None action!\n",
- "INFO:root:veh 17 at t = 40 returns None action!\n",
- "INFO:root:veh 17 at t = 41 returns None action!\n",
- "INFO:root:veh 44 at t = 0 returns None action!\n",
- "INFO:root:veh 44 at t = 1 returns None action!\n",
- "INFO:root:veh 3 at t = 17 returns None action!\n",
- "INFO:root:veh 3 at t = 18 returns None action!\n",
- "INFO:root:veh 3 at t = 19 returns None action!\n",
- "INFO:root:veh 3 at t = 20 returns None action!\n",
- "INFO:root:veh 44 at t = 28 returns None action!\n",
- "INFO:root:veh 44 at t = 29 returns None action!\n",
- "INFO:root:veh 3 at t = 3 returns None action!\n",
- "INFO:root:veh 3 at t = 4 returns None action!\n",
- "INFO:root:veh 22 at t = 23 returns None action!\n",
- "INFO:root:veh 22 at t = 24 returns None action!\n",
- "INFO:root:veh 22 at t = 28 returns None action!\n",
- "INFO:root:veh 22 at t = 29 returns None action!\n",
- "INFO:root:veh 22 at t = 30 returns None action!\n",
- "INFO:root:veh 22 at t = 31 returns None action!\n",
- "INFO:root:veh 22 at t = 34 returns None action!\n",
- "INFO:root:veh 22 at t = 35 returns None action!\n",
- "INFO:root:veh 22 at t = 37 returns None action!\n",
- "INFO:root:veh 22 at t = 38 returns None action!\n",
- "INFO:root:veh 22 at t = 39 returns None action!\n",
- "INFO:root:veh 22 at t = 40 returns None action!\n",
- "INFO:root:veh 22 at t = 41 returns None action!\n",
- "INFO:root:veh 22 at t = 42 returns None action!\n",
- "INFO:root:veh 22 at t = 43 returns None action!\n",
- "INFO:root:veh 22 at t = 47 returns None action!\n",
- "INFO:root:veh 22 at t = 48 returns None action!\n",
- "INFO:root:veh 94 at t = 24 returns None action!\n",
- "INFO:root:veh 94 at t = 25 returns None action!\n",
- "INFO:root:veh 94 at t = 26 returns None action!\n",
- "INFO:root:veh 9 at t = 24 returns None action!\n",
- "INFO:root:veh 9 at t = 25 returns None action!\n",
- "INFO:root:veh 59 at t = 60 returns None action!\n",
- "INFO:root:veh 59 at t = 61 returns None action!\n",
- "INFO:root:veh 59 at t = 62 returns None action!\n",
- "INFO:root:veh 59 at t = 63 returns None action!\n",
- "INFO:root:veh 59 at t = 64 returns None action!\n",
- "INFO:root:veh 59 at t = 67 returns None action!\n",
- "INFO:root:veh 59 at t = 68 returns None action!\n",
- "INFO:root:veh 59 at t = 69 returns None action!\n",
- "INFO:root:veh 44 at t = 0 returns None action!\n",
- "INFO:root:veh 9 at t = 12 returns None action!\n",
- "INFO:root:veh 9 at t = 13 returns None action!\n",
- "INFO:root:veh 9 at t = 14 returns None action!\n",
- "INFO:root:veh 9 at t = 15 returns None action!\n",
- "INFO:root:veh 9 at t = 16 returns None action!\n",
- "INFO:root:veh 9 at t = 17 returns None action!\n",
- "INFO:root:veh 9 at t = 18 returns None action!\n",
- "INFO:root:veh 109 at t = 0 returns None action!\n",
- "INFO:root:veh 109 at t = 1 returns None action!\n",
- "INFO:root:veh 109 at t = 2 returns None action!\n",
- "INFO:root:veh 109 at t = 3 returns None action!\n",
- "INFO:root:veh 109 at t = 19 returns None action!\n",
- "INFO:root:veh 109 at t = 20 returns None action!\n",
- "INFO:root:veh 109 at t = 21 returns None action!\n",
- "INFO:root:veh 109 at t = 22 returns None action!\n",
- "INFO:root:veh 109 at t = 30 returns None action!\n",
- "INFO:root:veh 109 at t = 31 returns None action!\n",
- "INFO:root:veh 109 at t = 32 returns None action!\n",
- "INFO:root:veh 109 at t = 63 returns None action!\n",
- "INFO:root:veh 109 at t = 64 returns None action!\n",
- "INFO:root:veh 109 at t = 66 returns None action!\n",
- "INFO:root:veh 109 at t = 67 returns None action!\n",
- "INFO:root:veh 109 at t = 78 returns None action!\n",
- "INFO:root:veh 3 at t = 75 returns None action!\n",
- "INFO:root:veh 3 at t = 76 returns None action!\n",
- "INFO:root:veh 10 at t = 5 returns None action!\n",
- "INFO:root:veh 10 at t = 6 returns None action!\n",
- "INFO:root:veh 4 at t = 11 returns None action!\n",
- "INFO:root:veh 4 at t = 12 returns None action!\n",
- "INFO:root:veh 1 at t = 15 returns None action!\n",
- "INFO:root:veh 1 at t = 16 returns None action!\n",
- "INFO:root:veh 1 at t = 17 returns None action!\n",
- "INFO:root:veh 1 at t = 18 returns None action!\n",
- "INFO:root:veh 5 at t = 18 returns None action!\n",
- "INFO:root:veh 51 at t = 65 returns None action!\n",
- "INFO:root:veh 51 at t = 66 returns None action!\n",
- "INFO:root:veh 48 at t = 67 returns None action!\n",
- "INFO:root:veh 48 at t = 68 returns None action!\n",
- "INFO:root:veh 2 at t = 6 returns None action!\n",
- "INFO:root:veh 2 at t = 7 returns None action!\n",
- "INFO:root:veh 2 at t = 22 returns None action!\n",
- "INFO:root:veh 2 at t = 23 returns None action!\n",
- "INFO:root:veh 3 at t = 22 returns None action!\n",
- "INFO:root:veh 3 at t = 23 returns None action!\n",
- "INFO:root:veh 3 at t = 24 returns None action!\n",
- "INFO:root:veh 70 at t = 73 returns None action!\n",
- "INFO:root:veh 70 at t = 74 returns None action!\n",
- "INFO:root:veh 70 at t = 75 returns None action!\n",
- "INFO:root:veh 55 at t = 58 returns None action!\n",
- "INFO:root:veh 55 at t = 59 returns None action!\n",
- "INFO:root:veh 55 at t = 60 returns None action!\n",
- "INFO:root:veh 55 at t = 61 returns None action!\n",
- "INFO:root:veh 13 at t = 11 returns None action!\n",
- "INFO:root:veh 13 at t = 12 returns None action!\n",
- "INFO:root:veh 13 at t = 36 returns None action!\n",
- "INFO:root:veh 13 at t = 37 returns None action!\n",
- "INFO:root:veh 84 at t = 1 returns None action!\n",
- "INFO:root:veh 84 at t = 2 returns None action!\n",
- "INFO:root:veh 10 at t = 15 returns None action!\n",
- "INFO:root:veh 10 at t = 16 returns None action!\n",
- "INFO:root:veh 10 at t = 17 returns None action!\n",
- "INFO:root:veh 9 at t = 20 returns None action!\n",
- "INFO:root:veh 9 at t = 21 returns None action!\n",
- "INFO:root:veh 9 at t = 22 returns None action!\n",
- "INFO:root:veh 5 at t = 40 returns None action!\n",
- "INFO:root:veh 5 at t = 41 returns None action!\n",
- "INFO:root:veh 5 at t = 44 returns None action!\n",
- "INFO:root:veh 5 at t = 45 returns None action!\n",
- "INFO:root:veh 5 at t = 46 returns None action!\n",
- "INFO:root:veh 5 at t = 47 returns None action!\n",
- "INFO:root:veh 5 at t = 50 returns None action!\n",
- "INFO:root:veh 5 at t = 51 returns None action!\n",
- "INFO:root:veh 5 at t = 52 returns None action!\n",
- "INFO:root:veh 5 at t = 53 returns None action!\n",
- "INFO:root:veh 5 at t = 54 returns None action!\n",
- "INFO:root:veh 5 at t = 55 returns None action!\n",
- "INFO:root:veh 5 at t = 56 returns None action!\n",
- "INFO:root:veh 19 at t = 10 returns None action!\n",
- "INFO:root:veh 19 at t = 11 returns None action!\n",
- "INFO:root:veh 19 at t = 12 returns None action!\n",
- "INFO:root:veh 3 at t = 58 returns None action!\n",
- "INFO:root:veh 3 at t = 59 returns None action!\n",
- "INFO:root:veh 3 at t = 60 returns None action!\n",
- "INFO:root:veh 25 at t = 12 returns None action!\n",
- "INFO:root:veh 25 at t = 13 returns None action!\n",
- "INFO:root:veh 25 at t = 16 returns None action!\n",
- "INFO:root:veh 15 at t = 3 returns None action!\n",
- "INFO:root:veh 15 at t = 4 returns None action!\n",
- "INFO:root:veh 15 at t = 5 returns None action!\n",
- "INFO:root:veh 7 at t = 5 returns None action!\n",
- "INFO:root:veh 15 at t = 6 returns None action!\n",
- "INFO:root:veh 94 at t = 2 returns None action!\n",
- "INFO:root:veh 94 at t = 3 returns None action!\n",
- "INFO:root:veh 94 at t = 4 returns None action!\n",
- "INFO:root:veh 88 at t = 5 returns None action!\n",
- "INFO:root:veh 88 at t = 6 returns None action!\n",
- "INFO:root:veh 88 at t = 7 returns None action!\n",
- "INFO:root:veh 88 at t = 8 returns None action!\n",
- "INFO:root:veh 88 at t = 22 returns None action!\n",
- "INFO:root:veh 88 at t = 23 returns None action!\n",
- "INFO:root:veh 10 at t = 22 returns None action!\n",
- "INFO:root:veh 10 at t = 23 returns None action!\n",
- "INFO:root:veh 22 at t = 2 returns None action!\n",
- "INFO:root:veh 22 at t = 3 returns None action!\n",
- "INFO:root:veh 34 at t = 0 returns None action!\n",
- "INFO:root:veh 34 at t = 1 returns None action!\n",
- "INFO:root:veh 32 at t = 7 returns None action!\n",
- "INFO:root:veh 32 at t = 8 returns None action!\n",
- "INFO:root:veh 28 at t = 8 returns None action!\n",
- "INFO:root:veh 32 at t = 9 returns None action!\n",
- "INFO:root:veh 28 at t = 9 returns None action!\n",
- "INFO:root:veh 21 at t = 67 returns None action!\n",
- "INFO:root:veh 21 at t = 68 returns None action!\n",
- "INFO:root:veh 21 at t = 69 returns None action!\n",
- "INFO:root:veh 21 at t = 70 returns None action!\n",
- "INFO:root:veh 21 at t = 71 returns None action!\n",
- "INFO:root:veh 21 at t = 1 returns None action!\n",
- "INFO:root:veh 21 at t = 2 returns None action!\n",
- "INFO:root:veh 21 at t = 3 returns None action!\n",
- "INFO:root:veh 20 at t = 32 returns None action!\n",
- "INFO:root:veh 20 at t = 33 returns None action!\n",
- "INFO:root:veh 20 at t = 34 returns None action!\n",
- "INFO:root:veh 71 at t = 13 returns None action!\n",
- "INFO:root:veh 71 at t = 14 returns None action!\n",
- "INFO:root:veh 71 at t = 15 returns None action!\n",
- "INFO:root:veh 4 at t = 20 returns None action!\n",
- "INFO:root:veh 4 at t = 21 returns None action!\n",
- "INFO:root:veh 19 at t = 46 returns None action!\n",
- "INFO:root:veh 19 at t = 47 returns None action!\n",
- "INFO:root:veh 31 at t = 6 returns None action!\n",
- "INFO:root:veh 31 at t = 7 returns None action!\n",
- "INFO:root:veh 31 at t = 8 returns None action!\n",
- "INFO:root:veh 26 at t = 50 returns None action!\n",
- "INFO:root:veh 26 at t = 51 returns None action!\n",
- "INFO:root:veh 15 at t = 58 returns None action!\n",
- "INFO:root:veh 22 at t = 4 returns None action!\n",
- "INFO:root:veh 22 at t = 5 returns None action!\n",
- "INFO:root:veh 6 at t = 31 returns None action!\n",
- "INFO:root:veh 6 at t = 32 returns None action!\n",
- "INFO:root:veh 6 at t = 33 returns None action!\n",
- "INFO:root:veh 6 at t = 34 returns None action!\n",
- "INFO:root:veh 6 at t = 35 returns None action!\n",
- "INFO:root:veh 15 at t = 16 returns None action!\n",
- "INFO:root:veh 15 at t = 17 returns None action!\n",
- "INFO:root:veh 15 at t = 18 returns None action!\n",
- "INFO:root:veh 15 at t = 19 returns None action!\n",
- "INFO:root:veh 50 at t = 1 returns None action!\n",
- "INFO:root:veh 50 at t = 2 returns None action!\n",
- "INFO:root:veh 50 at t = 5 returns None action!\n",
- "INFO:root:veh 50 at t = 6 returns None action!\n",
- "INFO:root:veh 26 at t = 7 returns None action!\n",
- "INFO:root:veh 26 at t = 8 returns None action!\n",
- "INFO:root:veh 41 at t = 41 returns None action!\n",
- "INFO:root:veh 41 at t = 42 returns None action!\n",
- "INFO:root:veh 41 at t = 43 returns None action!\n",
- "INFO:root:veh 15 at t = 4 returns None action!\n",
- "INFO:root:veh 15 at t = 5 returns None action!\n",
- "INFO:root:veh 5 at t = 41 returns None action!\n",
- "INFO:root:veh 5 at t = 42 returns None action!\n",
- "INFO:root:veh 5 at t = 44 returns None action!\n",
- "INFO:root:veh 5 at t = 45 returns None action!\n",
- "INFO:root:veh 5 at t = 46 returns None action!\n",
- "INFO:root:veh 5 at t = 47 returns None action!\n",
- "INFO:root:veh 0 at t = 34 returns None action!\n",
- "INFO:root:veh 0 at t = 35 returns None action!\n",
- "INFO:root:veh 0 at t = 36 returns None action!\n",
- "INFO:root:veh 8 at t = 37 returns None action!\n",
- "INFO:root:veh 8 at t = 38 returns None action!\n",
- "INFO:root:veh 8 at t = 39 returns None action!\n",
- "INFO:root:veh 5 at t = 42 returns None action!\n",
- "INFO:root:veh 5 at t = 43 returns None action!\n",
- "INFO:root:veh 5 at t = 44 returns None action!\n",
- "INFO:root:veh 5 at t = 45 returns None action!\n",
- "INFO:root:veh 7 at t = 48 returns None action!\n",
- "INFO:root:veh 7 at t = 49 returns None action!\n",
- "INFO:root:veh 7 at t = 52 returns None action!\n",
- "INFO:root:veh 7 at t = 53 returns None action!\n",
- "INFO:root:veh 8 at t = 64 returns None action!\n",
- "INFO:root:veh 8 at t = 65 returns None action!\n",
- "INFO:root:veh 8 at t = 66 returns None action!\n",
- "INFO:root:veh 21 at t = 6 returns None action!\n",
- "INFO:root:veh 21 at t = 7 returns None action!\n",
- "INFO:root:veh 21 at t = 8 returns None action!\n",
- "INFO:root:veh 10 at t = 10 returns None action!\n",
- "INFO:root:veh 10 at t = 11 returns None action!\n",
- "INFO:root:veh 10 at t = 12 returns None action!\n",
- "INFO:root:veh 27 at t = 21 returns None action!\n",
- "INFO:root:veh 27 at t = 22 returns None action!\n",
- "INFO:root:veh 27 at t = 23 returns None action!\n",
- "INFO:root:veh 6 at t = 61 returns None action!\n",
- "INFO:root:veh 6 at t = 62 returns None action!\n",
- "INFO:root:veh 6 at t = 63 returns None action!\n",
- "INFO:root:veh 6 at t = 69 returns None action!\n",
- "INFO:root:veh 6 at t = 70 returns None action!\n",
- "INFO:root:veh 6 at t = 71 returns None action!\n",
- "INFO:root:veh 6 at t = 72 returns None action!\n",
- "INFO:root:veh 6 at t = 73 returns None action!\n",
- "INFO:root:veh 6 at t = 74 returns None action!\n",
- "INFO:root:veh 42 at t = 20 returns None action!\n",
- "INFO:root:veh 42 at t = 21 returns None action!\n",
- "INFO:root:veh 42 at t = 22 returns None action!\n",
- "INFO:root:veh 42 at t = 23 returns None action!\n",
- "INFO:root:veh 42 at t = 65 returns None action!\n",
- "INFO:root:veh 42 at t = 66 returns None action!\n",
- "INFO:root:veh 42 at t = 78 returns None action!\n",
- "INFO:root:veh 42 at t = 79 returns None action!\n",
- "INFO:root:veh 3 at t = 0 returns None action!\n",
- "INFO:root:veh 3 at t = 1 returns None action!\n",
- "INFO:root:veh 3 at t = 2 returns None action!\n",
- "INFO:root:veh 3 at t = 3 returns None action!\n",
- "INFO:root:veh 0 at t = 1 returns None action!\n",
- "INFO:root:veh 0 at t = 2 returns None action!\n",
- "INFO:root:veh 17 at t = 7 returns None action!\n",
- "INFO:root:veh 17 at t = 8 returns None action!\n",
- "INFO:root:veh 17 at t = 10 returns None action!\n",
- "INFO:root:veh 17 at t = 11 returns None action!\n",
- "INFO:root:veh 17 at t = 12 returns None action!\n",
- "INFO:root:veh 17 at t = 13 returns None action!\n",
- "INFO:root:veh 17 at t = 32 returns None action!\n",
- "INFO:root:veh 17 at t = 33 returns None action!\n",
- "INFO:root:veh 7 at t = 7 returns None action!\n",
- "INFO:root:veh 7 at t = 8 returns None action!\n",
- "INFO:root:veh 1 at t = 36 returns None action!\n",
- "INFO:root:veh 1 at t = 37 returns None action!\n",
- "INFO:root:veh 11 at t = 77 returns None action!\n",
- "INFO:root:veh 11 at t = 78 returns None action!\n",
- "INFO:root:veh 5 at t = 6 returns None action!\n",
- "INFO:root:veh 5 at t = 7 returns None action!\n",
- "INFO:root:veh 5 at t = 8 returns None action!\n",
- "INFO:root:veh 5 at t = 10 returns None action!\n",
- "INFO:root:veh 5 at t = 11 returns None action!\n",
- "INFO:root:veh 5 at t = 12 returns None action!\n",
- "INFO:root:veh 5 at t = 13 returns None action!\n",
- "INFO:root:veh 5 at t = 14 returns None action!\n"
+ "INFO:root:\n",
+ " Evaluating policy on 50 files...\n",
+ "100%|██████████| 50/50 [00:14<00:00, 3.47it/s]\n"
]
}
],
@@ -716,7 +213,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -745,12 +242,13 @@
"
agents_controlled | \n",
" reg_coef | \n",
" act_acc | \n",
+ " accel_val_mae | \n",
+ " steer_val_mae | \n",
" pos_rmse | \n",
" speed_mae | \n",
" goal_rate | \n",
" veh_edge_cr | \n",
" veh_veh_cr | \n",
- " num_violations | \n",
" class | \n",
" \n",
" \n",
@@ -758,207 +256,977 @@
" \n",
" 0 | \n",
" None | \n",
- " tfrecord-00004-of-01000_378.json | \n",
- " 5 | \n",
+ " tfrecord-00350-of-01000_410.json | \n",
+ " 2 | \n",
" 0.0 | \n",
- " 0.140 | \n",
- " 2.581 | \n",
- " 0.288 | \n",
- " 0.400 | \n",
- " 0.600 | \n",
+ " 0.094 | \n",
+ " 1.526 | \n",
+ " 0.127 | \n",
+ " 5.253 | \n",
+ " 1.066 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
" 0.000 | \n",
- " 0 | \n",
" IL | \n",
"
\n",
" \n",
" 1 | \n",
" None | \n",
- " tfrecord-00003-of-01000_109.json | \n",
- " 20 | \n",
+ " tfrecord-00349-of-01000_311.json | \n",
+ " 2 | \n",
" 0.0 | \n",
- " 0.112 | \n",
- " 119.303 | \n",
- " 20.169 | \n",
- " 0.350 | \n",
- " 0.250 | \n",
- " 0.100 | \n",
- " 0 | \n",
+ " 0.019 | \n",
+ " 2.409 | \n",
+ " 0.281 | \n",
+ " 3.528 | \n",
+ " 1.018 | \n",
+ " 0.000 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
" IL | \n",
"
\n",
" \n",
" 2 | \n",
" None | \n",
- " tfrecord-00004-of-01000_61.json | \n",
- " 6 | \n",
+ " tfrecord-00421-of-01000_364.json | \n",
+ " 9 | \n",
" 0.0 | \n",
- " 0.210 | \n",
- " 4.265 | \n",
- " 0.463 | \n",
- " 0.167 | \n",
- " 0.500 | \n",
+ " 0.067 | \n",
+ " 1.970 | \n",
+ " 0.164 | \n",
+ " 7.574 | \n",
+ " 1.055 | \n",
" 0.000 | \n",
- " 0 | \n",
+ " 0.222 | \n",
+ " 0.667 | \n",
" IL | \n",
"
\n",
" \n",
" 3 | \n",
" None | \n",
- " tfrecord-00012-of-01000_87.json | \n",
- " 9 | \n",
+ " tfrecord-00243-of-01000_280.json | \n",
+ " 2 | \n",
" 0.0 | \n",
- " 0.190 | \n",
- " 181.568 | \n",
- " 162.275 | \n",
- " 0.333 | \n",
+ " 0.012 | \n",
+ " 1.974 | \n",
+ " 0.239 | \n",
+ " 2.759 | \n",
+ " 0.555 | \n",
" 0.000 | \n",
" 0.000 | \n",
- " 0 | \n",
+ " 1.000 | \n",
" IL | \n",
"
\n",
" \n",
" 4 | \n",
" None | \n",
- " tfrecord-00007-of-01000_237.json | \n",
- " 14 | \n",
+ " tfrecord-00275-of-01000_483.json | \n",
+ " 3 | \n",
" 0.0 | \n",
- " 0.278 | \n",
- " 8.479 | \n",
- " 0.513 | \n",
- " 0.286 | \n",
- " 0.071 | \n",
+ " 0.050 | \n",
+ " 1.833 | \n",
+ " 0.220 | \n",
+ " 3.999 | \n",
+ " 0.621 | \n",
" 0.000 | \n",
- " 0 | \n",
+ " 0.667 | \n",
+ " 0.333 | \n",
" IL | \n",
"
\n",
" \n",
" 5 | \n",
" None | \n",
- " tfrecord-00005-of-01000_423.json | \n",
- " 4 | \n",
+ " tfrecord-00201-of-01000_202.json | \n",
+ " 16 | \n",
" 0.0 | \n",
- " 0.391 | \n",
- " 6.674 | \n",
- " 0.776 | \n",
+ " 0.094 | \n",
+ " 2.250 | \n",
+ " 0.003 | \n",
+ " 140.214 | \n",
+ " 57.562 | \n",
+ " 0.312 | \n",
" 0.250 | \n",
" 0.000 | \n",
- " 0.000 | \n",
- " 0 | \n",
" IL | \n",
"
\n",
" \n",
" 6 | \n",
" None | \n",
- " tfrecord-00012-of-01000_246.json | \n",
- " 20 | \n",
+ " tfrecord-00623-of-01000_242.json | \n",
+ " 8 | \n",
" 0.0 | \n",
- " 0.108 | \n",
- " 162.746 | \n",
- " 53.993 | \n",
- " 0.050 | \n",
+ " 0.058 | \n",
+ " 2.032 | \n",
+ " 0.375 | \n",
+ " 205.846 | \n",
+ " 114.555 | \n",
+ " 0.125 | \n",
+ " 0.000 | \n",
+ " 0.375 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " None | \n",
+ " tfrecord-00316-of-01000_183.json | \n",
+ " 10 | \n",
+ " 0.0 | \n",
+ " 0.162 | \n",
+ " 2.260 | \n",
+ " 0.013 | \n",
+ " 8.700 | \n",
+ " 0.760 | \n",
+ " 0.200 | \n",
+ " 0.200 | \n",
+ " 0.200 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " None | \n",
+ " tfrecord-00443-of-01000_487.json | \n",
+ " 6 | \n",
+ " 0.0 | \n",
+ " 0.025 | \n",
+ " 2.143 | \n",
+ " 0.294 | \n",
+ " 9.119 | \n",
+ " 1.796 | \n",
+ " 0.000 | \n",
+ " 0.333 | \n",
" 0.500 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " None | \n",
+ " tfrecord-00908-of-01000_302.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.149 | \n",
+ " 2.089 | \n",
+ " 0.039 | \n",
+ " 12.949 | \n",
+ " 1.180 | \n",
+ " 0.150 | \n",
+ " 0.200 | \n",
" 0.200 | \n",
- " 0 | \n",
" IL | \n",
"
\n",
" \n",
- " 7 | \n",
+ " 10 | \n",
+ " None | \n",
+ " tfrecord-00395-of-01000_97.json | \n",
+ " 18 | \n",
+ " 0.0 | \n",
+ " 0.076 | \n",
+ " 2.333 | \n",
+ " 0.292 | \n",
+ " 253.132 | \n",
+ " 200.875 | \n",
+ " 0.000 | \n",
+ " 0.833 | \n",
+ " 0.111 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " None | \n",
+ " tfrecord-00845-of-01000_433.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.133 | \n",
+ " 1.311 | \n",
+ " 0.190 | \n",
+ " 3.827 | \n",
+ " 0.630 | \n",
+ " 0.000 | \n",
+ " 0.333 | \n",
+ " 0.667 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " None | \n",
+ " tfrecord-00628-of-01000_338.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.029 | \n",
+ " 1.485 | \n",
+ " 0.367 | \n",
+ " 4.083 | \n",
+ " 0.509 | \n",
+ " 0.000 | \n",
+ " 0.333 | \n",
+ " 0.333 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " None | \n",
+ " tfrecord-00498-of-01000_229.json | \n",
+ " 7 | \n",
+ " 0.0 | \n",
+ " 0.088 | \n",
+ " 2.199 | \n",
+ " 0.113 | \n",
+ " 195.827 | \n",
+ " 120.754 | \n",
+ " 0.143 | \n",
+ " 0.571 | \n",
+ " 0.143 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " None | \n",
+ " tfrecord-00069-of-01000_193.json | \n",
+ " 5 | \n",
+ " 0.0 | \n",
+ " 0.115 | \n",
+ " 1.891 | \n",
+ " 0.144 | \n",
+ " 169.086 | \n",
+ " 297.963 | \n",
+ " 0.000 | \n",
+ " 0.400 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " None | \n",
+ " tfrecord-00271-of-01000_435.json | \n",
+ " 18 | \n",
+ " 0.0 | \n",
+ " 0.178 | \n",
+ " 1.959 | \n",
+ " 0.024 | \n",
+ " 14.123 | \n",
+ " 1.407 | \n",
+ " 0.056 | \n",
+ " 0.111 | \n",
+ " 0.389 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " None | \n",
+ " tfrecord-00480-of-01000_25.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.142 | \n",
+ " 1.460 | \n",
+ " 0.146 | \n",
+ " 9.145 | \n",
+ " 1.885 | \n",
+ " 0.333 | \n",
+ " 0.000 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " None | \n",
+ " tfrecord-00683-of-01000_42.json | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.262 | \n",
+ " 1.287 | \n",
+ " 0.102 | \n",
+ " 8.824 | \n",
+ " 2.010 | \n",
+ " 0.000 | \n",
+ " 0.500 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " None | \n",
+ " tfrecord-00384-of-01000_192.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.117 | \n",
+ " 1.701 | \n",
+ " 0.096 | \n",
+ " 7.469 | \n",
+ " 1.049 | \n",
+ " 0.000 | \n",
+ " 0.667 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
" None | \n",
- " tfrecord-00012-of-01000_389.json | \n",
+ " tfrecord-00073-of-01000_173.json | \n",
" 4 | \n",
" 0.0 | \n",
- " 0.172 | \n",
- " 123.747 | \n",
- " 73.352 | \n",
- " 0.250 | \n",
+ " 0.178 | \n",
+ " 2.130 | \n",
+ " 0.046 | \n",
+ " 8.065 | \n",
+ " 1.571 | \n",
+ " 0.000 | \n",
" 0.500 | \n",
" 0.000 | \n",
- " 0 | \n",
" IL | \n",
"
\n",
" \n",
- " 8 | \n",
+ " 20 | \n",
" None | \n",
- " tfrecord-00001-of-01000_307.json | \n",
+ " tfrecord-00468-of-01000_290.json | \n",
" 3 | \n",
" 0.0 | \n",
- " 0.258 | \n",
- " 4.814 | \n",
- " 0.485 | \n",
+ " 0.017 | \n",
+ " 2.952 | \n",
+ " 0.461 | \n",
+ " 4.592 | \n",
+ " 2.468 | \n",
+ " 0.000 | \n",
+ " 0.333 | \n",
" 0.667 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " None | \n",
+ " tfrecord-00656-of-01000_18.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.025 | \n",
+ " 2.386 | \n",
+ " 0.392 | \n",
+ " 6.863 | \n",
+ " 3.039 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " None | \n",
+ " tfrecord-00651-of-01000_285.json | \n",
+ " 4 | \n",
+ " 0.0 | \n",
+ " 0.084 | \n",
+ " 1.902 | \n",
+ " 0.050 | \n",
+ " 4.980 | \n",
+ " 0.859 | \n",
+ " 0.250 | \n",
+ " 0.250 | \n",
+ " 0.250 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " None | \n",
+ " tfrecord-00364-of-01000_319.json | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.088 | \n",
+ " 1.411 | \n",
+ " 0.288 | \n",
+ " 4.417 | \n",
+ " 0.519 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " None | \n",
+ " tfrecord-00615-of-01000_417.json | \n",
+ " 13 | \n",
+ " 0.0 | \n",
+ " 0.034 | \n",
+ " 2.531 | \n",
+ " 0.350 | \n",
+ " 11.520 | \n",
+ " 2.748 | \n",
+ " 0.154 | \n",
+ " 0.385 | \n",
+ " 0.385 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " None | \n",
+ " tfrecord-00795-of-01000_89.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.042 | \n",
+ " 2.039 | \n",
+ " 0.180 | \n",
+ " 171.441 | \n",
+ " 274.731 | \n",
" 0.333 | \n",
+ " 0.667 | \n",
" 0.000 | \n",
- " 0 | \n",
" IL | \n",
"
\n",
" \n",
- " 9 | \n",
+ " 26 | \n",
+ " None | \n",
+ " tfrecord-00880-of-01000_341.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.034 | \n",
+ " 2.709 | \n",
+ " 0.451 | \n",
+ " 168.273 | \n",
+ " 81.749 | \n",
+ " 0.000 | \n",
+ " 0.500 | \n",
+ " 0.500 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " None | \n",
+ " tfrecord-00187-of-01000_350.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.104 | \n",
+ " 1.837 | \n",
+ " 0.169 | \n",
+ " 6.644 | \n",
+ " 1.400 | \n",
+ " 0.000 | \n",
+ " 0.667 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
" None | \n",
- " tfrecord-00004-of-01000_157.json | \n",
- " 11 | \n",
+ " tfrecord-00428-of-01000_234.json | \n",
+ " 8 | \n",
" 0.0 | \n",
+ " 0.059 | \n",
+ " 1.558 | \n",
+ " 0.465 | \n",
+ " 8.172 | \n",
+ " 1.171 | \n",
+ " 0.000 | \n",
+ " 0.875 | \n",
" 0.125 | \n",
- " 260.832 | \n",
- " 343.664 | \n",
- " 0.182 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " None | \n",
+ " tfrecord-00748-of-01000_178.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
" 0.000 | \n",
- " 0.545 | \n",
- " 0 | \n",
+ " 2.302 | \n",
+ " 0.334 | \n",
+ " 3.596 | \n",
+ " 1.461 | \n",
+ " 0.000 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
" IL | \n",
"
\n",
- " \n",
- "\n",
- ""
- ],
- "text/plain": [
- " run_id traffic_scene agents_controlled reg_coef \\\n",
- "0 None tfrecord-00004-of-01000_378.json 5 0.0 \n",
- "1 None tfrecord-00003-of-01000_109.json 20 0.0 \n",
- "2 None tfrecord-00004-of-01000_61.json 6 0.0 \n",
- "3 None tfrecord-00012-of-01000_87.json 9 0.0 \n",
- "4 None tfrecord-00007-of-01000_237.json 14 0.0 \n",
- "5 None tfrecord-00005-of-01000_423.json 4 0.0 \n",
- "6 None tfrecord-00012-of-01000_246.json 20 0.0 \n",
- "7 None tfrecord-00012-of-01000_389.json 4 0.0 \n",
- "8 None tfrecord-00001-of-01000_307.json 3 0.0 \n",
- "9 None tfrecord-00004-of-01000_157.json 11 0.0 \n",
- "\n",
- " act_acc pos_rmse speed_mae goal_rate veh_edge_cr veh_veh_cr \\\n",
- "0 0.140 2.581 0.288 0.400 0.600 0.000 \n",
- "1 0.112 119.303 20.169 0.350 0.250 0.100 \n",
- "2 0.210 4.265 0.463 0.167 0.500 0.000 \n",
- "3 0.190 181.568 162.275 0.333 0.000 0.000 \n",
- "4 0.278 8.479 0.513 0.286 0.071 0.000 \n",
- "5 0.391 6.674 0.776 0.250 0.000 0.000 \n",
- "6 0.108 162.746 53.993 0.050 0.500 0.200 \n",
- "7 0.172 123.747 73.352 0.250 0.500 0.000 \n",
- "8 0.258 4.814 0.485 0.667 0.333 0.000 \n",
- "9 0.125 260.832 343.664 0.182 0.000 0.545 \n",
- "\n",
- " num_violations class \n",
- "0 0 IL \n",
- "1 0 IL \n",
- "2 0 IL \n",
- "3 0 IL \n",
- "4 0 IL \n",
- "5 0 IL \n",
- "6 0 IL \n",
- "7 0 IL \n",
- "8 0 IL \n",
- "9 0 IL "
- ]
- },
- "execution_count": 133,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df_il_res.round(3)"
- ]
- },
+ " \n",
+ " 30 | \n",
+ " None | \n",
+ " tfrecord-00481-of-01000_309.json | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.106 | \n",
+ " 2.680 | \n",
+ " 0.135 | \n",
+ " 10.895 | \n",
+ " 3.673 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " None | \n",
+ " tfrecord-00246-of-01000_104.json | \n",
+ " 5 | \n",
+ " 0.0 | \n",
+ " 0.202 | \n",
+ " 2.069 | \n",
+ " 0.028 | \n",
+ " 13.658 | \n",
+ " 2.539 | \n",
+ " 0.000 | \n",
+ " 0.600 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " None | \n",
+ " tfrecord-00880-of-01000_456.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.090 | \n",
+ " 2.430 | \n",
+ " 0.159 | \n",
+ " 141.539 | \n",
+ " 35.880 | \n",
+ " 0.150 | \n",
+ " 0.450 | \n",
+ " 0.200 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " None | \n",
+ " tfrecord-00576-of-01000_258.json | \n",
+ " 5 | \n",
+ " 0.0 | \n",
+ " 0.105 | \n",
+ " 1.928 | \n",
+ " 0.166 | \n",
+ " 6.807 | \n",
+ " 1.368 | \n",
+ " 0.200 | \n",
+ " 0.600 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " None | \n",
+ " tfrecord-00342-of-01000_476.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.034 | \n",
+ " 2.557 | \n",
+ " 0.232 | \n",
+ " 221.795 | \n",
+ " 177.375 | \n",
+ " 0.300 | \n",
+ " 0.050 | \n",
+ " 0.600 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " None | \n",
+ " tfrecord-00445-of-01000_61.json | \n",
+ " 18 | \n",
+ " 0.0 | \n",
+ " 0.124 | \n",
+ " 2.418 | \n",
+ " 0.056 | \n",
+ " 181.373 | \n",
+ " 72.507 | \n",
+ " 0.056 | \n",
+ " 0.111 | \n",
+ " 0.333 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " None | \n",
+ " tfrecord-00430-of-01000_75.json | \n",
+ " 1 | \n",
+ " 0.0 | \n",
+ " 0.062 | \n",
+ " 1.224 | \n",
+ " 0.203 | \n",
+ " 4.086 | \n",
+ " 1.092 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " None | \n",
+ " tfrecord-00518-of-01000_117.json | \n",
+ " 4 | \n",
+ " 0.0 | \n",
+ " 0.078 | \n",
+ " 2.130 | \n",
+ " 0.054 | \n",
+ " 8.837 | \n",
+ " 1.530 | \n",
+ " 0.250 | \n",
+ " 0.500 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " None | \n",
+ " tfrecord-00444-of-01000_226.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.038 | \n",
+ " 2.863 | \n",
+ " 0.268 | \n",
+ " 255.838 | \n",
+ " 181.061 | \n",
+ " 0.150 | \n",
+ " 0.650 | \n",
+ " 0.100 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " None | \n",
+ " tfrecord-00083-of-01000_389.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.011 | \n",
+ " 2.175 | \n",
+ " 0.456 | \n",
+ " 33.302 | \n",
+ " 2.794 | \n",
+ " 0.150 | \n",
+ " 0.300 | \n",
+ " 0.350 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " None | \n",
+ " tfrecord-00119-of-01000_361.json | \n",
+ " 20 | \n",
+ " 0.0 | \n",
+ " 0.124 | \n",
+ " 2.261 | \n",
+ " 0.016 | \n",
+ " 139.079 | \n",
+ " 33.757 | \n",
+ " 0.300 | \n",
+ " 0.100 | \n",
+ " 0.400 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " None | \n",
+ " tfrecord-00559-of-01000_436.json | \n",
+ " 6 | \n",
+ " 0.0 | \n",
+ " 0.029 | \n",
+ " 2.778 | \n",
+ " 0.236 | \n",
+ " 148.869 | \n",
+ " 316.467 | \n",
+ " 0.333 | \n",
+ " 0.333 | \n",
+ " 0.167 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " None | \n",
+ " tfrecord-00172-of-01000_326.json | \n",
+ " 5 | \n",
+ " 0.0 | \n",
+ " 0.022 | \n",
+ " 2.604 | \n",
+ " 0.361 | \n",
+ " 6.835 | \n",
+ " 2.275 | \n",
+ " 0.000 | \n",
+ " 0.800 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " None | \n",
+ " tfrecord-00444-of-01000_42.json | \n",
+ " 8 | \n",
+ " 0.0 | \n",
+ " 0.142 | \n",
+ " 1.849 | \n",
+ " 0.063 | \n",
+ " 220.961 | \n",
+ " 209.358 | \n",
+ " 0.000 | \n",
+ " 0.125 | \n",
+ " 0.375 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " None | \n",
+ " tfrecord-00569-of-01000_314.json | \n",
+ " 13 | \n",
+ " 0.0 | \n",
+ " 0.016 | \n",
+ " 3.279 | \n",
+ " 0.481 | \n",
+ " 15.319 | \n",
+ " 5.635 | \n",
+ " 0.154 | \n",
+ " 0.385 | \n",
+ " 0.462 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " None | \n",
+ " tfrecord-00496-of-01000_43.json | \n",
+ " 9 | \n",
+ " 0.0 | \n",
+ " 0.094 | \n",
+ " 2.482 | \n",
+ " 0.000 | \n",
+ " 155.690 | \n",
+ " 256.373 | \n",
+ " 0.333 | \n",
+ " 0.222 | \n",
+ " 0.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " None | \n",
+ " tfrecord-00601-of-01000_213.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.050 | \n",
+ " 1.667 | \n",
+ " 0.194 | \n",
+ " 2.774 | \n",
+ " 0.411 | \n",
+ " 0.000 | \n",
+ " 0.333 | \n",
+ " 0.667 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " None | \n",
+ " tfrecord-00537-of-01000_284.json | \n",
+ " 6 | \n",
+ " 0.0 | \n",
+ " 0.040 | \n",
+ " 1.947 | \n",
+ " 0.313 | \n",
+ " 14.913 | \n",
+ " 3.075 | \n",
+ " 0.000 | \n",
+ " 0.333 | \n",
+ " 0.500 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " None | \n",
+ " tfrecord-00736-of-01000_303.json | \n",
+ " 13 | \n",
+ " 0.0 | \n",
+ " 0.066 | \n",
+ " 1.944 | \n",
+ " 0.195 | \n",
+ " 180.957 | \n",
+ " 115.203 | \n",
+ " 0.154 | \n",
+ " 0.231 | \n",
+ " 0.231 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " None | \n",
+ " tfrecord-00646-of-01000_105.json | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.038 | \n",
+ " 3.075 | \n",
+ " 0.070 | \n",
+ " 6.076 | \n",
+ " 2.975 | \n",
+ " 0.000 | \n",
+ " 0.000 | \n",
+ " 1.000 | \n",
+ " IL | \n",
+ "
\n",
+ " \n",
+ "\n",
+ ""
+ ],
+ "text/plain": [
+ " run_id traffic_scene agents_controlled reg_coef \\\n",
+ "0 None tfrecord-00350-of-01000_410.json 2 0.0 \n",
+ "1 None tfrecord-00349-of-01000_311.json 2 0.0 \n",
+ "2 None tfrecord-00421-of-01000_364.json 9 0.0 \n",
+ "3 None tfrecord-00243-of-01000_280.json 2 0.0 \n",
+ "4 None tfrecord-00275-of-01000_483.json 3 0.0 \n",
+ "5 None tfrecord-00201-of-01000_202.json 16 0.0 \n",
+ "6 None tfrecord-00623-of-01000_242.json 8 0.0 \n",
+ "7 None tfrecord-00316-of-01000_183.json 10 0.0 \n",
+ "8 None tfrecord-00443-of-01000_487.json 6 0.0 \n",
+ "9 None tfrecord-00908-of-01000_302.json 20 0.0 \n",
+ "10 None tfrecord-00395-of-01000_97.json 18 0.0 \n",
+ "11 None tfrecord-00845-of-01000_433.json 3 0.0 \n",
+ "12 None tfrecord-00628-of-01000_338.json 3 0.0 \n",
+ "13 None tfrecord-00498-of-01000_229.json 7 0.0 \n",
+ "14 None tfrecord-00069-of-01000_193.json 5 0.0 \n",
+ "15 None tfrecord-00271-of-01000_435.json 18 0.0 \n",
+ "16 None tfrecord-00480-of-01000_25.json 3 0.0 \n",
+ "17 None tfrecord-00683-of-01000_42.json 2 0.0 \n",
+ "18 None tfrecord-00384-of-01000_192.json 3 0.0 \n",
+ "19 None tfrecord-00073-of-01000_173.json 4 0.0 \n",
+ "20 None tfrecord-00468-of-01000_290.json 3 0.0 \n",
+ "21 None tfrecord-00656-of-01000_18.json 3 0.0 \n",
+ "22 None tfrecord-00651-of-01000_285.json 4 0.0 \n",
+ "23 None tfrecord-00364-of-01000_319.json 2 0.0 \n",
+ "24 None tfrecord-00615-of-01000_417.json 13 0.0 \n",
+ "25 None tfrecord-00795-of-01000_89.json 3 0.0 \n",
+ "26 None tfrecord-00880-of-01000_341.json 20 0.0 \n",
+ "27 None tfrecord-00187-of-01000_350.json 3 0.0 \n",
+ "28 None tfrecord-00428-of-01000_234.json 8 0.0 \n",
+ "29 None tfrecord-00748-of-01000_178.json 3 0.0 \n",
+ "30 None tfrecord-00481-of-01000_309.json 2 0.0 \n",
+ "31 None tfrecord-00246-of-01000_104.json 5 0.0 \n",
+ "32 None tfrecord-00880-of-01000_456.json 20 0.0 \n",
+ "33 None tfrecord-00576-of-01000_258.json 5 0.0 \n",
+ "34 None tfrecord-00342-of-01000_476.json 20 0.0 \n",
+ "35 None tfrecord-00445-of-01000_61.json 18 0.0 \n",
+ "36 None tfrecord-00430-of-01000_75.json 1 0.0 \n",
+ "37 None tfrecord-00518-of-01000_117.json 4 0.0 \n",
+ "38 None tfrecord-00444-of-01000_226.json 20 0.0 \n",
+ "39 None tfrecord-00083-of-01000_389.json 20 0.0 \n",
+ "40 None tfrecord-00119-of-01000_361.json 20 0.0 \n",
+ "41 None tfrecord-00559-of-01000_436.json 6 0.0 \n",
+ "42 None tfrecord-00172-of-01000_326.json 5 0.0 \n",
+ "43 None tfrecord-00444-of-01000_42.json 8 0.0 \n",
+ "44 None tfrecord-00569-of-01000_314.json 13 0.0 \n",
+ "45 None tfrecord-00496-of-01000_43.json 9 0.0 \n",
+ "46 None tfrecord-00601-of-01000_213.json 3 0.0 \n",
+ "47 None tfrecord-00537-of-01000_284.json 6 0.0 \n",
+ "48 None tfrecord-00736-of-01000_303.json 13 0.0 \n",
+ "49 None tfrecord-00646-of-01000_105.json 2 0.0 \n",
+ "\n",
+ " act_acc accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate \\\n",
+ "0 0.094 1.526 0.127 5.253 1.066 0.000 \n",
+ "1 0.019 2.409 0.281 3.528 1.018 0.000 \n",
+ "2 0.067 1.970 0.164 7.574 1.055 0.000 \n",
+ "3 0.012 1.974 0.239 2.759 0.555 0.000 \n",
+ "4 0.050 1.833 0.220 3.999 0.621 0.000 \n",
+ "5 0.094 2.250 0.003 140.214 57.562 0.312 \n",
+ "6 0.058 2.032 0.375 205.846 114.555 0.125 \n",
+ "7 0.162 2.260 0.013 8.700 0.760 0.200 \n",
+ "8 0.025 2.143 0.294 9.119 1.796 0.000 \n",
+ "9 0.149 2.089 0.039 12.949 1.180 0.150 \n",
+ "10 0.076 2.333 0.292 253.132 200.875 0.000 \n",
+ "11 0.133 1.311 0.190 3.827 0.630 0.000 \n",
+ "12 0.029 1.485 0.367 4.083 0.509 0.000 \n",
+ "13 0.088 2.199 0.113 195.827 120.754 0.143 \n",
+ "14 0.115 1.891 0.144 169.086 297.963 0.000 \n",
+ "15 0.178 1.959 0.024 14.123 1.407 0.056 \n",
+ "16 0.142 1.460 0.146 9.145 1.885 0.333 \n",
+ "17 0.262 1.287 0.102 8.824 2.010 0.000 \n",
+ "18 0.117 1.701 0.096 7.469 1.049 0.000 \n",
+ "19 0.178 2.130 0.046 8.065 1.571 0.000 \n",
+ "20 0.017 2.952 0.461 4.592 2.468 0.000 \n",
+ "21 0.025 2.386 0.392 6.863 3.039 0.000 \n",
+ "22 0.084 1.902 0.050 4.980 0.859 0.250 \n",
+ "23 0.088 1.411 0.288 4.417 0.519 0.000 \n",
+ "24 0.034 2.531 0.350 11.520 2.748 0.154 \n",
+ "25 0.042 2.039 0.180 171.441 274.731 0.333 \n",
+ "26 0.034 2.709 0.451 168.273 81.749 0.000 \n",
+ "27 0.104 1.837 0.169 6.644 1.400 0.000 \n",
+ "28 0.059 1.558 0.465 8.172 1.171 0.000 \n",
+ "29 0.000 2.302 0.334 3.596 1.461 0.000 \n",
+ "30 0.106 2.680 0.135 10.895 3.673 0.000 \n",
+ "31 0.202 2.069 0.028 13.658 2.539 0.000 \n",
+ "32 0.090 2.430 0.159 141.539 35.880 0.150 \n",
+ "33 0.105 1.928 0.166 6.807 1.368 0.200 \n",
+ "34 0.034 2.557 0.232 221.795 177.375 0.300 \n",
+ "35 0.124 2.418 0.056 181.373 72.507 0.056 \n",
+ "36 0.062 1.224 0.203 4.086 1.092 0.000 \n",
+ "37 0.078 2.130 0.054 8.837 1.530 0.250 \n",
+ "38 0.038 2.863 0.268 255.838 181.061 0.150 \n",
+ "39 0.011 2.175 0.456 33.302 2.794 0.150 \n",
+ "40 0.124 2.261 0.016 139.079 33.757 0.300 \n",
+ "41 0.029 2.778 0.236 148.869 316.467 0.333 \n",
+ "42 0.022 2.604 0.361 6.835 2.275 0.000 \n",
+ "43 0.142 1.849 0.063 220.961 209.358 0.000 \n",
+ "44 0.016 3.279 0.481 15.319 5.635 0.154 \n",
+ "45 0.094 2.482 0.000 155.690 256.373 0.333 \n",
+ "46 0.050 1.667 0.194 2.774 0.411 0.000 \n",
+ "47 0.040 1.947 0.313 14.913 3.075 0.000 \n",
+ "48 0.066 1.944 0.195 180.957 115.203 0.154 \n",
+ "49 0.038 3.075 0.070 6.076 2.975 0.000 \n",
+ "\n",
+ " veh_edge_cr veh_veh_cr class \n",
+ "0 1.000 0.000 IL \n",
+ "1 0.000 1.000 IL \n",
+ "2 0.222 0.667 IL \n",
+ "3 0.000 1.000 IL \n",
+ "4 0.667 0.333 IL \n",
+ "5 0.250 0.000 IL \n",
+ "6 0.000 0.375 IL \n",
+ "7 0.200 0.200 IL \n",
+ "8 0.333 0.500 IL \n",
+ "9 0.200 0.200 IL \n",
+ "10 0.833 0.111 IL \n",
+ "11 0.333 0.667 IL \n",
+ "12 0.333 0.333 IL \n",
+ "13 0.571 0.143 IL \n",
+ "14 0.400 0.000 IL \n",
+ "15 0.111 0.389 IL \n",
+ "16 0.000 0.000 IL \n",
+ "17 0.500 0.000 IL \n",
+ "18 0.667 0.000 IL \n",
+ "19 0.500 0.000 IL \n",
+ "20 0.333 0.667 IL \n",
+ "21 1.000 0.000 IL \n",
+ "22 0.250 0.250 IL \n",
+ "23 1.000 0.000 IL \n",
+ "24 0.385 0.385 IL \n",
+ "25 0.667 0.000 IL \n",
+ "26 0.500 0.500 IL \n",
+ "27 0.667 0.000 IL \n",
+ "28 0.875 0.125 IL \n",
+ "29 0.000 1.000 IL \n",
+ "30 1.000 0.000 IL \n",
+ "31 0.600 0.000 IL \n",
+ "32 0.450 0.200 IL \n",
+ "33 0.600 0.000 IL \n",
+ "34 0.050 0.600 IL \n",
+ "35 0.111 0.333 IL \n",
+ "36 1.000 0.000 IL \n",
+ "37 0.500 0.000 IL \n",
+ "38 0.650 0.100 IL \n",
+ "39 0.300 0.350 IL \n",
+ "40 0.100 0.400 IL \n",
+ "41 0.333 0.167 IL \n",
+ "42 0.800 0.000 IL \n",
+ "43 0.125 0.375 IL \n",
+ "44 0.385 0.462 IL \n",
+ "45 0.222 0.000 IL \n",
+ "46 0.333 0.667 IL \n",
+ "47 0.333 0.500 IL \n",
+ "48 0.231 0.231 IL \n",
+ "49 0.000 1.000 IL "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_il_res.round(3)"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -967,7282 +1235,624 @@
"text": [
"Scene tfrecord-00012-of-01000_389.json\n"
]
+ }
+ ],
+ "source": [
+ "scene_id = 'tfrecord-00012-of-01000_389.json'\n",
+ "\n",
+ "df_scene = df_il_trajs[df_il_trajs.traffic_scene == scene_id]\n",
+ "print(f\"Scene {scene_id}\")\n",
+ "for agent_id in df_scene.agent_id.unique():\n",
+ " agent_df = df_scene[df_scene.agent_id == agent_id]\n",
+ " plot_agent_trajectory(agent_df, evaluator.env.action_space.n)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. **Human-regularized RL**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Select policies to evaluate\n",
+ "- Add the model folder path and models to use\n",
+ "- Use `manage_models.py` to download models from `wandb` first"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "RL_POLICY_PATH = \"../models/hr_rl\"\n",
+ "\n",
+ "# Scenes on which to evaluate the models\n",
+ "# rl_policy_paths = glob.glob(f\"{RL_POLICY_PATH}\" + \"/*.pt\")\n",
+ "# rl_policy_names = [os.path.basename(file)[:-3] for file in rl_policy_paths]\n",
+ "reg_weights = [0.0]\n",
+ "\n",
+ "rl_policy_names = ['nocturne-hr-ppo-01_02_11_16_0.00_S100']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1 Evaluate policies on train dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating policy nocturne-hr-ppo-01_02_11_16_0.00_S100\n"
+ ]
},
{
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " 2023-12-29T13:41:42.448435\n",
- " image/svg+xml\n",
- " \n",
- " \n",
- " Matplotlib v3.8.0, https://matplotlib.org/\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ "\" clip-path=\"url(#pc7241d3713)\" style=\"fill: #cc8963; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ "\" clip-path=\"url(#pd9cc63e097)\" style=\"fill: #cc8963; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n",
+ "\n",
+ "fig.suptitle('Human likeness scores for PPO policies', fontsize=16)\n",
+ "\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='act_acc', ax=axs[0], hue='Type', legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='accel_val_mae', ax=axs[1], hue='Type', legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='steer_val_mae', ax=axs[2], hue='Type', legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='speed_mae', ax=axs[3], hue='Type', legend=False)\n",
+ "\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='pos_rmse', ax=axs[4], hue='Type', legend=True)\n",
+ "\n",
+ "fig.tight_layout()\n",
+ "sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 2024-01-03T10:32:59.762425\n",
+ " image/svg+xml\n",
+ " \n",
+ " \n",
+ " Matplotlib v3.8.0, https://matplotlib.org/\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
" \n",
- " \n",
+ " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
" \n",
"\n"
@@ -13100,95 +4892,63 @@
}
],
"source": [
- "scene_id = 'tfrecord-00012-of-01000_389.json'\n",
+ "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n",
"\n",
- "df_scene = df_il_trajs[df_il_trajs.traffic_scene == scene_id]\n",
- "print(f\"Scene {scene_id}\")\n",
- "for agent_id in df_scene.agent_id.unique():\n",
- " agent_df = df_scene[df_scene.agent_id == agent_id]\n",
- " plot_agent_trajectory(agent_df, evaluator.env.action_space.n)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2. **Human-regularized RL**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "RL_POLICY_PATH = \"../models/hr_rl\"\n",
+ "fig.suptitle(f'Performance scores for PPO policies | DET = {DETERMINISTIC}', fontsize=16)\n",
"\n",
- "# Scenes on which to evaluate the models\n",
- "rl_policy_paths = glob.glob(f\"{RL_POLICY_PATH}\" + \"/*.pt\")\n",
- "rl_policy_names = [os.path.basename(file)[:-3] for file in rl_policy_paths]\n",
- "reg_weights = [0.0]\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='goal_rate', ax=axs[0], hue='Type', legend=False)\n",
"\n",
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='veh_edge_cr', ax=axs[1], hue='Type', legend=False)\n",
"\n",
- "rl_policy_names"
+ "sns.barplot(data=df_hr_rl, x='num_scenes', y='veh_veh_cr', ax=axs[2], hue='Type', legend=False)\n",
+ "\n",
+ "fig.tight_layout()\n",
+ "sns.despine()"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 35,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating policy nocturne-hr-ppo-12_29_07_55_0.025\n",
- "Evaluating policy nocturne-hr-ppo-12_29_07_52_0.00\n"
- ]
+ "data": {
+ "text/plain": [
+ "199"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "df_hr_rl_all = pd.DataFrame()\n",
- "\n",
- "for idx, policy in enumerate(rl_policy_names):\n",
- "\n",
- " print(f'Evaluating policy {policy}')\n",
- "\n",
- " # Load trained model from artifact dir\n",
- " checkpoint = torch.load(f\"{RL_POLICY_PATH}/{policy}.pt\")\n",
- " policy = LateFusionMLPPolicy(\n",
- " observation_space=checkpoint['data']['observation_space'],\n",
- " action_space=checkpoint['data']['action_space'],\n",
- " lr_schedule=checkpoint['data']['lr_schedule'],\n",
- " use_sde=checkpoint['data']['use_sde'],\n",
- " env_config=env_config,\n",
- " mlp_class=LateFusionMLP,\n",
- " #mlp_config=checkpoint['model'],\n",
- " )\n",
- " policy.load_state_dict(checkpoint['state_dict'])\n",
- " policy.eval();\n",
- "\n",
- " # Evaluate on scenes\n",
- " evaluator = EvaluatePolicy(\n",
- " env_config=env_config, \n",
- " exp_config=exp_config,\n",
- " policy=policy,\n",
- " eval_files=eval_files,\n",
- " log_to_wandb=False, \n",
- " deterministic=True,\n",
- " reg_coef=reg_weights[idx],\n",
- " return_trajectories=True,\n",
- " )\n",
- "\n",
- " df_rl_res, _ = evaluator._get_scores()\n",
- " df_rl_res['class'] = f\"HR_RL: {reg_weights[idx]}\"\n",
- "\n",
- " df_hr_rl_all = pd.concat([df_hr_rl_all, df_rl_res])"
+ "len(df_hr_rl['traffic_scene'].unique())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.04001020194409743"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_hr_rl[df_hr_rl['Type'] == 'train'].groupby('traffic_scene')['goal_rate'].mean()) / 199"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 25,
"metadata": {},
"outputs": [
{
@@ -13213,494 +4973,155 @@
" \n",
" | \n",
" run_id | \n",
- " traffic_scene | \n",
- " agents_controlled | \n",
" reg_coef | \n",
+ " traffic_scene | \n",
+ " agent_id | \n",
" act_acc | \n",
+ " accel_val_mae | \n",
+ " steer_val_mae | \n",
" pos_rmse | \n",
" speed_mae | \n",
" goal_rate | \n",
" veh_edge_cr | \n",
" veh_veh_cr | \n",
- " num_violations | \n",
- " class | \n",
+ " Type | \n",
+ " num_scenes | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" None | \n",
- " tfrecord-00004-of-01000_378.json | \n",
- " 5 | \n",
- " 0.025 | \n",
- " 0.142500 | \n",
- " 8.358297 | \n",
- " 1.440388 | \n",
- " 0.400000 | \n",
- " 0.000000 | \n",
- " 0.200000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " None | \n",
- " tfrecord-00003-of-01000_109.json | \n",
- " 20 | \n",
- " 0.025 | \n",
- " 0.105000 | \n",
- " 119.291848 | \n",
- " 21.604723 | \n",
- " 0.800000 | \n",
- " 0.000000 | \n",
- " 0.200000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " None | \n",
- " tfrecord-00004-of-01000_61.json | \n",
- " 6 | \n",
- " 0.025 | \n",
- " 0.164583 | \n",
- " 6.606689 | \n",
- " 1.097389 | \n",
- " 0.666667 | \n",
- " 0.333333 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " None | \n",
- " tfrecord-00012-of-01000_87.json | \n",
- " 9 | \n",
- " 0.025 | \n",
- " 0.133333 | \n",
- " 164.076051 | \n",
- " 139.820955 | \n",
- " 0.777778 | \n",
- " 0.000000 | \n",
- " 0.222222 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " None | \n",
- " tfrecord-00007-of-01000_237.json | \n",
- " 14 | \n",
- " 0.025 | \n",
- " 0.184821 | \n",
- " 13.654639 | \n",
- " 1.436978 | \n",
- " 0.714286 | \n",
- " 0.142857 | \n",
- " 0.142857 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " None | \n",
- " tfrecord-00005-of-01000_423.json | \n",
- " 4 | \n",
- " 0.025 | \n",
- " 0.312500 | \n",
- " 13.714191 | \n",
- " 1.655821 | \n",
- " 0.500000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " None | \n",
- " tfrecord-00012-of-01000_246.json | \n",
- " 20 | \n",
- " 0.025 | \n",
- " 0.113125 | \n",
- " 193.501743 | \n",
- " 98.045055 | \n",
- " 0.750000 | \n",
- " 0.150000 | \n",
- " 0.100000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " None | \n",
- " tfrecord-00012-of-01000_389.json | \n",
- " 4 | \n",
- " 0.025 | \n",
- " 0.121875 | \n",
- " 123.742438 | \n",
- " 64.752366 | \n",
- " 0.750000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " None | \n",
+ " 0.0 | \n",
" tfrecord-00001-of-01000_307.json | \n",
- " 3 | \n",
- " 0.025 | \n",
- " 0.133333 | \n",
- " 7.316937 | \n",
- " 1.267483 | \n",
- " 0.333333 | \n",
- " 0.333333 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " None | \n",
- " tfrecord-00004-of-01000_157.json | \n",
- " 11 | \n",
- " 0.025 | \n",
- " 0.122727 | \n",
- " 260.829870 | \n",
- " 315.457867 | \n",
- " 0.545455 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.025 | \n",
- "
\n",
- " \n",
- " 0 | \n",
- " None | \n",
- " tfrecord-00004-of-01000_378.json | \n",
- " 5 | \n",
- " 0.000 | \n",
- " 0.025000 | \n",
- " 8.214314 | \n",
- " 1.658601 | \n",
- " 0.600000 | \n",
- " 0.200000 | \n",
- " 0.000000 | \n",
" 0 | \n",
- " HR_RL: 0.0 | \n",
+ " 0.0750 | \n",
+ " 2.326923 | \n",
+ " 0.314103 | \n",
+ " 7.041383 | \n",
+ " 1.820937 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " train | \n",
+ " 100 | \n",
"
\n",
" \n",
" 1 | \n",
" None | \n",
- " tfrecord-00003-of-01000_109.json | \n",
- " 20 | \n",
- " 0.000 | \n",
- " 0.059375 | \n",
- " 119.300123 | \n",
- " 23.499673 | \n",
- " 0.850000 | \n",
- " 0.050000 | \n",
- " 0.100000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
+ " 0.0 | \n",
+ " tfrecord-00001-of-01000_307.json | \n",
+ " 1 | \n",
+ " 0.0125 | \n",
+ " 1.973684 | \n",
+ " 0.350000 | \n",
+ " 3.827434 | \n",
+ " 1.104908 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " train | \n",
+ " 100 | \n",
"
\n",
" \n",
" 2 | \n",
" None | \n",
- " tfrecord-00004-of-01000_61.json | \n",
- " 6 | \n",
- " 0.000 | \n",
- " 0.029167 | \n",
- " 8.344646 | \n",
- " 1.510286 | \n",
- " 0.666667 | \n",
- " 0.333333 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " None | \n",
- " tfrecord-00012-of-01000_87.json | \n",
- " 9 | \n",
- " 0.000 | \n",
- " 0.044444 | \n",
- " 164.078821 | \n",
- " 127.039760 | \n",
- " 0.777778 | \n",
- " 0.000000 | \n",
- " 0.222222 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " None | \n",
- " tfrecord-00007-of-01000_237.json | \n",
- " 14 | \n",
- " 0.000 | \n",
- " 0.067857 | \n",
- " 14.361484 | \n",
- " 1.978035 | \n",
- " 0.642857 | \n",
- " 0.142857 | \n",
- " 0.142857 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " None | \n",
- " tfrecord-00005-of-01000_423.json | \n",
- " 4 | \n",
- " 0.000 | \n",
- " 0.096875 | \n",
- " 12.347980 | \n",
- " 2.126997 | \n",
- " 0.750000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " None | \n",
- " tfrecord-00012-of-01000_246.json | \n",
- " 20 | \n",
- " 0.000 | \n",
- " 0.034375 | \n",
- " 193.515814 | \n",
- " 112.653371 | \n",
- " 0.750000 | \n",
- " 0.050000 | \n",
- " 0.200000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " None | \n",
- " tfrecord-00012-of-01000_389.json | \n",
- " 4 | \n",
- " 0.000 | \n",
- " 0.034375 | \n",
- " 123.741883 | \n",
- " 102.708780 | \n",
- " 0.500000 | \n",
- " 0.000000 | \n",
- " 0.500000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " None | \n",
+ " 0.0 | \n",
" tfrecord-00001-of-01000_307.json | \n",
- " 3 | \n",
- " 0.000 | \n",
- " 0.025000 | \n",
- " 8.964723 | \n",
- " 2.219035 | \n",
- " 1.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " None | \n",
- " tfrecord-00004-of-01000_157.json | \n",
- " 11 | \n",
- " 0.000 | \n",
- " 0.057955 | \n",
- " 260.830741 | \n",
- " 314.463060 | \n",
- " 0.545455 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0 | \n",
- " HR_RL: 0.0 | \n",
- "
\n",
- " \n",
- "\n",
- ""
- ],
- "text/plain": [
- " run_id traffic_scene agents_controlled reg_coef \\\n",
- "0 None tfrecord-00004-of-01000_378.json 5 0.025 \n",
- "1 None tfrecord-00003-of-01000_109.json 20 0.025 \n",
- "2 None tfrecord-00004-of-01000_61.json 6 0.025 \n",
- "3 None tfrecord-00012-of-01000_87.json 9 0.025 \n",
- "4 None tfrecord-00007-of-01000_237.json 14 0.025 \n",
- "5 None tfrecord-00005-of-01000_423.json 4 0.025 \n",
- "6 None tfrecord-00012-of-01000_246.json 20 0.025 \n",
- "7 None tfrecord-00012-of-01000_389.json 4 0.025 \n",
- "8 None tfrecord-00001-of-01000_307.json 3 0.025 \n",
- "9 None tfrecord-00004-of-01000_157.json 11 0.025 \n",
- "0 None tfrecord-00004-of-01000_378.json 5 0.000 \n",
- "1 None tfrecord-00003-of-01000_109.json 20 0.000 \n",
- "2 None tfrecord-00004-of-01000_61.json 6 0.000 \n",
- "3 None tfrecord-00012-of-01000_87.json 9 0.000 \n",
- "4 None tfrecord-00007-of-01000_237.json 14 0.000 \n",
- "5 None tfrecord-00005-of-01000_423.json 4 0.000 \n",
- "6 None tfrecord-00012-of-01000_246.json 20 0.000 \n",
- "7 None tfrecord-00012-of-01000_389.json 4 0.000 \n",
- "8 None tfrecord-00001-of-01000_307.json 3 0.000 \n",
- "9 None tfrecord-00004-of-01000_157.json 11 0.000 \n",
- "\n",
- " act_acc pos_rmse speed_mae goal_rate veh_edge_cr veh_veh_cr \\\n",
- "0 0.142500 8.358297 1.440388 0.400000 0.000000 0.200000 \n",
- "1 0.105000 119.291848 21.604723 0.800000 0.000000 0.200000 \n",
- "2 0.164583 6.606689 1.097389 0.666667 0.333333 0.000000 \n",
- "3 0.133333 164.076051 139.820955 0.777778 0.000000 0.222222 \n",
- "4 0.184821 13.654639 1.436978 0.714286 0.142857 0.142857 \n",
- "5 0.312500 13.714191 1.655821 0.500000 0.000000 0.000000 \n",
- "6 0.113125 193.501743 98.045055 0.750000 0.150000 0.100000 \n",
- "7 0.121875 123.742438 64.752366 0.750000 0.000000 0.000000 \n",
- "8 0.133333 7.316937 1.267483 0.333333 0.333333 0.000000 \n",
- "9 0.122727 260.829870 315.457867 0.545455 0.000000 0.000000 \n",
- "0 0.025000 8.214314 1.658601 0.600000 0.200000 0.000000 \n",
- "1 0.059375 119.300123 23.499673 0.850000 0.050000 0.100000 \n",
- "2 0.029167 8.344646 1.510286 0.666667 0.333333 0.000000 \n",
- "3 0.044444 164.078821 127.039760 0.777778 0.000000 0.222222 \n",
- "4 0.067857 14.361484 1.978035 0.642857 0.142857 0.142857 \n",
- "5 0.096875 12.347980 2.126997 0.750000 0.000000 0.000000 \n",
- "6 0.034375 193.515814 112.653371 0.750000 0.050000 0.200000 \n",
- "7 0.034375 123.741883 102.708780 0.500000 0.000000 0.500000 \n",
- "8 0.025000 8.964723 2.219035 1.000000 0.000000 0.000000 \n",
- "9 0.057955 260.830741 314.463060 0.545455 0.000000 0.000000 \n",
- "\n",
- " num_violations class \n",
- "0 0 HR_RL: 0.025 \n",
- "1 0 HR_RL: 0.025 \n",
- "2 0 HR_RL: 0.025 \n",
- "3 0 HR_RL: 0.025 \n",
- "4 0 HR_RL: 0.025 \n",
- "5 0 HR_RL: 0.025 \n",
- "6 0 HR_RL: 0.025 \n",
- "7 0 HR_RL: 0.025 \n",
- "8 0 HR_RL: 0.025 \n",
- "9 0 HR_RL: 0.025 \n",
- "0 0 HR_RL: 0.0 \n",
- "1 0 HR_RL: 0.0 \n",
- "2 0 HR_RL: 0.0 \n",
- "3 0 HR_RL: 0.0 \n",
- "4 0 HR_RL: 0.0 \n",
- "5 0 HR_RL: 0.0 \n",
- "6 0 HR_RL: 0.0 \n",
- "7 0 HR_RL: 0.0 \n",
- "8 0 HR_RL: 0.0 \n",
- "9 0 HR_RL: 0.0 "
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df_hr_rl_all"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4. **Summary figures**"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4.1 What is the overall performance, and how is it distributed across scenes?\n",
- "\n",
- "- Error distribution"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 57,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " act_acc | \n",
- " goal_rate | \n",
- " veh_edge_cr | \n",
- " veh_veh_cr | \n",
+ " 9 | \n",
+ " 0.0000 | \n",
+ " 2.088608 | \n",
+ " 0.412025 | \n",
+ " 8.245358 | \n",
+ " 2.618147 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " train | \n",
+ " 100 | \n",
"
\n",
" \n",
- " reg_coef | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
+ " 3 | \n",
+ " None | \n",
+ " 0.0 | \n",
+ " tfrecord-00004-of-01000_378.json | \n",
+ " 0 | \n",
+ " 0.0000 | \n",
+ " 2.421053 | \n",
+ " 0.503509 | \n",
+ " 5.061539 | \n",
+ " 1.413594 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " train | \n",
+ " 100 | \n",
"
\n",
- " \n",
- " \n",
" \n",
- " 0.0 | \n",
- " 0.198 | \n",
- " 0.283 | \n",
- " 0.265 | \n",
- " 0.085 | \n",
+ " 4 | \n",
+ " None | \n",
+ " 0.0 | \n",
+ " tfrecord-00004-of-01000_378.json | \n",
+ " 2 | \n",
+ " 0.0000 | \n",
+ " 1.500000 | \n",
+ " 0.425000 | \n",
+ " 1.942767 | \n",
+ " 0.621620 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " train | \n",
+ " 100 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
- " act_acc goal_rate veh_edge_cr veh_veh_cr\n",
- "reg_coef \n",
- "0.0 0.198 0.283 0.265 0.085"
+ " run_id reg_coef traffic_scene agent_id act_acc \\\n",
+ "0 None 0.0 tfrecord-00001-of-01000_307.json 0 0.0750 \n",
+ "1 None 0.0 tfrecord-00001-of-01000_307.json 1 0.0125 \n",
+ "2 None 0.0 tfrecord-00001-of-01000_307.json 9 0.0000 \n",
+ "3 None 0.0 tfrecord-00004-of-01000_378.json 0 0.0000 \n",
+ "4 None 0.0 tfrecord-00004-of-01000_378.json 2 0.0000 \n",
+ "\n",
+ " accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate veh_edge_cr \\\n",
+ "0 2.326923 0.314103 7.041383 1.820937 0.0 0.0 \n",
+ "1 1.973684 0.350000 3.827434 1.104908 0.0 1.0 \n",
+ "2 2.088608 0.412025 8.245358 2.618147 0.0 0.0 \n",
+ "3 2.421053 0.503509 5.061539 1.413594 0.0 0.0 \n",
+ "4 1.500000 0.425000 1.942767 0.621620 1.0 0.0 \n",
+ "\n",
+ " veh_veh_cr Type num_scenes \n",
+ "0 0.0 train 100 \n",
+ "1 0.0 train 100 \n",
+ "2 0.0 train 100 \n",
+ "3 0.0 train 100 \n",
+ "4 0.0 train 100 "
]
},
- "execution_count": 57,
+ "execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "metrics = ['act_acc', 'goal_rate', 'veh_edge_cr', 'veh_veh_cr', 'reg_coef']\n",
- "\n",
- "# Aggregate performance HR-PPO\n",
- "df_hr_rl_agg = df_hr_rl_all[metrics].groupby('reg_coef').mean().round(3)\n",
- "\n",
- "# Aggregate performance IL\n",
- "df_il_agg = df_il_res[metrics].groupby('reg_coef').mean().round(3)\n",
- "\n",
- "df_il_agg"
+ "df_hr_rl.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_hr_rl.groupby('traffic_scene')"
]
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": 27,
"metadata": {},
"outputs": [
{
@@ -13724,14 +5145,12 @@
" \n",
" \n",
" | \n",
- " act_acc | \n",
" goal_rate | \n",
" veh_edge_cr | \n",
" veh_veh_cr | \n",
"
\n",
" \n",
- " reg_coef | \n",
- " | \n",
+ " traffic_scene | \n",
" | \n",
" | \n",
" | \n",
@@ -13739,42 +5158,106 @@
"
\n",
" \n",
" \n",
- " 0.000 | \n",
- " 0.047 | \n",
- " 0.708 | \n",
- " 0.078 | \n",
- " 0.117 | \n",
+ " tfrecord-00001-of-01000_307.json | \n",
+ " 0.000000 | \n",
+ " 0.333333 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " 0.025 | \n",
- " 0.153 | \n",
- " 0.624 | \n",
- " 0.096 | \n",
- " 0.087 | \n",
+ " tfrecord-00004-of-00150_246.json | \n",
+ " 0.333333 | \n",
+ " 0.666667 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00004-of-01000_378.json | \n",
+ " 0.200000 | \n",
+ " 0.400000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00005-of-00150_192.json | \n",
+ " 0.111111 | \n",
+ " 0.333333 | \n",
+ " 0.222222 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00005-of-01000_423.json | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00144-of-00150_65.json | \n",
+ " 0.000000 | \n",
+ " 0.857143 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00146-of-00150_228.json | \n",
+ " 0.000000 | \n",
+ " 0.571429 | \n",
+ " 0.428571 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00147-of-00150_191.json | \n",
+ " 0.000000 | \n",
+ " 0.500000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00147-of-00150_6.json | \n",
+ " 0.111111 | \n",
+ " 0.333333 | \n",
+ " 0.222222 | \n",
+ "
\n",
+ " \n",
+ " tfrecord-00149-of-00150_66.json | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
"\n",
+ "199 rows × 3 columns
\n",
""
],
"text/plain": [
- " act_acc goal_rate veh_edge_cr veh_veh_cr\n",
- "reg_coef \n",
- "0.000 0.047 0.708 0.078 0.117\n",
- "0.025 0.153 0.624 0.096 0.087"
+ " goal_rate veh_edge_cr veh_veh_cr\n",
+ "traffic_scene \n",
+ "tfrecord-00001-of-01000_307.json 0.000000 0.333333 0.000000\n",
+ "tfrecord-00004-of-00150_246.json 0.333333 0.666667 0.000000\n",
+ "tfrecord-00004-of-01000_378.json 0.200000 0.400000 0.000000\n",
+ "tfrecord-00005-of-00150_192.json 0.111111 0.333333 0.222222\n",
+ "tfrecord-00005-of-01000_423.json 0.000000 1.000000 0.000000\n",
+ "... ... ... ...\n",
+ "tfrecord-00144-of-00150_65.json 0.000000 0.857143 0.000000\n",
+ "tfrecord-00146-of-00150_228.json 0.000000 0.571429 0.428571\n",
+ "tfrecord-00147-of-00150_191.json 0.000000 0.500000 0.000000\n",
+ "tfrecord-00147-of-00150_6.json 0.111111 0.333333 0.222222\n",
+ "tfrecord-00149-of-00150_66.json 0.000000 1.000000 0.000000\n",
+ "\n",
+ "[199 rows x 3 columns]"
]
},
- "execution_count": 58,
+ "execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "df_hr_rl_agg"
+ "df_hr_rl.groupby('traffic_scene')[performance_metrics].mean()"
]
},
{
"cell_type": "code",
- "execution_count": 80,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -13783,12 +5266,12 @@
"\n",
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" \n",
" \n",
- " 2023-12-29T12:03:00.410283\n",
+ " 2024-01-03T10:35:04.178353\n",
" image/svg+xml\n",
" \n",
" \n",
@@ -13803,80 +5286,255 @@
" \n",
" \n",
" \n",
- " \n",
+ "\" style=\"fill: none\"/>\n",
" \n",
" \n",
" \n",
- " \n",
+ "\" style=\"fill: none\"/>\n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
" \n",
- " \n",
" \n",
" \n",
" \n",
- " \n",
+ "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
" \n",
" \n",
- " \n",
+ "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
" \n",
" \n",
- " \n",
+ "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
+ " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
+ " \n",
" \n",
" \n",
"\n"
],
"text/plain": [
- ""
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.histplot(data=df_hr_rl, y='goal_rate',);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#plot_agent_trajectory(df_rl_trajs_sub[df_rl_trajs_sub['agent_id'] == 0], evaluator.env.action_space.n)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#plot_agent_trajectory(df_rl_trajs_sub[df_rl_trajs_sub['agent_id'] == 1], evaluator.env.action_space.n)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.2 Make videos of top 10 most difficult scenes / top 10 easiest scenes\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " run_id | \n",
+ " traffic_scene | \n",
+ " agents_controlled | \n",
+ " reg_coef | \n",
+ " act_acc | \n",
+ " accel_val_mae | \n",
+ " steer_val_mae | \n",
+ " pos_rmse | \n",
+ " speed_mae | \n",
+ " goal_rate | \n",
+ " veh_edge_cr | \n",
+ " veh_veh_cr | \n",
+ " class | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 48 | \n",
+ " None | \n",
+ " tfrecord-00061-of-01000_223.json | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.031250 | \n",
+ " 1.807229 | \n",
+ " 0.345783 | \n",
+ " 5.234700 | \n",
+ " 1.430212 | \n",
+ " 0.0 | \n",
+ " 0.5 | \n",
+ " 0.0 | \n",
+ " HR_RL: 0.0 | \n",
+ "
\n",
+ " \n",
+ " 59 | \n",
+ " None | \n",
+ " tfrecord-00072-of-01000_20.json | \n",
+ " 4 | \n",
+ " 0.0 | \n",
+ " 0.003125 | \n",
+ " 2.354651 | \n",
+ " 0.415116 | \n",
+ " 6.995587 | \n",
+ " 2.723312 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " HR_RL: 0.0 | \n",
+ "
\n",
+ " \n",
+ " 57 | \n",
+ " None | \n",
+ " tfrecord-00070-of-01000_158.json | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.000000 | \n",
+ " 1.330189 | \n",
+ " 0.587736 | \n",
+ " 5.233993 | \n",
+ " 0.280898 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " HR_RL: 0.0 | \n",
+ "
\n",
+ " \n",
+ " 56 | \n",
+ " None | \n",
+ " tfrecord-00069-of-01000_193.json | \n",
+ " 5 | \n",
+ " 0.0 | \n",
+ " 0.020000 | \n",
+ " 1.956522 | \n",
+ " 0.281522 | \n",
+ " 169.067436 | \n",
+ " 768.979722 | \n",
+ " 0.0 | \n",
+ " 0.8 | \n",
+ " 0.0 | \n",
+ " HR_RL: 0.0 | \n",
+ "
\n",
+ " \n",
+ " 84 | \n",
+ " None | \n",
+ " tfrecord-00095-of-01000_204.json | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.006250 | \n",
+ " 2.000000 | \n",
+ " 0.449383 | \n",
+ " 5.101935 | \n",
+ " 0.509049 | \n",
+ " 0.0 | \n",
+ " 0.5 | \n",
+ " 0.0 | \n",
+ " HR_RL: 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " run_id traffic_scene agents_controlled reg_coef \\\n",
+ "48 None tfrecord-00061-of-01000_223.json 2 0.0 \n",
+ "59 None tfrecord-00072-of-01000_20.json 4 0.0 \n",
+ "57 None tfrecord-00070-of-01000_158.json 3 0.0 \n",
+ "56 None tfrecord-00069-of-01000_193.json 5 0.0 \n",
+ "84 None tfrecord-00095-of-01000_204.json 2 0.0 \n",
+ "\n",
+ " act_acc accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate \\\n",
+ "48 0.031250 1.807229 0.345783 5.234700 1.430212 0.0 \n",
+ "59 0.003125 2.354651 0.415116 6.995587 2.723312 0.0 \n",
+ "57 0.000000 1.330189 0.587736 5.233993 0.280898 0.0 \n",
+ "56 0.020000 1.956522 0.281522 169.067436 768.979722 0.0 \n",
+ "84 0.006250 2.000000 0.449383 5.101935 0.509049 0.0 \n",
+ "\n",
+ " veh_edge_cr veh_veh_cr class \n",
+ "48 0.5 0.0 HR_RL: 0.0 \n",
+ "59 0.0 1.0 HR_RL: 0.0 \n",
+ "57 1.0 0.0 HR_RL: 0.0 \n",
+ "56 0.8 0.0 HR_RL: 0.0 \n",
+ "84 0.5 0.0 HR_RL: 0.0 "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "NUM_VIDEOS = 5\n",
+ "\n",
+ "df_hr_rl.sort_values('goal_rate').head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "Tracking run with wandb version 0.16.1"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Run data is saved locally in /home/emerge/daphne/nocturne_lab/evaluation/wandb/run-20240103_070328-k4hdxw1b
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Syncing run RL_S100_worst_stoch to Weights & Biases (docs)
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View project at https://wandb.ai/daphnecor/eval_hr_rl_policies"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run at https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/k4hdxw1b"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating policy on tfrecord-00061-of-01000_223.json...\n",
+ "Evaluating policy on tfrecord-00072-of-01000_20.json...\n",
+ "Evaluating policy on tfrecord-00070-of-01000_158.json...\n",
+ "Evaluating policy on tfrecord-00069-of-01000_193.json...\n",
+ "Evaluating policy on tfrecord-00095-of-01000_204.json...\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run RL_S100_worst_stoch at: https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/k4hdxw1b
View job at https://wandb.ai/daphnecor/eval_hr_rl_policies/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyNzEwNDI4OA==/version_details/v0
Synced 5 W&B file(s), 10 media file(s), 0 artifact file(s) and 0 other file(s)"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Find logs at: ./wandb/run-20240103_070328-k4hdxw1b/logs
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "worst_scenes = df_hr_rl_all.sort_values('goal_rate')['traffic_scene'].values[:NUM_VIDEOS]\n",
+ "\n",
+ "render_traffic_scenes(\n",
+ " eval_scenes=worst_scenes,\n",
+ " policy=policy,\n",
+ " run_name=\"RL_S100_worst\",\n",
+ " deterministic=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "Tracking run with wandb version 0.16.1"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Run data is saved locally in /home/emerge/daphne/nocturne_lab/evaluation/wandb/run-20240103_070415-o4cve8xh
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Syncing run RL_S100_worst_det to Weights & Biases (docs)
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View project at https://wandb.ai/daphnecor/eval_hr_rl_policies"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run at https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/o4cve8xh"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating policy on tfrecord-00061-of-01000_223.json...\n",
+ "Evaluating policy on tfrecord-00072-of-01000_20.json...\n",
+ "Evaluating policy on tfrecord-00070-of-01000_158.json...\n",
+ "Evaluating policy on tfrecord-00069-of-01000_193.json...\n",
+ "Evaluating policy on tfrecord-00095-of-01000_204.json...\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run RL_S100_worst_det at: https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/o4cve8xh
View job at https://wandb.ai/daphnecor/eval_hr_rl_policies/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyNzEwNDI4OA==/version_details/v0
Synced 5 W&B file(s), 10 media file(s), 0 artifact file(s) and 0 other file(s)"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Find logs at: ./wandb/run-20240103_070415-o4cve8xh/logs
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "worst_scenes = df_hr_rl_all.sort_values('goal_rate')['traffic_scene'].values[:NUM_VIDEOS]\n",
+ "\n",
+ "render_traffic_scenes(\n",
+ " eval_scenes=worst_scenes,\n",
+ " policy=policy,\n",
+ " run_name=\"RL_S100_worst\",\n",
+ " deterministic=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "Tracking run with wandb version 0.16.1"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Run data is saved locally in /home/emerge/daphne/nocturne_lab/evaluation/wandb/run-20240103_073548-bspg5o1a
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Syncing run RL_S100_best_stoch to Weights & Biases (docs)
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View project at https://wandb.ai/daphnecor/eval_hr_rl_policies"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run at https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/bspg5o1a"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating policy on tfrecord-00019-of-01000_117.json...\n",
+ "Evaluating policy on tfrecord-00059-of-01000_12.json...\n",
+ "Evaluating policy on tfrecord-00111-of-01000_359.json...\n",
+ "Evaluating policy on tfrecord-00079-of-01000_207.json...\n",
+ "Evaluating policy on tfrecord-00024-of-01000_33.json...\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run RL_S100_best_stoch at: https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/bspg5o1a
View job at https://wandb.ai/daphnecor/eval_hr_rl_policies/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyNzEwNDI4OA==/version_details/v1
Synced 5 W&B file(s), 10 media file(s), 0 artifact file(s) and 0 other file(s)"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Find logs at: ./wandb/run-20240103_073548-bspg5o1a/logs
"
+ ],
+ "text/plain": [
+ ""
]
},
"metadata": {},
@@ -14565,7 +6594,28 @@
}
],
"source": [
- "sns.barplot(data=df_hr_rl_all, y='goal_rate', hue='reg_coef', palette=\"deep\");"
+ "best_scenes = df_hr_rl_all.sort_values('goal_rate', ascending=False)['traffic_scene'].values[:NUM_VIDEOS]\n",
+ "\n",
+ "render_traffic_scenes(\n",
+ " eval_scenes=best_scenes,\n",
+ " policy=policy,\n",
+ " run_name=\"RL_S100_best\",\n",
+ " deterministic=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Is there a trade-off between human likeness and performance? | **Pareto plots**"
]
},
{
diff --git a/experiments/hr_rl/run_hr_ppo.py b/experiments/hr_rl/run_hr_ppo.py
index 9b1788a9..09e90da1 100644
--- a/experiments/hr_rl/run_hr_ppo.py
+++ b/experiments/hr_rl/run_hr_ppo.py
@@ -37,7 +37,6 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl
env = MultiAgentAsVecEnv(
config=env_config,
num_envs=env_config.max_num_vehicles,
- train_on_single_scene=exp_config.train_on_single_scene,
)
# Set up run
@@ -62,6 +61,9 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl
logging.info(f"Learning in {len(env.env.files)} scene(s): {env.env.files} | using {exp_config.ppo.device}")
logging.info(f"--- obs_space: {env.observation_space.shape[0]} ---")
logging.info(f"Action_space\n: {env.env.idx_to_actions}")
+
+ if exp_config.reg_weight > 0.0:
+ logging.info(f"Regularization weight: {exp_config.reg_weight} with policy: {exp_config.human_policy_path}")
# Initialize custom callback
custom_callback = CustomMultiAgentCallback(
@@ -135,7 +137,7 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl
"arch_ego_state": [8],
"arch_road_objects": [64],
"arch_road_graph": [128, 64],
- "arch_shared_net": [],
+ "arch_shared_net": [128],
"act_func": "tanh",
"dropout": 0.0,
"last_layer_dim_pi": 64,
@@ -143,16 +145,26 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl
}
)
- num_files_list = [10, 100, 1000]
+ num_files_list = [10]
+ #MEMORY = [4, 2]
+ MEMORY = [1]
- for scenes in num_files_list:
- # Set regularization weight
- env_config.num_files = scenes
-
- # Train
- train(
- env_config=env_config,
- exp_config=exp_config,
- video_config=video_config,
- model_config=model_config,
- )
+ for mem in MEMORY:
+ for num_scenes in num_files_list:
+
+ # Set memory
+ env_config.subscriber.n_frames_stacked = mem
+
+ # Set regularization weight
+ #exp_config.reg_weight = lam
+
+ exp_config.human_policy_path = f"models/il/human_policy_S{num_scenes}_2024_01_02.pt"
+ env_config.num_files = num_scenes
+
+ # Train
+ train(
+ env_config=env_config,
+ exp_config=exp_config,
+ video_config=video_config,
+ model_config=model_config,
+ )
diff --git a/experiments/hr_rl/run_hr_ppo_cli.py b/experiments/hr_rl/run_hr_ppo_cli.py
index 34ca1dfd..3fb542dd 100644
--- a/experiments/hr_rl/run_hr_ppo_cli.py
+++ b/experiments/hr_rl/run_hr_ppo_cli.py
@@ -41,7 +41,6 @@
"large": [256, 128, 64],
}
-
def run_hr_ppo(
sweep_name: str = exp_config.group,
steer_disc: int = 5,
diff --git a/experiments/il/run_behavioral_cloning.py b/experiments/il/run_behavioral_cloning.py
index 2812a9de..84f90880 100644
--- a/experiments/il/run_behavioral_cloning.py
+++ b/experiments/il/run_behavioral_cloning.py
@@ -19,7 +19,7 @@
if __name__ == "__main__":
MAX_EVAL_FILES = 12
- NUM_TRAIN_FILES = 1000
+ NUM_TRAIN_FILES = 50
# Create run
run = wandb.init(
diff --git a/networks/mlp_late_fusion.py b/networks/mlp_late_fusion.py
index aed85e0b..c625628c 100644
--- a/networks/mlp_late_fusion.py
+++ b/networks/mlp_late_fusion.py
@@ -49,10 +49,10 @@ def __init__(
self.arch_shared_net = arch_shared_net
self.dropout = dropout
- #TODO: write function that gets this information from config
- self.input_dim_ego = 10
- self.input_dim_road_graph = 6500
- self.input_dim_road_objects = 220
+ #TODO @Daphne: write function that gets this information from config
+ self.input_dim_ego = 10 * self.config.subscriber.n_frames_stacked
+ self.input_dim_road_graph = 6500 * self.config.subscriber.n_frames_stacked
+ self.input_dim_road_objects = 220 * self.config.subscriber.n_frames_stacked
# IMPORTANT:Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
@@ -198,10 +198,10 @@ def _unflatten_obs(self, obs_flat):
# Visible state object order: road_objects, road_points, traffic_lights, stop_signs
# Find the ends of each section
- ROAD_OBJECTS_END = 13 * self.config.scenario.max_visible_objects
- ROAD_POINTS_END = ROAD_OBJECTS_END + (13 * self.config.scenario.max_visible_road_points)
- TL_END = ROAD_POINTS_END + (12 * self.config.scenario.max_visible_traffic_lights)
- STOP_SIGN_END = TL_END + (3 * self.config.scenario.max_visible_stop_signs)
+ ROAD_OBJECTS_END = (13 * self.config.scenario.max_visible_objects) * self.config.subscriber.n_frames_stacked
+ ROAD_POINTS_END = (ROAD_OBJECTS_END + (13 * self.config.scenario.max_visible_road_points)) * self.config.subscriber.n_frames_stacked
+ TL_END = (ROAD_POINTS_END + (12 * self.config.scenario.max_visible_traffic_lights)) * self.config.subscriber.n_frames_stacked
+ STOP_SIGN_END = (TL_END + (3 * self.config.scenario.max_visible_stop_signs)) * self.config.subscriber.n_frames_stacked
# Unflatten
road_objects = vis_state[:, :ROAD_OBJECTS_END]
@@ -251,12 +251,14 @@ def _build_mlp_extractor(self) -> None:
# Load environment and experiment configurations
env_config = load_config("env_config")
exp_config = load_config("exp_config")
+
+ env_config.subscriber.n_frames_stacked = 2
# Make environment
env = MultiAgentAsVecEnv(
config=env_config,
num_envs=env_config.max_num_vehicles,
- train_on_single_scene=exp_config.train_on_single_scene,
+
)
obs = env.reset()
diff --git a/nocturne/envs/base_env.py b/nocturne/envs/base_env.py
index ea089b7d..14eb7aff 100644
--- a/nocturne/envs/base_env.py
+++ b/nocturne/envs/base_env.py
@@ -43,10 +43,10 @@ def __init__( # pylint: disable=too-many-arguments
self,
config: Dict[str, Any],
*,
- img_width=1600,
- img_height=1600,
+ img_width=1200,
+ img_height=1200,
draw_target_positions=True,
- padding=50.0,
+ padding=10.0,
) -> None:
"""Initialize a Nocturne environment.
@@ -488,6 +488,10 @@ def _get_obs_space_dim(self, config, base=0):
(3 * self.config.scenario.max_visible_stop_signs) +
(12 * self.config.scenario.max_visible_traffic_lights)
)
+
+ # Multiply by memory to get the final dimension
+ obs_space_dim = obs_space_dim * self.config.subscriber.n_frames_stacked
+
return (obs_space_dim,)
def normalize_ego_state_by_cat(self, state):
@@ -520,6 +524,8 @@ def render(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint
Optional[RenderType]: Rendered image.
"""
return self.scenario.getImage(**self._render_settings)
+
+ env.scenario.getImage(**video_config.render)
def render_ego(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint: disable=unused-argument
"""Render the ego vehicles.
diff --git a/nocturne/envs/vec_env_ma.py b/nocturne/envs/vec_env_ma.py
index ea2e38e9..e5056b77 100644
--- a/nocturne/envs/vec_env_ma.py
+++ b/nocturne/envs/vec_env_ma.py
@@ -26,7 +26,7 @@ class MultiAgentAsVecEnv(VecEnv):
VecEnv (SB3 VecEnv): SB3 VecEnv base class.
"""
- def __init__(self, config, num_envs, psr=False, train_on_single_scene=False):
+ def __init__(self, config, num_envs, psr=False):
# Create Nocturne env
self.env = BaseEnv(config)
@@ -44,7 +44,7 @@ def __init__(self, config, num_envs, psr=False, train_on_single_scene=False):
self.frac_collided = [] # Log fraction of agents that collided
self.frac_goal_achieved = [] # Log fraction of agents that achieved their goal
self.agents_in_scene = []
- self.filename = self.env.files[0] if train_on_single_scene else None # If provided, always use the same file
+ self.filename = None # If provided, always use the same file
def _reset_seeds(self) -> None:
"""Reset all environments' seeds."""
@@ -172,6 +172,11 @@ def reset_scene_dict(self):
def step_num(self) -> List[int]:
"""The episodic timestep."""
return self.env.step_num
+ @property
+ def render(self) -> List[int]:
+ """The episodic timestep."""
+ img = self.env.render()
+ return img
def seed(self, seed=None):
"""Set the random seeds for all environments."""
diff --git a/utils/eval.py b/utils/eval.py
index 062a976b..9a2c76db 100644
--- a/utils/eval.py
+++ b/utils/eval.py
@@ -51,13 +51,13 @@ def _get_scores(self):
logging.info(f'\n Evaluating policy on {len(self.eval_files)} files...')
- # Make tables
- df_eval = pd.DataFrame(
+ # Create tables
+ df_eval = pd.DataFrame(
columns=[
"run_id",
- "traffic_scene",
- "agents_controlled",
"reg_coef",
+ "traffic_scene",
+ "agent_id",
"act_acc",
"accel_val_mae",
"steer_val_mae",
@@ -103,7 +103,7 @@ def _get_scores(self):
nonnan_ids = ~np.isnan(expert_actions)
# Compute metrics
- action_accuracy = self.get_action_accuracy(
+ action_accuracies = self.get_action_accuracy(
policy_actions, expert_actions, nonnan_ids
)
@@ -120,15 +120,15 @@ def _get_scores(self):
)
# Violations of the 3-second rule
- violations_matrix, num_violations = self.get_veh_to_veh_distances(policy_pos, policy_speed)
+ #violations_matrix, num_violations = self.get_veh_to_veh_distances(policy_pos, policy_speed)
# Store metrics
- scene_perf = {
+ scene_perf = pd.DataFrame({
"run_id": self.run.id if self.log_to_wandb else None,
+ "reg_coef": np.repeat(self.reg_coef, len(self.agent_names)),
"traffic_scene": file,
- "agents_controlled": expert_actions.shape[0],
- "reg_coef": self.exp_config.reg_weight if self.reg_coef is None else self.reg_coef,
- "act_acc": action_accuracy,
+ "agent_id": self.agent_names,
+ "act_acc": action_accuracies,
"accel_val_mae": abs_diff_accel,
"steer_val_mae": abs_diff_steer,
"pos_rmse": position_rmse,
@@ -136,8 +136,8 @@ def _get_scores(self):
"goal_rate": policy_gr,
"veh_edge_cr": policy_edge_cr,
"veh_veh_cr": policy_veh_cr,
- }
- df_eval.loc[len(df_eval)] = scene_perf
+ })
+ df_eval = pd.concat([df_eval, scene_perf], ignore_index=True)
if self.return_trajectories:
scene_trajs = pd.DataFrame({
@@ -180,6 +180,7 @@ def _step_through_scene(self, filename: str, mode: str):
# Make sure the agent ids are in the same order
agent_ids = np.sort([veh.id for veh in self.env.controlled_vehicles])
+ self.agent_names = agent_ids
agent_id_to_idx_dict = {agent_id: idx for idx, agent_id in enumerate(agent_ids)}
last_info_dicts = {agent_id: {} for agent_id in agent_ids}
dead_agent_ids = []
@@ -188,7 +189,7 @@ def _step_through_scene(self, filename: str, mode: str):
action_indices = np.full(fill_value=np.nan, shape=(self.num_agents, num_steps))
agent_positions = np.full(fill_value=np.nan, shape=(self.num_agents, num_steps, 2))
agent_speed = np.full(fill_value=np.nan, shape=(self.num_agents, num_steps))
- goal_achieved, veh_edge_collision, veh_veh_collision = 0, 0, 0
+ goal_achieved, veh_edge_collision, veh_veh_collision = np.zeros(self.num_agents), np.zeros(self.num_agents), np.zeros(self.num_agents)
# Set control mode
if mode == "expert":
@@ -263,18 +264,19 @@ def _step_through_scene(self, filename: str, mode: str):
if done_dict["__all__"]:
for agent_id in agent_ids:
- goal_achieved += last_info_dicts[agent_id]["goal_achieved"]
- veh_edge_collision += last_info_dicts[agent_id]["veh_edge_collision"]
- veh_veh_collision += last_info_dicts[agent_id]["veh_veh_collision"]
+ agent_idx = agent_id_to_idx_dict[agent_id]
+ goal_achieved[agent_idx] += last_info_dicts[agent_id]["goal_achieved"]
+ veh_edge_collision[agent_idx] += last_info_dicts[agent_id]["veh_edge_collision"]
+ veh_veh_collision[agent_idx] += last_info_dicts[agent_id]["veh_veh_collision"]
break
return (
action_indices,
agent_positions,
agent_speed,
- goal_achieved/self.num_agents,
- veh_edge_collision/self.num_agents,
- veh_veh_collision/self.num_agents,
+ goal_achieved,
+ veh_edge_collision,
+ veh_veh_collision,
)
def get_action_val_diff(self, pred_actions, expert_actions):
@@ -291,23 +293,31 @@ def get_action_val_diff(self, pred_actions, expert_actions):
np.isnan(expert_actions),
)
)
- valid_expert_acts = expert_actions[nonnan_ids]
- valid_pred_acts = pred_actions[nonnan_ids]
+ num_agents = expert_actions.shape[0]
+ arr = np.zeros((num_agents, 2))
+ for agent_idx in range(num_agents):
+ not_nan = nonnan_ids[agent_idx, :]
- exp_acc_vals, exp_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts)
- pred_acc_vals, pred_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts)
+ valid_expert_acts = expert_actions[agent_idx, :][not_nan]
+ valid_pred_acts = pred_actions[agent_idx, :][not_nan]
- for idx in range(valid_expert_acts.shape[0]):
+ exp_acc_vals, exp_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts)
+ pred_acc_vals, pred_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts)
- # Get expert and predicted values
- exp_acc_vals[idx], exp_steer_vals[idx] = self.env.idx_to_actions[valid_expert_acts[idx]]
- pred_acc_vals[idx], pred_steer_vals[idx] = self.env.idx_to_actions[valid_pred_acts[idx]]
+ for idx in range(valid_expert_acts.shape[0]):
+ # Get expert and predicted values
+ exp_acc_vals[idx], exp_steer_vals[idx] = self.env.idx_to_actions[valid_expert_acts[idx]]
+ pred_acc_vals[idx], pred_steer_vals[idx] = self.env.idx_to_actions[valid_pred_acts[idx]]
- # Get mean absolute difference
- abs_accel_diff = np.abs(exp_acc_vals - pred_acc_vals).mean()
- abs_steer_diff = np.abs(exp_steer_vals - pred_steer_vals).mean()
+ # Get mean absolute difference
+ abs_accel_diff = np.abs(exp_acc_vals - pred_acc_vals).mean()
+ abs_steer_diff = np.abs(exp_steer_vals - pred_steer_vals).mean()
- return abs_accel_diff, abs_steer_diff
+ # Store
+ arr[agent_idx, 0] = abs_accel_diff
+ arr[agent_idx, 1] = abs_steer_diff
+
+ return arr[:, 0], arr[:, 1] # abs_diff_accel, abs_diff_steer
def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids):
"""Get accuracy of agent actions.
@@ -316,9 +326,15 @@ def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids):
expert_actions: (num_agents, num_steps_per_episode) the expert actions of the agents.
nonnan_ids: (num_agents, num_steps_per_episode) the indices of non-nan actions.
"""
- return (expert_actions[nonnan_ids] == pred_actions[nonnan_ids]).sum() / nonnan_ids.flatten().shape[0]
+ num_agents = expert_actions.shape[0]
+ arr = np.zeros(num_agents)
+ for agent_idx in range(num_agents):
+ not_nan = nonnan_ids[agent_idx, :]
+ arr[agent_idx] = (expert_actions[agent_idx, :][not_nan] == pred_actions[agent_idx, :][not_nan]).sum() / not_nan.shape[0]
+ return arr
def get_pos_rmse(self, pred_actions, expert_actions):
+
# Filter out invalid actions
nonnan_ids = np.logical_not(
np.logical_or(
@@ -326,7 +342,12 @@ def get_pos_rmse(self, pred_actions, expert_actions):
np.isnan(expert_actions),
)
)
- return np.sqrt(np.linalg.norm(pred_actions[nonnan_ids] - expert_actions[nonnan_ids])).mean()
+ num_agents = expert_actions.shape[0]
+ arr = np.zeros(num_agents)
+ for agent_idx in range(num_agents):
+ not_nan = nonnan_ids[agent_idx, :]
+ arr[agent_idx] = (np.sqrt(np.linalg.norm(pred_actions[agent_idx, :][not_nan] - expert_actions[agent_idx, :][not_nan]))).mean()
+ return arr
def get_speed_mae(self, pred_actions, expert_actions):
# Filter out invalid actions
@@ -336,7 +357,12 @@ def get_speed_mae(self, pred_actions, expert_actions):
np.isnan(expert_actions),
)
)
- return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean()
+ num_agents = expert_actions.shape[0]
+ arr = np.zeros(num_agents)
+ for agent_idx in range(num_agents):
+ not_nan = nonnan_ids[agent_idx, :]
+ arr[agent_idx] = (np.abs(pred_actions[agent_idx, :][not_nan] - expert_actions[agent_idx, :][not_nan])).mean()
+ return arr
def get_veh_to_veh_distances(self, positions, velocities, time_gap_in_sec=3):
"""Calculate distances between vehicles at each time step and track
diff --git a/utils/evaluation.py b/utils/evaluation.py
index 5ec2402c..9b105f73 100644
--- a/utils/evaluation.py
+++ b/utils/evaluation.py
@@ -2,12 +2,15 @@
import wandb
from pyvirtualdisplay import Display
+from utils.render import make_video
+
def evaluate_policy(
model,
env,
- n_steps_per_episode,
- n_eval_episodes,
+ n_steps_per_episode=80,
+ n_eval_episodes=1,
eval_files=None,
+ eval_modes=['expert', 'policy'],
deterministic = True,
render = False,
video_caption = None,
@@ -20,8 +23,8 @@ def evaluate_policy(
Args:
-----
- model: The RL agent you want to evaluate. This can be any object
- that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
+ model: The IL/RL policy to evaluate. This can be any object that implements
+ a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
or policy (``BasePolicy``).
env: The gym environment or ``VecEnv`` environment.
n_eval_episodes: Number of different traffic scenes in which to evaluate the agent
@@ -39,19 +42,25 @@ def evaluate_policy(
if verbose == 1:
print(f"Evaluating policy on {traffic_scene}...")
- for episode_i in range(n_eval_episodes):
+ for eval_mode in eval_modes:
# Reset env
observations = env.reset(filename=traffic_scene)
num_agents_controlled = len(env.agent_ids)
curr_rewards = np.zeros(num_agents_controlled)
frames = []
- for step_j in range(n_steps_per_episode):
+ for timestep in range(n_steps_per_episode):
+
+ if eval_mode == 'policy':
+ # Predict actions
+ actions, _ = model.predict(
+ observations,
+ deterministic=deterministic,
+ )
+ elif eval_mode == 'expert':
+ actions = None
- actions, _ = model.predict(
- observations,
- deterministic=deterministic,
- )
+ # Step environment
new_observations, rewards, dones, infos = env.step(actions)
for agent_idx, agent_id in enumerate(env.agent_ids):
@@ -62,7 +71,7 @@ def evaluate_policy(
# Render
if render:
- if step_j % video_config.logging.render_interval == 0:
+ if timestep % video_config.logging.render_interval == 0:
if video_config.logging.where_am_i == "headless_machine":
with Display() as disp:
render_scene = env.env.scenario.getImage(**video_config.render)
@@ -71,9 +80,9 @@ def evaluate_policy(
render_scene = env.scenario.getImage(**video_config)
frames.append(render_scene.T)
- if sum(dones) == env.num_agents or step_j == (n_steps_per_episode-1):
- episode_rewards[episode_i] = curr_rewards.sum() / env.num_agents
- episode_lengths[episode_i] = step_j
+ if sum(dones) == env.num_agents or timestep == (n_steps_per_episode-1):
+ episode_rewards = curr_rewards.sum() / env.num_agents
+ episode_lengths = timestep
break
# Log video to wandb
@@ -84,7 +93,7 @@ def evaluate_policy(
{
video_key: wandb.Video(movie_frames,
fps=video_config.logging.fps,
- caption=f'{video_caption} | norm_ep_return = {episode_rewards[episode_i]:.2f}'),
+ caption=f'{video_caption}_mode_:{eval_mode}'),
},
)
diff --git a/utils/manage_models.py b/utils/manage_models.py
index f2433cf8..a686ea69 100644
--- a/utils/manage_models.py
+++ b/utils/manage_models.py
@@ -9,7 +9,7 @@
entity = "daphnecor"
project = "scaling_ppo"
- collection_name = "nocturne-hr-ppo-12_30_21_43"
+ collection_name = "nocturne-hr-ppo-01_02_11_16"
# Always initialize a W&B run to start tracking
wandb.init()
diff --git a/utils/render.py b/utils/render.py
index b611bcd4..bc8a6702 100644
--- a/utils/render.py
+++ b/utils/render.py
@@ -25,8 +25,8 @@ def make_video(
filenames = None,
deterministic: bool = True,
max_steps: int = 80,
- snap_interval: int = 4,
- frames_per_second: int = 4,
+ snap_interval: int = 3,
+ frames_per_second: int = 3,
) -> Tuple[np.ndarray, pd.DataFrame]:
"""Make a video of policy in traffic scene.
diff --git a/utils/wrappers.py b/utils/wrappers.py
index 67d956e2..c149c4ce 100644
--- a/utils/wrappers.py
+++ b/utils/wrappers.py
@@ -17,6 +17,7 @@ def __init__(self, config):
self.observation_space = gym.spaces.Box(-np.inf, np.inf, self.env.observation_space.shape, np.float32)
def step(self, actions=None):
+ """If actions is None, vehicles are stepped in expert control mode."""
obs = np.zeros((self.num_agents, self.observation_space.shape[0]))
rews, dones, infos = np.zeros((self.num_agents)), np.zeros((self.num_agents)), []
@@ -26,7 +27,9 @@ def step(self, actions=None):
agent_id: actions[idx] for idx, agent_id in enumerate(self.agent_ids)
if agent_id not in self.dead_agent_ids
}
- else:
+ else: # Set in expert control mode
+ for veh_obj in self.env.controlled_vehicles:
+ veh_obj.expert_control = True
agent_actions = {}
# Take a step to obtain dicts