diff --git a/configs/bc_config.yaml b/configs/bc_config.yaml index 8079114d..13547552 100644 --- a/configs/bc_config.yaml +++ b/configs/bc_config.yaml @@ -1,8 +1,8 @@ # Model save_model: true # Save model after training -model_name: "human_policy_data_2_scenes" # Name of saved model +model_name: "human_policy" # Name of saved model save_model_path: ./models/il/ # Path to save model # Train -total_samples: 10_000 # Number of obs-act-next_obs-done pairs to generate -n_epochs: 100 # Training epochs +total_samples: 150_000 # Number of obs-act-next_obs-done pairs to generate +n_epochs: 200 # Training epochs \ No newline at end of file diff --git a/configs/env_config.yaml b/configs/env_config.yaml index f8c20c3a..10644e49 100644 --- a/configs/env_config.yaml +++ b/configs/env_config.yaml @@ -8,7 +8,7 @@ env: my_custom_multi_env_v1 # name of the env, hardcoded for now episode_length: 80 warmup_period: 10 # In the RL setting we use a warmup of 10 steps # How many files of the total dataset to use. -1 indicates to use all of them -num_files: 20 +num_files: 100 fix_file_order: true # If true, always select the SAME files (when creating the environent), if false, pick files at random sample_file_method: "random" # ALTERNATIVES: "no_replacement" (SUPPORTED) / "score-based" (TODO: @Daphne) dt: 0.1 @@ -18,9 +18,9 @@ discretize_actions: true include_head_angle: false # Whether to include the head tilt/angle as part of a vehicle's action accel_discretization: 5 accel_lower_bound: -3 -accel_upper_bound: 3 -steering_lower_bound: -0.7 -steering_upper_bound: 0.7 +accel_upper_bound: 3 +steering_lower_bound: -0.7 # steer right +steering_upper_bound: 0.7 # steer left steering_discretization: 5 max_num_vehicles: 20 randomize_goals: false @@ -109,4 +109,5 @@ subscriber: n_frames_stacked: 1 # Agent memory # Path to folder with traffic scene(s) from which to create an environment -data_path: ./data_full/train \ No newline at end of file +data_path: ./data_full/train +val_data_path: ./data_full/valid \ No newline at end of file diff --git a/configs/exp_config.yaml b/configs/exp_config.yaml index 1ef37674..91946233 100644 --- a/configs/exp_config.yaml +++ b/configs/exp_config.yaml @@ -22,7 +22,7 @@ ma_callback: model_save_freq: 100 # In iterations (one iter ~ (num_agents x n_steps)) save_video: true record_n_scenes: 10 # Number of different scenes to render - video_save_freq: 50 # Make a video every k iterations (100 iters ~ 1M steps) + video_save_freq: 100 # Make a video every k iterations (100 iters ~ 1M steps) video_deterministic: true ppo: @@ -32,7 +32,7 @@ ppo: vf_coef: 0.5 # Default in SB3 is 0.5 learn: - total_timesteps: 10_000_000 + total_timesteps: 2_000_000 progress_bar: false # human-regularized RL diff --git a/evaluation/il_analysis.ipynb b/evaluation/il_analysis.ipynb new file mode 100644 index 00000000..0834727c --- /dev/null +++ b/evaluation/il_analysis.ipynb @@ -0,0 +1,5971 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "# Dependencies\n", + "import numpy as np\n", + "import glob\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import warnings\n", + "import torch\n", + "import logging\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "from typing import Callable\n", + "from gym import spaces\n", + "from stable_baselines3.common.policies import ActorCriticPolicy\n", + "from networks.mlp_late_fusion import LateFusionMLP, LateFusionMLPPolicy\n", + "from collections import Counter\n", + "from utils.plot import plot_agent_trajectory\n", + "from utils.config import load_config_nb\n", + "from utils.eval import EvaluatePolicy\n", + "from utils.policies import load_policy\n", + "\n", + "sns.set('notebook', font_scale=1.1, rc={'figure.figsize': (10, 5)})\n", + "sns.set_style('ticks', rc={'figure.facecolor': 'none', 'axes.facecolor': 'none'})\n", + "%config InlineBackend.figure_format = 'svg'\n", + "warnings.filterwarnings(\"ignore\")\n", + "plt.set_loglevel('WARNING')\n", + "pd.options.display.float_format = \"{:,.2f}\".format\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_FILES = 1000\n", + "\n", + "# Load config files\n", + "env_config = load_config_nb(\"env_config\")\n", + "exp_config = load_config_nb(\"exp_config\")\n", + "model_config = load_config_nb(\"model_config\")\n", + "\n", + "# Set data path\n", + "env_config.data_path = \"../data_full/train/\"\n", + "env_config.val_data_path = \"../data_full/valid/\"\n", + "env_config.num_files = MAX_FILES\n", + "\n", + "# Logging level set to INFO\n", + "LOGGING_LEVEL = \"INFO\"\n", + "\n", + "# Scenes on which to evaluate the models\n", + "# Make sure file order is fixed\n", + "train_file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n", + "train_eval_files = sorted([os.path.basename(file) for file in train_file_paths])\n", + "\n", + "# Valid\n", + "valid_file_paths = glob.glob(f\"{env_config.val_data_path}\" + \"/tfrecord*\")\n", + "valid_eval_files = sorted([os.path.basename(file) for file in valid_file_paths])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### IL action distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Using 996 file(s)\n" + ] + } + ], + "source": [ + "from utils.imitation_learning.waymo_iterator import TrajectoryIterator\n", + "from torch.utils.data import DataLoader\n", + "\n", + "NUM_IL_FILES = 1000\n", + "\n", + "env_config.num_files = NUM_IL_FILES\n", + "\n", + "# Create iterator\n", + "waymo_iterator = TrajectoryIterator(\n", + " data_path=env_config.data_path,\n", + " env_config=env_config,\n", + " file_limit=env_config.num_files,\n", + ") \n", + "\n", + "# Rollout to get obs-act-obs-done trajectories \n", + "rollouts = next(iter(\n", + " DataLoader(\n", + " waymo_iterator,\n", + " batch_size=10_000, # Number of samples to generate\n", + " pin_memory=True,\n", + ")))\n", + "\n", + "obs, acts, next_obs, dones = rollouts\n", + "action_cats = dict(Counter(acts.numpy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
action_idxcountperc
017148214.82
112277327.73
27141914.19
322132213.22
42133513.35
504634.63
621950.95
71560.56
8131341.34
923900.90
\n", + "
" + ], + "text/plain": [ + " action_idx count perc\n", + "0 17 1482 14.82\n", + "1 12 2773 27.73\n", + "2 7 1419 14.19\n", + "3 22 1322 13.22\n", + "4 2 1335 13.35\n", + "5 0 463 4.63\n", + "6 21 95 0.95\n", + "7 1 56 0.56\n", + "8 13 134 1.34\n", + "9 23 90 0.90" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_acts = pd.DataFrame(\n", + " {\n", + " 'action_idx': list(action_cats.keys()), \n", + " 'count': list(action_cats.values()),\n", + " 'perc': np.array(list(action_cats.values())) / len(acts.numpy()) * 100,\n", + " }\n", + ")\n", + "\n", + "df_acts['action_idx'] = df_acts['action_idx'].astype(int)\n", + "df_acts.head(n=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-01-02T09:33:00.402924\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Create a Seaborn categorical plot\n", + "sns.barplot(data=df_acts, x='action_idx', y='perc')\n", + "\n", + "plt.xlabel('Action indices')\n", + "plt.ylabel('Percentage %')\n", + "plt.title(f'Action Distribution in IL training dataset (N = {NUM_IL_FILES})')\n", + "\n", + "plt.grid(True, alpha=.5)\n", + "sns.despine();" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "BASE_PATH = \"../models/il\"\n", + "\n", + "# Scenes on which to evaluate the models\n", + "il_policy_names = ['human_policy_S10_2024_01_02', 'human_policy_S100_2024_01_02']\n", + "num_scenes = [10, 100]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluate Behavioral Cloning policies on training scenes" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Evaluating policy on 10 files...\n", + " 0%| | 0/10 [00:00\n", + "#T_a3c6a_row0_col0 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #cdecc7 30.0%, #cdecc7 36.8%, transparent 36.8%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row0_col1 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #fdcebb 30.0%, #fdcebb 35.5%, transparent 35.5%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row0_col2 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #f34935 30.0%, #f34935 47.6%, transparent 47.6%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row0_col3 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #ffece3 30.0%, #ffece3 31.7%, transparent 31.7%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row0_col4 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #3585bf 30.0%, #3585bf 50.2%, transparent 50.2%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row1_col0 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #daf0d4 30.0%, #daf0d4 35.2%, transparent 35.2%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row1_col1 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #fdc9b3 30.0%, #fdc9b3 36.1%, transparent 36.1%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row1_col2 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #9a0c14 30.0%, #9a0c14 56.8%, transparent 56.8%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row1_col3 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #fee3d7 30.0%, #fee3d7 33.2%, transparent 33.2%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "#T_a3c6a_row1_col4 {\n", + " width: 10em;\n", + " background: linear-gradient(90deg, transparent 30.0%, #115ca5 30.0%, #115ca5 54.9%, transparent 54.9%) no-repeat center;\n", + " background-size: 100% 50.0%;\n", + "}\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

Aggregated Behavioral Cloning Human Likeness Scores (train data)

 act_accaccel_val_maespeed_maesteer_val_maepos_rmse
num_scenes     
100.231.6629.340.0850.42
1000.171.8444.740.1562.22
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Show aggregate human likeness scores\n", + "hl_df_avg_train = df_il_train.groupby('num_scenes')[['act_acc', 'accel_val_mae', 'speed_mae', 'steer_val_mae', 'pos_rmse']].mean()\n", + "\n", + "hl_df_avg_train.style.format('{:.3f}', na_rep=\"\")\\\n", + " .bar(subset=['act_acc'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n", + " .bar(subset=['accel_val_mae'], align=0, vmin=0, vmax=9, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['speed_mae'], align=0, vmin=0, vmax=50, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['steer_val_mae'], align=0, vmin=0, vmax=1.4, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['pos_rmse'], align=0, vmin=0, vmax=75, cmap=\"Blues\", height=50, width=60)\\\n", + " .set_caption(\"

Aggregated Behavioral Cloning Human Likeness Scores (train data)

\").format(\"{:.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

Aggregated Behavioral Cloning Performance Metrics (train data)

 goal_rateveh_edge_crveh_veh_cr
num_scenes   
100.300.270.07
1000.160.320.17
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 122, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "performance_df_avg_train = df_il_train.groupby('num_scenes')[['goal_rate', 'veh_edge_cr', 'veh_veh_cr']].mean()\n", + "\n", + "performance_df_avg_train.style.format('{:.3f}', na_rep=\"\")\\\n", + " .bar(subset=['goal_rate'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n", + " .bar(subset=['veh_edge_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['veh_veh_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n", + " .set_caption(\"

Aggregated Behavioral Cloning Performance Metrics (train data)

\").format(\"{:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluate Behavioral Cloning policies on new (unseen) scenes" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Evaluating policy on 10 files...\n", + "100%|██████████| 10/10 [00:02<00:00, 3.47it/s]\n", + "INFO:root:Evaluating policy on 99 files...\n", + "100%|██████████| 99/99 [00:25<00:00, 3.91it/s]\n" + ] + } + ], + "source": [ + "# Set data path\n", + "env_config.data_path = \"../data_full/valid/\"\n", + "\n", + "df_il_valid = pd.DataFrame()\n", + "\n", + "for trained_policy, num_files in zip(il_policy_names, num_scenes):\n", + "\n", + " eval_files = valid_eval_files[:num_files]\n", + "\n", + " # Load trained human reference policy\n", + " human_policy = load_policy(\n", + " data_path=BASE_PATH,\n", + " file_name=trained_policy, \n", + " )\n", + "\n", + " # Evaluate policy\n", + " evaluator = EvaluatePolicy(\n", + " env_config=env_config, \n", + " exp_config=exp_config,\n", + " policy=human_policy,\n", + " eval_files=eval_files,\n", + " log_to_wandb=False, \n", + " deterministic=False,\n", + " reg_coef=0.0,\n", + " return_trajectories=True,\n", + " )\n", + "\n", + " df_il_res, df_il_trajs = evaluator._get_scores()\n", + "\n", + " df_il_res['num_scenes'] = num_files\n", + " df_il_res['policy'] = trained_policy\n", + " df_il_valid = pd.concat([df_il_valid, df_il_res])\n", + "\n", + "df_il_valid['type'] = 'validation'" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

Aggregated Behavioral Cloning Human Likeness Scores (validation data)

 act_accaccel_val_maespeed_maesteer_val_maepos_rmse
num_scenes     
100.082.1716.420.1935.51
1000.102.0434.330.1949.43
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 124, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Show aggregate human likeness scores\n", + "hl_df_avg_valid = df_il_valid.groupby('num_scenes')[['act_acc', 'accel_val_mae', 'speed_mae', 'steer_val_mae', 'pos_rmse']].mean()\n", + "\n", + "hl_df_avg_valid.style.format('{:.3f}', na_rep=\"\")\\\n", + " .bar(subset=['act_acc'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n", + " .bar(subset=['accel_val_mae'], align=0, vmin=0, vmax=9, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['speed_mae'], align=0, vmin=0, vmax=50, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['steer_val_mae'], align=0, vmin=0, vmax=1.4, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['pos_rmse'], align=0, vmin=0, vmax=75, cmap=\"Blues\", height=50, width=60)\\\n", + " .set_caption(\"

Aggregated Behavioral Cloning Human Likeness Scores (validation data)

\").format(\"{:.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

Aggregated Behavioral Cloning Performance Metrics (validation data)

 goal_rateveh_edge_crveh_veh_cr
num_scenes   
100.060.550.15
1000.120.330.22
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "performance_df_avg_valid = df_il_valid.groupby('num_scenes')[['goal_rate', 'veh_edge_cr', 'veh_veh_cr']].mean()\n", + "\n", + "performance_df_avg_valid.style.format('{:.3f}', na_rep=\"\")\\\n", + " .bar(subset=['goal_rate'], align=0, vmin=0, vmax=1, cmap=\"Greens\", height=50, width=60)\\\n", + " .bar(subset=['veh_edge_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n", + " .bar(subset=['veh_veh_cr'], align=0, vmin=0, vmax=1, cmap=\"Reds\", height=50, width=60)\\\n", + " .set_caption(\"

Aggregated Behavioral Cloning Performance Metrics (validation data)

\").format(\"{:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare " + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": {}, + "outputs": [], + "source": [ + "df_il = pd.concat([df_il_train, df_il_valid])" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [], + "source": [ + "# fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n", + "\n", + "# g = sns.violinplot(data=df_il, x='num_scenes', y='act_acc', hue='type', ax=axs[0]);\n", + "# sns.swarmplot(data=df_il, x='num_scenes', y='act_acc', hue='type', color=\"k\", size=3, ax=g.ax)\n", + "\n", + "# g = sns.violinplot(data=df_il, x='num_scenes', y='act_acc', hue='type', ax=axs[0]);" + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-01-02T10:41:47.294104\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "\n", + "fig.suptitle('Human likeness scores for IL policies', fontsize=16)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='act_acc', hue='type', ax=axs[0], legend=True,)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='accel_val_mae', hue='type', ax=axs[1], legend=False)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='steer_val_mae', hue='type', ax=axs[2], legend=False)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='pos_rmse', hue='type', ax=axs[3], legend=False)\n", + "\n", + "fig.tight_layout()\n", + "sns.despine()" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-01-02T10:41:39.069166\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "\n", + "fig.suptitle('Performance scores for IL policies', fontsize=16)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='goal_rate', hue='type', ax=axs[0], legend=True,)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='veh_edge_cr', hue='type', ax=axs[1], legend=False)\n", + "\n", + "sns.barplot(data=df_il, x='num_scenes', y='veh_veh_cr', hue='type', ax=axs[2], legend=False)\n", + "\n", + "fig.tight_layout()\n", + "sns.despine()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nocturne_lab", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/evaluation/policy_performance_analysis.ipynb b/evaluation/policy_performance_analysis.ipynb index d6600319..50e6bf82 100644 --- a/evaluation/policy_performance_analysis.ipynb +++ b/evaluation/policy_performance_analysis.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -64,53 +64,27 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "MAX_FILES = 50\n", + "\n", "# Load config files\n", "env_config = load_config_nb(\"env_config\")\n", "exp_config = load_config_nb(\"exp_config\")\n", "model_config = load_config_nb(\"model_config\")\n", "\n", "# Set data path\n", - "env_config.data_path = \"../data_10/train/\"\n", + "env_config.data_path = \"../data_full/train/\"\n", + "env_config.num_files = MAX_FILES\n", "\n", "# Logging level set to INFO\n", "LOGGING_LEVEL = \"INFO\"\n", "\n", "# Scenes on which to evaluate the models\n", "file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n", - "eval_files = [os.path.basename(file) for file in file_paths]" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['tfrecord-00004-of-01000_378.json',\n", - " 'tfrecord-00003-of-01000_109.json',\n", - " 'tfrecord-00004-of-01000_61.json',\n", - " 'tfrecord-00012-of-01000_87.json',\n", - " 'tfrecord-00007-of-01000_237.json',\n", - " 'tfrecord-00005-of-01000_423.json',\n", - " 'tfrecord-00012-of-01000_246.json',\n", - " 'tfrecord-00012-of-01000_389.json',\n", - " 'tfrecord-00001-of-01000_307.json',\n", - " 'tfrecord-00004-of-01000_157.json']" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval_files" + "eval_files = [os.path.basename(file) for file in file_paths][:MAX_FILES]" ] }, { @@ -122,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -141,9 +115,582 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:veh 37 at t = 50 returns None action!\n", + "INFO:root:veh 37 at t = 51 returns None action!\n", + "INFO:root:veh 44 at t = 75 returns None action!\n", + "INFO:root:veh 44 at t = 76 returns None action!\n", + "INFO:root:veh 44 at t = 32 returns None action!\n", + "INFO:root:veh 44 at t = 33 returns None action!\n", + "INFO:root:veh 44 at t = 34 returns None action!\n", + "INFO:root:veh 44 at t = 36 returns None action!\n", + "INFO:root:veh 44 at t = 37 returns None action!\n", + "INFO:root:veh 44 at t = 38 returns None action!\n", + "INFO:root:veh 4 at t = 2 returns None action!\n", + "INFO:root:veh 41 at t = 1 returns None action!\n", + "INFO:root:veh 41 at t = 2 returns None action!\n", + "INFO:root:veh 17 at t = 15 returns None action!\n", + "INFO:root:veh 17 at t = 16 returns None action!\n", + "INFO:root:veh 17 at t = 17 returns None action!\n", + "INFO:root:veh 17 at t = 18 returns None action!\n", + "INFO:root:veh 17 at t = 19 returns None action!\n", + "INFO:root:veh 17 at t = 31 returns None action!\n", + "INFO:root:veh 17 at t = 32 returns None action!\n", + "INFO:root:veh 17 at t = 33 returns None action!\n", + "INFO:root:veh 17 at t = 34 returns None action!\n", + "INFO:root:veh 17 at t = 43 returns None action!\n", + "INFO:root:veh 17 at t = 44 returns None action!\n", + "INFO:root:veh 17 at t = 45 returns None action!\n", + "INFO:root:veh 17 at t = 46 returns None action!\n", + "INFO:root:veh 17 at t = 47 returns None action!\n", + "INFO:root:veh 17 at t = 51 returns None action!\n", + "INFO:root:veh 17 at t = 52 returns None action!\n", + "INFO:root:veh 14 at t = 1 returns None action!\n", + "INFO:root:veh 14 at t = 2 returns None action!\n", + "INFO:root:veh 14 at t = 3 returns None action!\n", + "INFO:root:veh 14 at t = 4 returns None action!\n", + "INFO:root:veh 14 at t = 44 returns None action!\n", + "INFO:root:veh 14 at t = 45 returns None action!\n", + "INFO:root:veh 14 at t = 46 returns None action!\n", + "INFO:root:veh 0 at t = 14 returns None action!\n", + "INFO:root:veh 0 at t = 15 returns None action!\n", + "INFO:root:veh 0 at t = 16 returns None action!\n", + "INFO:root:veh 0 at t = 17 returns None action!\n", + "INFO:root:veh 0 at t = 18 returns None action!\n", + "INFO:root:veh 0 at t = 20 returns None action!\n", + "INFO:root:veh 0 at t = 21 returns None action!\n", + "INFO:root:veh 1 at t = 62 returns None action!\n", + "INFO:root:veh 1 at t = 63 returns None action!\n", + "INFO:root:veh 7 at t = 33 returns None action!\n", + "INFO:root:veh 17 at t = 34 returns None action!\n", + "INFO:root:veh 17 at t = 35 returns None action!\n", + "INFO:root:veh 2 at t = 20 returns None action!\n", + "INFO:root:veh 2 at t = 21 returns None action!\n", + "INFO:root:veh 2 at t = 29 returns None action!\n", + "INFO:root:veh 2 at t = 30 returns None action!\n", + "INFO:root:veh 2 at t = 31 returns None action!\n", + "INFO:root:veh 47 at t = 0 returns None action!\n", + "INFO:root:veh 47 at t = 1 returns None action!\n", + "INFO:root:veh 47 at t = 2 returns None action!\n", + "INFO:root:veh 25 at t = 23 returns None action!\n", + "INFO:root:veh 25 at t = 24 returns None action!\n", + "INFO:root:veh 25 at t = 25 returns None action!\n", + "INFO:root:veh 25 at t = 26 returns None action!\n", + "INFO:root:veh 25 at t = 27 returns None action!\n", + "INFO:root:veh 26 at t = 29 returns None action!\n", + "INFO:root:veh 26 at t = 30 returns None action!\n", + "INFO:root:veh 11 at t = 16 returns None action!\n", + "INFO:root:veh 11 at t = 17 returns None action!\n", + "INFO:root:veh 10 at t = 65 returns None action!\n", + "INFO:root:veh 10 at t = 66 returns None action!\n", + "INFO:root:veh 58 at t = 1 returns None action!\n", + "INFO:root:veh 58 at t = 2 returns None action!\n", + "INFO:root:veh 59 at t = 3 returns None action!\n", + "INFO:root:veh 59 at t = 4 returns None action!\n", + "INFO:root:veh 55 at t = 12 returns None action!\n", + "INFO:root:veh 55 at t = 13 returns None action!\n", + "INFO:root:veh 55 at t = 14 returns None action!\n", + "INFO:root:veh 55 at t = 15 returns None action!\n", + "INFO:root:veh 58 at t = 20 returns None action!\n", + "INFO:root:veh 58 at t = 21 returns None action!\n", + "INFO:root:veh 59 at t = 24 returns None action!\n", + "INFO:root:veh 59 at t = 25 returns None action!\n", + "INFO:root:veh 69 at t = 29 returns None action!\n", + "INFO:root:veh 69 at t = 30 returns None action!\n", + "INFO:root:veh 58 at t = 32 returns None action!\n", + "INFO:root:veh 58 at t = 33 returns None action!\n", + "INFO:root:veh 46 at t = 1 returns None action!\n", + "INFO:root:veh 46 at t = 2 returns None action!\n", + "INFO:root:veh 46 at t = 3 returns None action!\n", + "INFO:root:veh 46 at t = 4 returns None action!\n", + "INFO:root:veh 14 at t = 8 returns None action!\n", + "INFO:root:veh 14 at t = 9 returns None action!\n", + "INFO:root:veh 26 at t = 25 returns None action!\n", + "INFO:root:veh 26 at t = 26 returns None action!\n", + "INFO:root:veh 26 at t = 29 returns None action!\n", + "INFO:root:veh 26 at t = 30 returns None action!\n", + "INFO:root:veh 30 at t = 11 returns None action!\n", + "INFO:root:veh 30 at t = 12 returns None action!\n", + "INFO:root:veh 30 at t = 13 returns None action!\n", + "INFO:root:veh 30 at t = 14 returns None action!\n", + "INFO:root:veh 30 at t = 31 returns None action!\n", + "INFO:root:veh 30 at t = 32 returns None action!\n", + "INFO:root:veh 30 at t = 33 returns None action!\n", + "INFO:root:veh 30 at t = 34 returns None action!\n", + "INFO:root:veh 11 at t = 66 returns None action!\n", + "INFO:root:veh 11 at t = 67 returns None action!\n", + "INFO:root:veh 11 at t = 68 returns None action!\n", + "INFO:root:veh 11 at t = 70 returns None action!\n", + "INFO:root:veh 11 at t = 71 returns None action!\n", + "INFO:root:veh 59 at t = 31 returns None action!\n", + "INFO:root:veh 59 at t = 32 returns None action!\n", + "INFO:root:veh 59 at t = 33 returns None action!\n", + "INFO:root:veh 59 at t = 34 returns None action!\n", + "INFO:root:veh 62 at t = 35 returns None action!\n", + "INFO:root:veh 62 at t = 36 returns None action!\n", + "INFO:root:veh 8 at t = 42 returns None action!\n", + "INFO:root:veh 8 at t = 43 returns None action!\n", + "INFO:root:veh 8 at t = 44 returns None action!\n", + "INFO:root:veh 8 at t = 45 returns None action!\n", + "INFO:root:veh 8 at t = 46 returns None action!\n", + "INFO:root:veh 60 at t = 0 returns None action!\n", + "INFO:root:veh 26 at t = 0 returns None action!\n", + "INFO:root:veh 60 at t = 1 returns None action!\n", + "INFO:root:veh 60 at t = 2 returns None action!\n", + "INFO:root:veh 60 at t = 3 returns None action!\n", + "INFO:root:veh 49 at t = 3 returns None action!\n", + "INFO:root:veh 60 at t = 4 returns None action!\n", + "INFO:root:veh 49 at t = 4 returns None action!\n", + "INFO:root:veh 49 at t = 5 returns None action!\n", + "INFO:root:veh 60 at t = 7 returns None action!\n", + "INFO:root:veh 60 at t = 8 returns None action!\n", + "INFO:root:veh 29 at t = 50 returns None action!\n", + "INFO:root:veh 29 at t = 51 returns None action!\n", + "INFO:root:veh 29 at t = 52 returns None action!\n", + "INFO:root:veh 4 at t = 2 returns None action!\n", + "INFO:root:veh 4 at t = 3 returns None action!\n", + "INFO:root:veh 4 at t = 4 returns None action!\n", + "INFO:root:veh 4 at t = 5 returns None action!\n", + "INFO:root:veh 13 at t = 8 returns None action!\n", + "INFO:root:veh 13 at t = 9 returns None action!\n", + "INFO:root:veh 35 at t = 0 returns None action!\n", + "INFO:root:veh 35 at t = 1 returns None action!\n", + "INFO:root:veh 35 at t = 2 returns None action!\n", + "INFO:root:veh 35 at t = 3 returns None action!\n", + "INFO:root:veh 35 at t = 4 returns None action!\n", + "INFO:root:veh 23 at t = 5 returns None action!\n", + "INFO:root:veh 6 at t = 2 returns None action!\n", + "INFO:root:veh 6 at t = 3 returns None action!\n", + "INFO:root:veh 6 at t = 29 returns None action!\n", + "INFO:root:veh 6 at t = 30 returns None action!\n", + "INFO:root:veh 22 at t = 0 returns None action!\n", + "INFO:root:veh 22 at t = 1 returns None action!\n", + "INFO:root:veh 22 at t = 2 returns None action!\n", + "INFO:root:veh 22 at t = 3 returns None action!\n", + "INFO:root:veh 3 at t = 7 returns None action!\n", + "INFO:root:veh 3 at t = 8 returns None action!\n", + "INFO:root:veh 20 at t = 20 returns None action!\n", + "INFO:root:veh 50 at t = 0 returns None action!\n", + "INFO:root:veh 50 at t = 1 returns None action!\n", + "INFO:root:veh 2 at t = 71 returns None action!\n", + "INFO:root:veh 26 at t = 63 returns None action!\n", + "INFO:root:veh 26 at t = 64 returns None action!\n", + "INFO:root:veh 26 at t = 65 returns None action!\n", + "INFO:root:veh 26 at t = 74 returns None action!\n", + "INFO:root:veh 26 at t = 75 returns None action!\n", + "INFO:root:veh 26 at t = 76 returns None action!\n", + "INFO:root:veh 26 at t = 77 returns None action!\n", + "INFO:root:veh 7 at t = 20 returns None action!\n", + "INFO:root:veh 7 at t = 21 returns None action!\n", + "INFO:root:veh 4 at t = 41 returns None action!\n", + "INFO:root:veh 4 at t = 42 returns None action!\n", + "INFO:root:veh 3 at t = 10 returns None action!\n", + "INFO:root:veh 3 at t = 11 returns None action!\n", + "INFO:root:veh 3 at t = 12 returns None action!\n", + "INFO:root:veh 24 at t = 21 returns None action!\n", + "INFO:root:veh 24 at t = 22 returns None action!\n", + "INFO:root:veh 24 at t = 23 returns None action!\n", + "INFO:root:veh 24 at t = 24 returns None action!\n", + "INFO:root:veh 16 at t = 28 returns None action!\n", + "INFO:root:veh 16 at t = 29 returns None action!\n", + "INFO:root:veh 16 at t = 30 returns None action!\n", + "INFO:root:veh 16 at t = 31 returns None action!\n", + "INFO:root:veh 23 at t = 17 returns None action!\n", + "INFO:root:veh 23 at t = 18 returns None action!\n", + "INFO:root:veh 23 at t = 22 returns None action!\n", + "INFO:root:veh 23 at t = 23 returns None action!\n", + "INFO:root:veh 23 at t = 31 returns None action!\n", + "INFO:root:veh 23 at t = 32 returns None action!\n", + "INFO:root:veh 1 at t = 76 returns None action!\n", + "INFO:root:veh 1 at t = 77 returns None action!\n", + "INFO:root:veh 1 at t = 78 returns None action!\n", + "INFO:root:veh 36 at t = 2 returns None action!\n", + "INFO:root:veh 36 at t = 3 returns None action!\n", + "INFO:root:veh 36 at t = 4 returns None action!\n", + "INFO:root:veh 36 at t = 5 returns None action!\n", + "INFO:root:veh 36 at t = 6 returns None action!\n", + "INFO:root:veh 15 at t = 4 returns None action!\n", + "INFO:root:veh 15 at t = 5 returns None action!\n", + "INFO:root:veh 15 at t = 6 returns None action!\n", + "INFO:root:veh 15 at t = 7 returns None action!\n", + "INFO:root:veh 8 at t = 69 returns None action!\n", + "INFO:root:veh 8 at t = 70 returns None action!\n", + "INFO:root:veh 8 at t = 71 returns None action!\n", + "INFO:root:veh 1 at t = 37 returns None action!\n", + "INFO:root:veh 1 at t = 38 returns None action!\n", + "INFO:root:veh 1 at t = 39 returns None action!\n", + "INFO:root:veh 25 at t = 15 returns None action!\n", + "INFO:root:veh 25 at t = 16 returns None action!\n", + "INFO:root:veh 25 at t = 17 returns None action!\n", + "INFO:root:veh 25 at t = 18 returns None action!\n", + "INFO:root:veh 28 at t = 1 returns None action!\n", + "INFO:root:veh 28 at t = 2 returns None action!\n", + "INFO:root:veh 2 at t = 54 returns None action!\n", + "INFO:root:veh 2 at t = 55 returns None action!\n", + "INFO:root:veh 2 at t = 56 returns None action!\n", + "INFO:root:veh 2 at t = 57 returns None action!\n", + "INFO:root:veh 2 at t = 58 returns None action!\n", + "INFO:root:veh 2 at t = 59 returns None action!\n", + "INFO:root:veh 2 at t = 60 returns None action!\n", + "INFO:root:veh 2 at t = 61 returns None action!\n", + "INFO:root:veh 2 at t = 64 returns None action!\n", + "INFO:root:veh 2 at t = 65 returns None action!\n", + "INFO:root:veh 2 at t = 66 returns None action!\n", + "INFO:root:veh 9 at t = 15 returns None action!\n", + "INFO:root:veh 6 at t = 55 returns None action!\n", + "INFO:root:veh 6 at t = 56 returns None action!\n", + "INFO:root:veh 6 at t = 57 returns None action!\n", + "INFO:root:veh 6 at t = 59 returns None action!\n", + "INFO:root:veh 6 at t = 60 returns None action!\n", + "INFO:root:veh 6 at t = 61 returns None action!\n", + "INFO:root:veh 54 at t = 32 returns None action!\n", + "INFO:root:veh 54 at t = 33 returns None action!\n", + "INFO:root:veh 54 at t = 37 returns None action!\n", + "INFO:root:veh 54 at t = 38 returns None action!\n", + "INFO:root:veh 54 at t = 39 returns None action!\n", + "INFO:root:veh 54 at t = 40 returns None action!\n", + "INFO:root:veh 54 at t = 41 returns None action!\n", + "INFO:root:veh 54 at t = 42 returns None action!\n", + "INFO:root:veh 54 at t = 53 returns None action!\n", + "INFO:root:veh 54 at t = 54 returns None action!\n", + "INFO:root:veh 54 at t = 55 returns None action!\n", + "INFO:root:veh 54 at t = 56 returns None action!\n", + "INFO:root:veh 56 at t = 63 returns None action!\n", + "INFO:root:veh 56 at t = 64 returns None action!\n", + "INFO:root:veh 56 at t = 65 returns None action!\n", + "INFO:root:veh 56 at t = 66 returns None action!\n", + "INFO:root:veh 56 at t = 67 returns None action!\n", + "INFO:root:veh 56 at t = 68 returns None action!\n", + "INFO:root:veh 56 at t = 70 returns None action!\n", + "INFO:root:veh 56 at t = 71 returns None action!\n", + "INFO:root:veh 56 at t = 72 returns None action!\n", + "INFO:root:veh 56 at t = 73 returns None action!\n", + "INFO:root:veh 56 at t = 74 returns None action!\n", + "INFO:root:veh 56 at t = 75 returns None action!\n", + "INFO:root:veh 11 at t = 37 returns None action!\n", + "INFO:root:veh 11 at t = 38 returns None action!\n", + "INFO:root:veh 10 at t = 33 returns None action!\n", + "INFO:root:veh 10 at t = 34 returns None action!\n", + "INFO:root:veh 10 at t = 35 returns None action!\n", + "INFO:root:veh 10 at t = 36 returns None action!\n", + "INFO:root:veh 86 at t = 4 returns None action!\n", + "INFO:root:veh 86 at t = 5 returns None action!\n", + "INFO:root:veh 86 at t = 6 returns None action!\n", + "INFO:root:veh 86 at t = 7 returns None action!\n", + "INFO:root:veh 86 at t = 8 returns None action!\n", + "INFO:root:veh 86 at t = 21 returns None action!\n", + "INFO:root:veh 86 at t = 22 returns None action!\n", + "INFO:root:veh 6 at t = 18 returns None action!\n", + "INFO:root:veh 6 at t = 19 returns None action!\n", + "INFO:root:veh 31 at t = 54 returns None action!\n", + "INFO:root:veh 31 at t = 55 returns None action!\n", + "INFO:root:veh 36 at t = 72 returns None action!\n", + "INFO:root:veh 36 at t = 73 returns None action!\n", + "INFO:root:veh 4 at t = 19 returns None action!\n", + "INFO:root:veh 4 at t = 20 returns None action!\n", + "INFO:root:veh 54 at t = 67 returns None action!\n", + "INFO:root:veh 54 at t = 68 returns None action!\n", + "INFO:root:veh 54 at t = 69 returns None action!\n", + "INFO:root:veh 3 at t = 4 returns None action!\n", + "INFO:root:veh 3 at t = 5 returns None action!\n", + "INFO:root:veh 45 at t = 44 returns None action!\n", + "INFO:root:veh 45 at t = 45 returns None action!\n", + "INFO:root:veh 45 at t = 46 returns None action!\n", + "INFO:root:veh 45 at t = 47 returns None action!\n", + "INFO:root:veh 48 at t = 64 returns None action!\n", + "INFO:root:veh 48 at t = 65 returns None action!\n", + "INFO:root:veh 48 at t = 68 returns None action!\n", + "INFO:root:veh 48 at t = 69 returns None action!\n", + "INFO:root:veh 45 at t = 78 returns None action!\n", + "INFO:root:veh 17 at t = 37 returns None action!\n", + "INFO:root:veh 17 at t = 38 returns None action!\n", + "INFO:root:veh 17 at t = 39 returns None action!\n", + "INFO:root:veh 17 at t = 40 returns None action!\n", + "INFO:root:veh 17 at t = 41 returns None action!\n", + "INFO:root:veh 44 at t = 0 returns None action!\n", + "INFO:root:veh 44 at t = 1 returns None action!\n", + "INFO:root:veh 3 at t = 17 returns None action!\n", + "INFO:root:veh 3 at t = 18 returns None action!\n", + "INFO:root:veh 3 at t = 19 returns None action!\n", + "INFO:root:veh 3 at t = 20 returns None action!\n", + "INFO:root:veh 44 at t = 28 returns None action!\n", + "INFO:root:veh 44 at t = 29 returns None action!\n", + "INFO:root:veh 3 at t = 3 returns None action!\n", + "INFO:root:veh 3 at t = 4 returns None action!\n", + "INFO:root:veh 22 at t = 23 returns None action!\n", + "INFO:root:veh 22 at t = 24 returns None action!\n", + "INFO:root:veh 22 at t = 28 returns None action!\n", + "INFO:root:veh 22 at t = 29 returns None action!\n", + "INFO:root:veh 22 at t = 30 returns None action!\n", + "INFO:root:veh 22 at t = 31 returns None action!\n", + "INFO:root:veh 22 at t = 34 returns None action!\n", + "INFO:root:veh 22 at t = 35 returns None action!\n", + "INFO:root:veh 22 at t = 37 returns None action!\n", + "INFO:root:veh 22 at t = 38 returns None action!\n", + "INFO:root:veh 22 at t = 39 returns None action!\n", + "INFO:root:veh 22 at t = 40 returns None action!\n", + "INFO:root:veh 22 at t = 41 returns None action!\n", + "INFO:root:veh 22 at t = 42 returns None action!\n", + "INFO:root:veh 22 at t = 43 returns None action!\n", + "INFO:root:veh 22 at t = 47 returns None action!\n", + "INFO:root:veh 22 at t = 48 returns None action!\n", + "INFO:root:veh 94 at t = 24 returns None action!\n", + "INFO:root:veh 94 at t = 25 returns None action!\n", + "INFO:root:veh 94 at t = 26 returns None action!\n", + "INFO:root:veh 9 at t = 24 returns None action!\n", + "INFO:root:veh 9 at t = 25 returns None action!\n", + "INFO:root:veh 59 at t = 60 returns None action!\n", + "INFO:root:veh 59 at t = 61 returns None action!\n", + "INFO:root:veh 59 at t = 62 returns None action!\n", + "INFO:root:veh 59 at t = 63 returns None action!\n", + "INFO:root:veh 59 at t = 64 returns None action!\n", + "INFO:root:veh 59 at t = 67 returns None action!\n", + "INFO:root:veh 59 at t = 68 returns None action!\n", + "INFO:root:veh 59 at t = 69 returns None action!\n", + "INFO:root:veh 44 at t = 0 returns None action!\n", + "INFO:root:veh 9 at t = 12 returns None action!\n", + "INFO:root:veh 9 at t = 13 returns None action!\n", + "INFO:root:veh 9 at t = 14 returns None action!\n", + "INFO:root:veh 9 at t = 15 returns None action!\n", + "INFO:root:veh 9 at t = 16 returns None action!\n", + "INFO:root:veh 9 at t = 17 returns None action!\n", + "INFO:root:veh 9 at t = 18 returns None action!\n", + "INFO:root:veh 109 at t = 0 returns None action!\n", + "INFO:root:veh 109 at t = 1 returns None action!\n", + "INFO:root:veh 109 at t = 2 returns None action!\n", + "INFO:root:veh 109 at t = 3 returns None action!\n", + "INFO:root:veh 109 at t = 19 returns None action!\n", + "INFO:root:veh 109 at t = 20 returns None action!\n", + "INFO:root:veh 109 at t = 21 returns None action!\n", + "INFO:root:veh 109 at t = 22 returns None action!\n", + "INFO:root:veh 109 at t = 30 returns None action!\n", + "INFO:root:veh 109 at t = 31 returns None action!\n", + "INFO:root:veh 109 at t = 32 returns None action!\n", + "INFO:root:veh 109 at t = 63 returns None action!\n", + "INFO:root:veh 109 at t = 64 returns None action!\n", + "INFO:root:veh 109 at t = 66 returns None action!\n", + "INFO:root:veh 109 at t = 67 returns None action!\n", + "INFO:root:veh 109 at t = 78 returns None action!\n", + "INFO:root:veh 3 at t = 75 returns None action!\n", + "INFO:root:veh 3 at t = 76 returns None action!\n", + "INFO:root:veh 10 at t = 5 returns None action!\n", + "INFO:root:veh 10 at t = 6 returns None action!\n", + "INFO:root:veh 4 at t = 11 returns None action!\n", + "INFO:root:veh 4 at t = 12 returns None action!\n", + "INFO:root:veh 1 at t = 15 returns None action!\n", + "INFO:root:veh 1 at t = 16 returns None action!\n", + "INFO:root:veh 1 at t = 17 returns None action!\n", + "INFO:root:veh 1 at t = 18 returns None action!\n", + "INFO:root:veh 5 at t = 18 returns None action!\n", + "INFO:root:veh 51 at t = 65 returns None action!\n", + "INFO:root:veh 51 at t = 66 returns None action!\n", + "INFO:root:veh 48 at t = 67 returns None action!\n", + "INFO:root:veh 48 at t = 68 returns None action!\n", + "INFO:root:veh 2 at t = 6 returns None action!\n", + "INFO:root:veh 2 at t = 7 returns None action!\n", + "INFO:root:veh 2 at t = 22 returns None action!\n", + "INFO:root:veh 2 at t = 23 returns None action!\n", + "INFO:root:veh 3 at t = 22 returns None action!\n", + "INFO:root:veh 3 at t = 23 returns None action!\n", + "INFO:root:veh 3 at t = 24 returns None action!\n", + "INFO:root:veh 70 at t = 73 returns None action!\n", + "INFO:root:veh 70 at t = 74 returns None action!\n", + "INFO:root:veh 70 at t = 75 returns None action!\n", + "INFO:root:veh 55 at t = 58 returns None action!\n", + "INFO:root:veh 55 at t = 59 returns None action!\n", + "INFO:root:veh 55 at t = 60 returns None action!\n", + "INFO:root:veh 55 at t = 61 returns None action!\n", + "INFO:root:veh 13 at t = 11 returns None action!\n", + "INFO:root:veh 13 at t = 12 returns None action!\n", + "INFO:root:veh 13 at t = 36 returns None action!\n", + "INFO:root:veh 13 at t = 37 returns None action!\n", + "INFO:root:veh 84 at t = 1 returns None action!\n", + "INFO:root:veh 84 at t = 2 returns None action!\n", + "INFO:root:veh 10 at t = 15 returns None action!\n", + "INFO:root:veh 10 at t = 16 returns None action!\n", + "INFO:root:veh 10 at t = 17 returns None action!\n", + "INFO:root:veh 9 at t = 20 returns None action!\n", + "INFO:root:veh 9 at t = 21 returns None action!\n", + "INFO:root:veh 9 at t = 22 returns None action!\n", + "INFO:root:veh 5 at t = 40 returns None action!\n", + "INFO:root:veh 5 at t = 41 returns None action!\n", + "INFO:root:veh 5 at t = 44 returns None action!\n", + "INFO:root:veh 5 at t = 45 returns None action!\n", + "INFO:root:veh 5 at t = 46 returns None action!\n", + "INFO:root:veh 5 at t = 47 returns None action!\n", + "INFO:root:veh 5 at t = 50 returns None action!\n", + "INFO:root:veh 5 at t = 51 returns None action!\n", + "INFO:root:veh 5 at t = 52 returns None action!\n", + "INFO:root:veh 5 at t = 53 returns None action!\n", + "INFO:root:veh 5 at t = 54 returns None action!\n", + "INFO:root:veh 5 at t = 55 returns None action!\n", + "INFO:root:veh 5 at t = 56 returns None action!\n", + "INFO:root:veh 19 at t = 10 returns None action!\n", + "INFO:root:veh 19 at t = 11 returns None action!\n", + "INFO:root:veh 19 at t = 12 returns None action!\n", + "INFO:root:veh 3 at t = 58 returns None action!\n", + "INFO:root:veh 3 at t = 59 returns None action!\n", + "INFO:root:veh 3 at t = 60 returns None action!\n", + "INFO:root:veh 25 at t = 12 returns None action!\n", + "INFO:root:veh 25 at t = 13 returns None action!\n", + "INFO:root:veh 25 at t = 16 returns None action!\n", + "INFO:root:veh 15 at t = 3 returns None action!\n", + "INFO:root:veh 15 at t = 4 returns None action!\n", + "INFO:root:veh 15 at t = 5 returns None action!\n", + "INFO:root:veh 7 at t = 5 returns None action!\n", + "INFO:root:veh 15 at t = 6 returns None action!\n", + "INFO:root:veh 94 at t = 2 returns None action!\n", + "INFO:root:veh 94 at t = 3 returns None action!\n", + "INFO:root:veh 94 at t = 4 returns None action!\n", + "INFO:root:veh 88 at t = 5 returns None action!\n", + "INFO:root:veh 88 at t = 6 returns None action!\n", + "INFO:root:veh 88 at t = 7 returns None action!\n", + "INFO:root:veh 88 at t = 8 returns None action!\n", + "INFO:root:veh 88 at t = 22 returns None action!\n", + "INFO:root:veh 88 at t = 23 returns None action!\n", + "INFO:root:veh 10 at t = 22 returns None action!\n", + "INFO:root:veh 10 at t = 23 returns None action!\n", + "INFO:root:veh 22 at t = 2 returns None action!\n", + "INFO:root:veh 22 at t = 3 returns None action!\n", + "INFO:root:veh 34 at t = 0 returns None action!\n", + "INFO:root:veh 34 at t = 1 returns None action!\n", + "INFO:root:veh 32 at t = 7 returns None action!\n", + "INFO:root:veh 32 at t = 8 returns None action!\n", + "INFO:root:veh 28 at t = 8 returns None action!\n", + "INFO:root:veh 32 at t = 9 returns None action!\n", + "INFO:root:veh 28 at t = 9 returns None action!\n", + "INFO:root:veh 21 at t = 67 returns None action!\n", + "INFO:root:veh 21 at t = 68 returns None action!\n", + "INFO:root:veh 21 at t = 69 returns None action!\n", + "INFO:root:veh 21 at t = 70 returns None action!\n", + "INFO:root:veh 21 at t = 71 returns None action!\n", + "INFO:root:veh 21 at t = 1 returns None action!\n", + "INFO:root:veh 21 at t = 2 returns None action!\n", + "INFO:root:veh 21 at t = 3 returns None action!\n", + "INFO:root:veh 20 at t = 32 returns None action!\n", + "INFO:root:veh 20 at t = 33 returns None action!\n", + "INFO:root:veh 20 at t = 34 returns None action!\n", + "INFO:root:veh 71 at t = 13 returns None action!\n", + "INFO:root:veh 71 at t = 14 returns None action!\n", + "INFO:root:veh 71 at t = 15 returns None action!\n", + "INFO:root:veh 4 at t = 20 returns None action!\n", + "INFO:root:veh 4 at t = 21 returns None action!\n", + "INFO:root:veh 19 at t = 46 returns None action!\n", + "INFO:root:veh 19 at t = 47 returns None action!\n", + "INFO:root:veh 31 at t = 6 returns None action!\n", + "INFO:root:veh 31 at t = 7 returns None action!\n", + "INFO:root:veh 31 at t = 8 returns None action!\n", + "INFO:root:veh 26 at t = 50 returns None action!\n", + "INFO:root:veh 26 at t = 51 returns None action!\n", + "INFO:root:veh 15 at t = 58 returns None action!\n", + "INFO:root:veh 22 at t = 4 returns None action!\n", + "INFO:root:veh 22 at t = 5 returns None action!\n", + "INFO:root:veh 6 at t = 31 returns None action!\n", + "INFO:root:veh 6 at t = 32 returns None action!\n", + "INFO:root:veh 6 at t = 33 returns None action!\n", + "INFO:root:veh 6 at t = 34 returns None action!\n", + "INFO:root:veh 6 at t = 35 returns None action!\n", + "INFO:root:veh 15 at t = 16 returns None action!\n", + "INFO:root:veh 15 at t = 17 returns None action!\n", + "INFO:root:veh 15 at t = 18 returns None action!\n", + "INFO:root:veh 15 at t = 19 returns None action!\n", + "INFO:root:veh 50 at t = 1 returns None action!\n", + "INFO:root:veh 50 at t = 2 returns None action!\n", + "INFO:root:veh 50 at t = 5 returns None action!\n", + "INFO:root:veh 50 at t = 6 returns None action!\n", + "INFO:root:veh 26 at t = 7 returns None action!\n", + "INFO:root:veh 26 at t = 8 returns None action!\n", + "INFO:root:veh 41 at t = 41 returns None action!\n", + "INFO:root:veh 41 at t = 42 returns None action!\n", + "INFO:root:veh 41 at t = 43 returns None action!\n", + "INFO:root:veh 15 at t = 4 returns None action!\n", + "INFO:root:veh 15 at t = 5 returns None action!\n", + "INFO:root:veh 5 at t = 41 returns None action!\n", + "INFO:root:veh 5 at t = 42 returns None action!\n", + "INFO:root:veh 5 at t = 44 returns None action!\n", + "INFO:root:veh 5 at t = 45 returns None action!\n", + "INFO:root:veh 5 at t = 46 returns None action!\n", + "INFO:root:veh 5 at t = 47 returns None action!\n", + "INFO:root:veh 0 at t = 34 returns None action!\n", + "INFO:root:veh 0 at t = 35 returns None action!\n", + "INFO:root:veh 0 at t = 36 returns None action!\n", + "INFO:root:veh 8 at t = 37 returns None action!\n", + "INFO:root:veh 8 at t = 38 returns None action!\n", + "INFO:root:veh 8 at t = 39 returns None action!\n", + "INFO:root:veh 5 at t = 42 returns None action!\n", + "INFO:root:veh 5 at t = 43 returns None action!\n", + "INFO:root:veh 5 at t = 44 returns None action!\n", + "INFO:root:veh 5 at t = 45 returns None action!\n", + "INFO:root:veh 7 at t = 48 returns None action!\n", + "INFO:root:veh 7 at t = 49 returns None action!\n", + "INFO:root:veh 7 at t = 52 returns None action!\n", + "INFO:root:veh 7 at t = 53 returns None action!\n", + "INFO:root:veh 8 at t = 64 returns None action!\n", + "INFO:root:veh 8 at t = 65 returns None action!\n", + "INFO:root:veh 8 at t = 66 returns None action!\n", + "INFO:root:veh 21 at t = 6 returns None action!\n", + "INFO:root:veh 21 at t = 7 returns None action!\n", + "INFO:root:veh 21 at t = 8 returns None action!\n", + "INFO:root:veh 10 at t = 10 returns None action!\n", + "INFO:root:veh 10 at t = 11 returns None action!\n", + "INFO:root:veh 10 at t = 12 returns None action!\n", + "INFO:root:veh 27 at t = 21 returns None action!\n", + "INFO:root:veh 27 at t = 22 returns None action!\n", + "INFO:root:veh 27 at t = 23 returns None action!\n", + "INFO:root:veh 6 at t = 61 returns None action!\n", + "INFO:root:veh 6 at t = 62 returns None action!\n", + "INFO:root:veh 6 at t = 63 returns None action!\n", + "INFO:root:veh 6 at t = 69 returns None action!\n", + "INFO:root:veh 6 at t = 70 returns None action!\n", + "INFO:root:veh 6 at t = 71 returns None action!\n", + "INFO:root:veh 6 at t = 72 returns None action!\n", + "INFO:root:veh 6 at t = 73 returns None action!\n", + "INFO:root:veh 6 at t = 74 returns None action!\n", + "INFO:root:veh 42 at t = 20 returns None action!\n", + "INFO:root:veh 42 at t = 21 returns None action!\n", + "INFO:root:veh 42 at t = 22 returns None action!\n", + "INFO:root:veh 42 at t = 23 returns None action!\n", + "INFO:root:veh 42 at t = 65 returns None action!\n", + "INFO:root:veh 42 at t = 66 returns None action!\n", + "INFO:root:veh 42 at t = 78 returns None action!\n", + "INFO:root:veh 42 at t = 79 returns None action!\n", + "INFO:root:veh 3 at t = 0 returns None action!\n", + "INFO:root:veh 3 at t = 1 returns None action!\n", + "INFO:root:veh 3 at t = 2 returns None action!\n", + "INFO:root:veh 3 at t = 3 returns None action!\n", + "INFO:root:veh 0 at t = 1 returns None action!\n", + "INFO:root:veh 0 at t = 2 returns None action!\n", + "INFO:root:veh 17 at t = 7 returns None action!\n", + "INFO:root:veh 17 at t = 8 returns None action!\n", + "INFO:root:veh 17 at t = 10 returns None action!\n", + "INFO:root:veh 17 at t = 11 returns None action!\n", + "INFO:root:veh 17 at t = 12 returns None action!\n", + "INFO:root:veh 17 at t = 13 returns None action!\n", + "INFO:root:veh 17 at t = 32 returns None action!\n", + "INFO:root:veh 17 at t = 33 returns None action!\n", + "INFO:root:veh 7 at t = 7 returns None action!\n", + "INFO:root:veh 7 at t = 8 returns None action!\n", + "INFO:root:veh 1 at t = 36 returns None action!\n", + "INFO:root:veh 1 at t = 37 returns None action!\n", + "INFO:root:veh 11 at t = 77 returns None action!\n", + "INFO:root:veh 11 at t = 78 returns None action!\n", + "INFO:root:veh 5 at t = 6 returns None action!\n", + "INFO:root:veh 5 at t = 7 returns None action!\n", + "INFO:root:veh 5 at t = 8 returns None action!\n", + "INFO:root:veh 5 at t = 10 returns None action!\n", + "INFO:root:veh 5 at t = 11 returns None action!\n", + "INFO:root:veh 5 at t = 12 returns None action!\n", + "INFO:root:veh 5 at t = 13 returns None action!\n", + "INFO:root:veh 5 at t = 14 returns None action!\n" + ] + } + ], "source": [ "# Load trained human reference policy\n", "human_policy = load_policy(\n", @@ -169,7 +716,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -411,7 +958,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -12571,7 +13118,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12580,7 +13127,10 @@ "# Scenes on which to evaluate the models\n", "rl_policy_paths = glob.glob(f\"{RL_POLICY_PATH}\" + \"/*.pt\")\n", "rl_policy_names = [os.path.basename(file)[:-3] for file in rl_policy_paths]\n", - "reg_weights = [0.025, 0.0]" + "reg_weights = [0.0]\n", + "\n", + "\n", + "rl_policy_names" ] }, { diff --git a/experiments/hr_rl/run_hr_ppo.py b/experiments/hr_rl/run_hr_ppo.py index 0868c59d..9b1788a9 100644 --- a/experiments/hr_rl/run_hr_ppo.py +++ b/experiments/hr_rl/run_hr_ppo.py @@ -143,10 +143,11 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl } ) - lambdas = [0.0] - for lam in lambdas: + num_files_list = [10, 100, 1000] + + for scenes in num_files_list: # Set regularization weight - exp_config.reg_weight = lam + env_config.num_files = scenes # Train train( diff --git a/experiments/hr_rl/run_hr_ppo_cli.py b/experiments/hr_rl/run_hr_ppo_cli.py index abff0554..34ca1dfd 100644 --- a/experiments/hr_rl/run_hr_ppo_cli.py +++ b/experiments/hr_rl/run_hr_ppo_cli.py @@ -43,7 +43,7 @@ def run_hr_ppo( - sweep_name: str = "hr_ppo", + sweep_name: str = exp_config.group, steer_disc: int = 5, accel_disc: int = 5, ent_coef: float = 0.0, @@ -159,6 +159,7 @@ def run_hr_ppo( seed=exp_config.seed, # Seed for the pseudo random generators verbose=exp_config.verbose, tensorboard_log=f"runs/{run_id}" if run_id is not None else None, + device=exp_config.ppo.device, env_config=env_config, mlp_class=LateFusionMLP, diff --git a/experiments/il/run_behavioral_cloning.py b/experiments/il/run_behavioral_cloning.py index 86319163..2812a9de 100644 --- a/experiments/il/run_behavioral_cloning.py +++ b/experiments/il/run_behavioral_cloning.py @@ -7,6 +7,7 @@ from imitation.algorithms import bc from imitation.data.types import Transitions from imitation.util import logger as imit_logger +from networks.mlp_late_fusion import LateFusionMLP, LateFusionMLPPolicy from utils.wrappers import LightNocturneEnvWrapper from utils.config import load_config @@ -16,17 +17,22 @@ from utils.string_utils import date_to_str if __name__ == "__main__": + + MAX_EVAL_FILES = 12 + NUM_TRAIN_FILES = 1000 # Create run run = wandb.init( project="eval_il_policy", sync_tensorboard=True, + group=f"BC_S{NUM_TRAIN_FILES}", ) # Configs video_config = load_config("video_config") - env_config = load_config("env_config") bc_config = load_config("bc_config") + env_config = load_config("env_config") + env_config.num_files = NUM_TRAIN_FILES # Device TODO: Add support for CUDA device = "cpu" @@ -63,23 +69,26 @@ demonstrations=transitions, rng=rng, device=device, + #policy=LateFusionMLPPolicy, ) + print(f'IL policy: \n{bc_trainer.policy}') + # Create evaluation env env = LightNocturneEnvWrapper(env_config) - eval_files = env.files - - # Check random behavior - reward_before_training, _ = evaluate_policy( - model=bc_trainer.policy, - env=LightNocturneEnvWrapper(env_config), - n_steps_per_episode=env_config.episode_length, - n_eval_episodes=1, - eval_files=eval_files, - video_config=video_config, - video_caption="BEFORE training", - render=True, - ) + eval_files = env.files[:MAX_EVAL_FILES] + + # # Check random behavior + # reward_before_training, _ = evaluate_policy( + # model=bc_trainer.policy, + # env=LightNocturneEnvWrapper(env_config), + # n_steps_per_episode=env_config.episode_length, + # n_eval_episodes=1, + # eval_files=eval_files, + # video_config=video_config, + # video_caption="BEFORE training", + # render=True, + # ) # Train bc_trainer.train( @@ -105,4 +114,4 @@ date_ = date_to_str(datetime.now()) if bc_config.save_model: - bc_trainer.policy.save(path=f"{bc_config.save_model_path}{bc_config.model_name}_{date_}.pt") \ No newline at end of file + bc_trainer.policy.save(path=f"{bc_config.save_model_path}{bc_config.model_name}_S{NUM_TRAIN_FILES}_{date_}.pt") \ No newline at end of file diff --git a/models/il/human_policy_2023_11_20.pt b/models/il/human_policy_2023_11_20.pt deleted file mode 100644 index 9f950824..00000000 Binary files a/models/il/human_policy_2023_11_20.pt and /dev/null differ diff --git a/networks/mlp_late_fusion.py b/networks/mlp_late_fusion.py index ed9a7db3..aed85e0b 100644 --- a/networks/mlp_late_fusion.py +++ b/networks/mlp_late_fusion.py @@ -245,7 +245,6 @@ def _build_mlp_extractor(self) -> None: **self.mlp_config, ) - if __name__ == "__main__": # Load config @@ -263,6 +262,21 @@ def _build_mlp_extractor(self) -> None: obs = env.reset() obs = torch.Tensor(obs)[:2] + + # Define model architecture + model_config = Box( + { + "arch_ego_state": [8], + "arch_road_objects": [64], + "arch_road_graph": [128, 64], + "arch_shared_net": [], + "act_func": "tanh", + "dropout": 0.0, + "last_layer_dim_pi": 64, + "last_layer_dim_vf": 64, + } + ) + # Test model = RegularizedPPO( reg_policy=None, @@ -275,6 +289,9 @@ def _build_mlp_extractor(self) -> None: seed=exp_config.seed, # Seed for the pseudo random generators verbose=1, device='cuda', + env_config=env_config, + mlp_class=LateFusionMLP, + mlp_config=model_config, ) # print(model.policy) model.learn(5000) \ No newline at end of file diff --git a/nocturne/envs/base_env.py b/nocturne/envs/base_env.py index 173e38ca..ea089b7d 100644 --- a/nocturne/envs/base_env.py +++ b/nocturne/envs/base_env.py @@ -708,18 +708,27 @@ def _position_as_array(position: Vector2D) -> np.ndarray: env = BaseEnv(config=env_config) # Reset - obs_dict = env.reset() + obs_dict = env.reset(filename='tfrecord-00421-of-01000_364.json') # Get info agent_ids = [agent_id for agent_id in obs_dict.keys()] + veh_objects = {agent.id: agent for agent in env.controlled_vehicles} dead_agent_ids = [] - for step in range(100): + for step in range(80): # Sample actions action_dict = {agent_id: env.action_space.sample() for agent_id in agent_ids if agent_id not in dead_agent_ids} - # Step in env + # Set in expert controlled mode + for obj in env.controlled_vehicles: + obj.expert_control = True + obs_dict, rew_dict, done_dict, info_dict = env.step(action_dict) + + print(f'step: {step}, done: {done_dict[37]}, info:\n {info_dict[37]}') + + expert_action = env.scenario.expert_action(veh_objects[37], step) + print(f'act = {expert_action} \n') # Update dead agents for agent_id, is_done in done_dict.items(): diff --git a/utils/eval.py b/utils/eval.py index 301de528..062a976b 100644 --- a/utils/eval.py +++ b/utils/eval.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import logging +from tqdm import tqdm import torch import wandb import glob @@ -48,6 +49,8 @@ def __init__( def _get_scores(self): """Evaluate policy across a set of traffic scenes.""" + logging.info(f'\n Evaluating policy on {len(self.eval_files)} files...') + # Make tables df_eval = pd.DataFrame( columns=[ @@ -56,12 +59,13 @@ def _get_scores(self): "agents_controlled", "reg_coef", "act_acc", + "accel_val_mae", + "steer_val_mae", "pos_rmse", "speed_mae", "goal_rate", "veh_edge_cr", "veh_veh_cr", - "num_violations", ] ) @@ -81,7 +85,7 @@ def _get_scores(self): ] ) - for file in self.eval_files: + for file in tqdm(self.eval_files): logging.debug(f"Evaluating policy on {file}...") @@ -96,23 +100,23 @@ def _get_scores(self): ) # Filter out invalid steps - nonnan_ids = np.logical_not( - np.logical_or( - np.isnan(policy_actions), - np.isnan(expert_actions), - ) - ) + nonnan_ids = ~np.isnan(expert_actions) + # Compute metrics action_accuracy = self.get_action_accuracy( policy_actions, expert_actions, nonnan_ids ) position_rmse = self.get_pos_rmse( - policy_pos, expert_pos, nonnan_ids + policy_pos, expert_pos, ) speed_agent_mae = self.get_speed_mae( - policy_speed, expert_speed, nonnan_ids + policy_speed, expert_speed, + ) + + abs_diff_accel, abs_diff_steer = self.get_action_val_diff( + policy_actions, expert_actions, ) # Violations of the 3-second rule @@ -125,12 +129,13 @@ def _get_scores(self): "agents_controlled": expert_actions.shape[0], "reg_coef": self.exp_config.reg_weight if self.reg_coef is None else self.reg_coef, "act_acc": action_accuracy, + "accel_val_mae": abs_diff_accel, + "steer_val_mae": abs_diff_steer, "pos_rmse": position_rmse, "speed_mae": speed_agent_mae, "goal_rate": policy_gr, "veh_edge_cr": policy_edge_cr, "veh_veh_cr": policy_veh_cr, - "num_violations": num_violations, } df_eval.loc[len(df_eval)] = scene_perf @@ -218,7 +223,9 @@ def _step_through_scene(self, filename: str, mode: str): agent_speed[veh_idx, timestep] = veh_obj.speed action_indices[veh_idx, timestep] = action_idx else: - logging.info(f'veh {veh_obj.id} at t = {timestep} returns None action!') + # Skip None actions (these are invalid) + logging.debug(f'veh {veh_obj.id} at t = {timestep} returns None action!') + continue action_dict = {} @@ -269,42 +276,67 @@ def _step_through_scene(self, filename: str, mode: str): veh_edge_collision/self.num_agents, veh_veh_collision/self.num_agents, ) - - def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids): - """Get accuracy of agent actions. + + def get_action_val_diff(self, pred_actions, expert_actions): + """Get difference between human action values and predicted action values. Args: pred_actions: (num_agents, num_steps_per_episode) the predicted actions of the agents. expert_actions: (num_agents, num_steps_per_episode) the expert actions of the agents. nonnan_ids: (num_agents, num_steps_per_episode) the indices of non-nan actions. """ - return (expert_actions[nonnan_ids] == pred_actions[nonnan_ids]).sum() / nonnan_ids.flatten().shape[0] + # Filter out invalid actions + nonnan_ids = np.logical_not( + np.logical_or( + np.isnan(pred_actions), + np.isnan(expert_actions), + ) + ) + valid_expert_acts = expert_actions[nonnan_ids] + valid_pred_acts = pred_actions[nonnan_ids] - def get_pos_rmse(self, pred_actions, expert_actions, nonnan_ids): - return np.sqrt(np.linalg.norm(pred_actions[nonnan_ids] - expert_actions[nonnan_ids])).mean() - - def get_speed_mae(self, pred_actions, expert_actions, nonnan_ids): - return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean() - - def get_steer_mae(self, pred_actions, expert_actions, nonnan_ids): - return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean() + exp_acc_vals, exp_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts) + pred_acc_vals, pred_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts) + + for idx in range(valid_expert_acts.shape[0]): + + # Get expert and predicted values + exp_acc_vals[idx], exp_steer_vals[idx] = self.env.idx_to_actions[valid_expert_acts[idx]] + pred_acc_vals[idx], pred_steer_vals[idx] = self.env.idx_to_actions[valid_pred_acts[idx]] + + # Get mean absolute difference + abs_accel_diff = np.abs(exp_acc_vals - pred_acc_vals).mean() + abs_steer_diff = np.abs(exp_steer_vals - pred_steer_vals).mean() + + return abs_accel_diff, abs_steer_diff - def get_action_abs_distance(self, pred_actions, expert_actions, nonnan_ids, action_space_dim): + def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids): """Get accuracy of agent actions. Args: pred_actions: (num_agents, num_steps_per_episode) the predicted actions of the agents. expert_actions: (num_agents, num_steps_per_episode) the expert actions of the agents. nonnan_ids: (num_agents, num_steps_per_episode) the indices of non-nan actions. """ + return (expert_actions[nonnan_ids] == pred_actions[nonnan_ids]).sum() / nonnan_ids.flatten().shape[0] - num_agents = pred_actions.shape[0] - agg_abs_dist = 0 - - for idx in range(pred_actions.shape[0]): - n_samples = pred_actions[0][nonnan_ids[0]].shape[0] - agent_abs_dist = np.abs(pred_actions[idx][nonnan_ids[idx]] - expert_actions[idx][nonnan_ids[idx]]).sum() / n_samples - agg_abs_dist += agent_abs_dist - - return agg_abs_dist / num_agents + def get_pos_rmse(self, pred_actions, expert_actions): + # Filter out invalid actions + nonnan_ids = np.logical_not( + np.logical_or( + np.isnan(pred_actions), + np.isnan(expert_actions), + ) + ) + return np.sqrt(np.linalg.norm(pred_actions[nonnan_ids] - expert_actions[nonnan_ids])).mean() + + def get_speed_mae(self, pred_actions, expert_actions): + # Filter out invalid actions + nonnan_ids = np.logical_not( + np.logical_or( + np.isnan(pred_actions), + np.isnan(expert_actions), + ) + ) + return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean() def get_veh_to_veh_distances(self, positions, velocities, time_gap_in_sec=3): """Calculate distances between vehicles at each time step and track @@ -367,29 +399,11 @@ def _get_files(self, eval_files, file_limit): env_config = load_config("env_config") exp_config = load_config("exp_config") - # env_config.data_path = "./data_10/train" - - # # Load trained human reference policy - # human_policy = load_policy( - # data_path="./models/il", - # file_name="human_policy_10_scenes_2023_11_21", - # ) - - # # Evaluate policy - # evaluator = EvaluatePolicy( - # env_config=env_config, - # exp_config=exp_config, - # policy=human_policy, - # log_to_wandb=False, - # deterministic=True, - # reg_coef=0.0, - # return_trajectories=True, - # ) - - # il_results_check = evaluator._get_scores() + MAX_FILES = 50 - # Set data path - env_config.data_path = "./data/train/" + # Train + train_file_paths = glob.glob(f"{env_config.data_path}" + "/tfrecord*") + train_eval_files = [os.path.basename(file) for file in train_file_paths][:MAX_FILES] # Load human reference policy human_policy = load_policy( @@ -402,7 +416,7 @@ def _get_files(self, eval_files, file_limit): env_config=env_config, exp_config=exp_config, policy=human_policy, - eval_files=["tfrecord-00012-of-01000_389.json"], + eval_files=train_eval_files, log_to_wandb=False, deterministic=True, reg_coef=0.0, diff --git a/utils/imitation_learning/waymo_iterator.py b/utils/imitation_learning/waymo_iterator.py index cde7eca8..a7184711 100644 --- a/utils/imitation_learning/waymo_iterator.py +++ b/utils/imitation_learning/waymo_iterator.py @@ -19,7 +19,7 @@ class TrajectoryIterator(IterableDataset): """Generates trajectories in Waymo scenes: sequences of observations and actions.""" - def __init__(self, data_path, env_config, with_replacement=True, file_limit=None): + def __init__(self, data_path, env_config, with_replacement=True, file_limit=-1): self.data_path = data_path self.config = env_config self.env = BaseEnv(env_config) @@ -32,7 +32,7 @@ def __init__(self, data_path, env_config, with_replacement=True, file_limit=None super(TrajectoryIterator).__init__() - logging.info(f"Using {len(self.file_names)} file(s): {self.file_names}") + logging.info(f"Using {len(self.file_names)} file(s)") def __iter__(self): """Return an (expert_state, expert_action) iterable.""" @@ -93,7 +93,7 @@ def _discretize_expert_actions(self, filename: str): # Get (continuous) expert action expert_action = scenario.expert_action(veh_obj, timestep) - # Check for invalid action (because no value available for taking + # Check for invalid actions (None) (because no value available for taking # derivative) or because the vehicle is at an invalid state if expert_action is None: continue @@ -106,6 +106,9 @@ def _discretize_expert_actions(self, filename: str): expert_action_idx = self.actions_to_joint_idx[accel_grid_val, steering_grid_val][0] + if expert_action_idx is None: + logging.debug("Expert action is None!") + # Store if timestep >= self.config.warmup_period: df_actions.loc[timestep-self.config.warmup_period][veh_obj.getID()] = expert_action_idx @@ -229,6 +232,7 @@ def _find_closest_index(self, action_grid, action): if __name__ == "__main__": env_config = load_config("env_config") + env_config.num_files = 1000 # Create iterator waymo_iterator = TrajectoryIterator( @@ -241,7 +245,11 @@ def _find_closest_index(self, action_grid, action): rollouts = next(iter( DataLoader( waymo_iterator, - batch_size=400, + batch_size=10_000, # Number of samples to generate pin_memory=True, ))) + + obs, acts, next_obs, dones = rollouts + + print('hi') diff --git a/utils/manage_models.py b/utils/manage_models.py index 356f7c1c..f2433cf8 100644 --- a/utils/manage_models.py +++ b/utils/manage_models.py @@ -9,7 +9,7 @@ entity = "daphnecor" project = "scaling_ppo" - collection_name = "nocturne-hr-ppo-12_29_07_55" + collection_name = "nocturne-hr-ppo-12_30_21_43" # Always initialize a W&B run to start tracking wandb.init()