diff --git a/configs/exp_config.yaml b/configs/exp_config.yaml index 91946233..2407da24 100644 --- a/configs/exp_config.yaml +++ b/configs/exp_config.yaml @@ -1,5 +1,5 @@ project: scaling_ppo -group: playground +group: effect_of_human_reg env_id: Nocturne seed: 42 track_wandb: true @@ -19,10 +19,10 @@ ma_callback: log_indiv_metrics: false log_agent_actions: false save_model: true - model_save_freq: 100 # In iterations (one iter ~ (num_agents x n_steps)) + model_save_freq: 300 # In iterations (one iter ~ (num_agents x n_steps)) save_video: true record_n_scenes: 10 # Number of different scenes to render - video_save_freq: 100 # Make a video every k iterations (100 iters ~ 1M steps) + video_save_freq: 500 # Make a video every k iterations (100 iters ~ 1M steps) video_deterministic: true ppo: @@ -32,7 +32,7 @@ ppo: vf_coef: 0.5 # Default in SB3 is 0.5 learn: - total_timesteps: 2_000_000 + total_timesteps: 10_000_000 progress_bar: false # human-regularized RL diff --git a/configs/video_config.yaml b/configs/video_config.yaml index ddb7bcd8..b33096f7 100644 --- a/configs/video_config.yaml +++ b/configs/video_config.yaml @@ -6,5 +6,5 @@ render: logging: render_interval: 3 - fps: 4 + fps: 2 where_am_i: headless_machine \ No newline at end of file diff --git a/evaluation/policy_performance_analysis.ipynb b/evaluation/policy_performance_analysis.ipynb index 50e6bf82..07d01e7d 100644 --- a/evaluation/policy_performance_analysis.ipynb +++ b/evaluation/policy_performance_analysis.ipynb @@ -18,13 +18,12 @@ "- **Accuracy** to the expert actions. \n", "- **Euclidean distance** to the expert positions at a given state and timepoint (the L2 distance between the controlled object's XY\n", " location and its position in the logged history at the same timestep.) `np.linalg.norm(object_xy - log_xy, axis=-1)`\n", - "- **Safe distance through the 3-second rule**\n", - "- **Mean absolute error to the expert speed**" + "- **Mean absolute error to the expert speed, acceleration and steering wheel angle**" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +36,10 @@ "import torch\n", "import logging\n", "import os\n", + "import wandb\n", "import matplotlib.pyplot as plt\n", + "from utils.evaluation import evaluate_policy\n", + "from utils.wrappers import LightNocturneEnvWrapper\n", "\n", "from typing import Callable\n", "from gym import spaces\n", @@ -52,7 +54,9 @@ "sns.set('notebook', font_scale=1.1, rc={'figure.figsize': (10, 5)})\n", "sns.set_style('ticks', rc={'figure.facecolor': 'none', 'axes.facecolor': 'none'})\n", "%config InlineBackend.figure_format = 'svg'\n", - "warnings.filterwarnings(\"ignore\")" + "warnings.filterwarnings(\"ignore\")\n", + "plt.set_loglevel('WARNING')\n", + "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"policy_performance_analysis.ipynb\"" ] }, { @@ -64,27 +68,83 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "MAX_FILES = 50\n", + "MAX_FILES = 100\n", + "DETERMINISTIC = True\n", "\n", "# Load config files\n", "env_config = load_config_nb(\"env_config\")\n", "exp_config = load_config_nb(\"exp_config\")\n", + "video_config = load_config_nb(\"video_config\")\n", "model_config = load_config_nb(\"model_config\")\n", "\n", "# Set data path\n", "env_config.data_path = \"../data_full/train/\"\n", + "env_config.val_data_path = \"../data_full/valid/\"\n", "env_config.num_files = MAX_FILES\n", "\n", "# Logging level set to INFO\n", "LOGGING_LEVEL = \"INFO\"\n", "\n", "# Scenes on which to evaluate the models\n", - "file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n", - "eval_files = [os.path.basename(file) for file in file_paths][:MAX_FILES]" + "# Make sure file order is fixed\n", + "train_file_paths = glob.glob(f\"{env_config.data_path}\" + \"/tfrecord*\")\n", + "train_eval_files = sorted([os.path.basename(file) for file in train_file_paths])\n", + "\n", + "# Valid\n", + "valid_file_paths = glob.glob(f\"{env_config.val_data_path}\" + \"/tfrecord*\")\n", + "valid_eval_files = sorted([os.path.basename(file) for file in valid_file_paths])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def render_traffic_scenes(\n", + " eval_scenes, \n", + " policy, \n", + " run_name,\n", + " deterministic=True, \n", + " group_name=\"RL_S100\", \n", + " project_name=\"eval_hr_rl_policies\"\n", + " ):\n", + "\n", + " if deterministic:\n", + " mode = \"det\"\n", + " else:\n", + " mode = \"stoch\"\n", + "\n", + " # Create run\n", + " run = wandb.init(\n", + " project=project_name,\n", + " sync_tensorboard=True,\n", + " group=group_name,\n", + " name=f\"{run_name}_{mode}\",\n", + " )\n", + "\n", + " avg_rew, std_rew = evaluate_policy(\n", + " model=policy, \n", + " env=LightNocturneEnvWrapper(env_config),\n", + " eval_files=eval_scenes,\n", + " video_config=video_config,\n", + " video_caption=f\"\",\n", + " deterministic=deterministic,\n", + " render=True,\n", + " )\n", + "\n", + " run.finish()" ] }, { @@ -96,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -115,579 +175,16 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "INFO:root:veh 37 at t = 50 returns None action!\n", - "INFO:root:veh 37 at t = 51 returns None action!\n", - "INFO:root:veh 44 at t = 75 returns None action!\n", - "INFO:root:veh 44 at t = 76 returns None action!\n", - "INFO:root:veh 44 at t = 32 returns None action!\n", - "INFO:root:veh 44 at t = 33 returns None action!\n", - "INFO:root:veh 44 at t = 34 returns None action!\n", - "INFO:root:veh 44 at t = 36 returns None action!\n", - "INFO:root:veh 44 at t = 37 returns None action!\n", - "INFO:root:veh 44 at t = 38 returns None action!\n", - "INFO:root:veh 4 at t = 2 returns None action!\n", - "INFO:root:veh 41 at t = 1 returns None action!\n", - "INFO:root:veh 41 at t = 2 returns None action!\n", - "INFO:root:veh 17 at t = 15 returns None action!\n", - "INFO:root:veh 17 at t = 16 returns None action!\n", - "INFO:root:veh 17 at t = 17 returns None action!\n", - "INFO:root:veh 17 at t = 18 returns None action!\n", - "INFO:root:veh 17 at t = 19 returns None action!\n", - "INFO:root:veh 17 at t = 31 returns None action!\n", - "INFO:root:veh 17 at t = 32 returns None action!\n", - "INFO:root:veh 17 at t = 33 returns None action!\n", - "INFO:root:veh 17 at t = 34 returns None action!\n", - "INFO:root:veh 17 at t = 43 returns None action!\n", - "INFO:root:veh 17 at t = 44 returns None action!\n", - "INFO:root:veh 17 at t = 45 returns None action!\n", - "INFO:root:veh 17 at t = 46 returns None action!\n", - "INFO:root:veh 17 at t = 47 returns None action!\n", - "INFO:root:veh 17 at t = 51 returns None action!\n", - "INFO:root:veh 17 at t = 52 returns None action!\n", - "INFO:root:veh 14 at t = 1 returns None action!\n", - "INFO:root:veh 14 at t = 2 returns None action!\n", - "INFO:root:veh 14 at t = 3 returns None action!\n", - "INFO:root:veh 14 at t = 4 returns None action!\n", - "INFO:root:veh 14 at t = 44 returns None action!\n", - "INFO:root:veh 14 at t = 45 returns None action!\n", - "INFO:root:veh 14 at t = 46 returns None action!\n", - "INFO:root:veh 0 at t = 14 returns None action!\n", - "INFO:root:veh 0 at t = 15 returns None action!\n", - "INFO:root:veh 0 at t = 16 returns None action!\n", - "INFO:root:veh 0 at t = 17 returns None action!\n", - "INFO:root:veh 0 at t = 18 returns None action!\n", - "INFO:root:veh 0 at t = 20 returns None action!\n", - "INFO:root:veh 0 at t = 21 returns None action!\n", - "INFO:root:veh 1 at t = 62 returns None action!\n", - "INFO:root:veh 1 at t = 63 returns None action!\n", - "INFO:root:veh 7 at t = 33 returns None action!\n", - "INFO:root:veh 17 at t = 34 returns None action!\n", - "INFO:root:veh 17 at t = 35 returns None action!\n", - "INFO:root:veh 2 at t = 20 returns None action!\n", - "INFO:root:veh 2 at t = 21 returns None action!\n", - "INFO:root:veh 2 at t = 29 returns None action!\n", - "INFO:root:veh 2 at t = 30 returns None action!\n", - "INFO:root:veh 2 at t = 31 returns None action!\n", - "INFO:root:veh 47 at t = 0 returns None action!\n", - "INFO:root:veh 47 at t = 1 returns None action!\n", - "INFO:root:veh 47 at t = 2 returns None action!\n", - "INFO:root:veh 25 at t = 23 returns None action!\n", - "INFO:root:veh 25 at t = 24 returns None action!\n", - "INFO:root:veh 25 at t = 25 returns None action!\n", - "INFO:root:veh 25 at t = 26 returns None action!\n", - "INFO:root:veh 25 at t = 27 returns None action!\n", - "INFO:root:veh 26 at t = 29 returns None action!\n", - "INFO:root:veh 26 at t = 30 returns None action!\n", - "INFO:root:veh 11 at t = 16 returns None action!\n", - "INFO:root:veh 11 at t = 17 returns None action!\n", - "INFO:root:veh 10 at t = 65 returns None action!\n", - "INFO:root:veh 10 at t = 66 returns None action!\n", - "INFO:root:veh 58 at t = 1 returns None action!\n", - "INFO:root:veh 58 at t = 2 returns None action!\n", - "INFO:root:veh 59 at t = 3 returns None action!\n", - "INFO:root:veh 59 at t = 4 returns None action!\n", - "INFO:root:veh 55 at t = 12 returns None action!\n", - "INFO:root:veh 55 at t = 13 returns None action!\n", - "INFO:root:veh 55 at t = 14 returns None action!\n", - "INFO:root:veh 55 at t = 15 returns None action!\n", - "INFO:root:veh 58 at t = 20 returns None action!\n", - "INFO:root:veh 58 at t = 21 returns None action!\n", - "INFO:root:veh 59 at t = 24 returns None action!\n", - "INFO:root:veh 59 at t = 25 returns None action!\n", - "INFO:root:veh 69 at t = 29 returns None action!\n", - "INFO:root:veh 69 at t = 30 returns None action!\n", - "INFO:root:veh 58 at t = 32 returns None action!\n", - "INFO:root:veh 58 at t = 33 returns None action!\n", - "INFO:root:veh 46 at t = 1 returns None action!\n", - "INFO:root:veh 46 at t = 2 returns None action!\n", - "INFO:root:veh 46 at t = 3 returns None action!\n", - "INFO:root:veh 46 at t = 4 returns None action!\n", - "INFO:root:veh 14 at t = 8 returns None action!\n", - "INFO:root:veh 14 at t = 9 returns None action!\n", - "INFO:root:veh 26 at t = 25 returns None action!\n", - "INFO:root:veh 26 at t = 26 returns None action!\n", - "INFO:root:veh 26 at t = 29 returns None action!\n", - "INFO:root:veh 26 at t = 30 returns None action!\n", - "INFO:root:veh 30 at t = 11 returns None action!\n", - "INFO:root:veh 30 at t = 12 returns None action!\n", - "INFO:root:veh 30 at t = 13 returns None action!\n", - "INFO:root:veh 30 at t = 14 returns None action!\n", - "INFO:root:veh 30 at t = 31 returns None action!\n", - "INFO:root:veh 30 at t = 32 returns None action!\n", - "INFO:root:veh 30 at t = 33 returns None action!\n", - "INFO:root:veh 30 at t = 34 returns None action!\n", - "INFO:root:veh 11 at t = 66 returns None action!\n", - "INFO:root:veh 11 at t = 67 returns None action!\n", - "INFO:root:veh 11 at t = 68 returns None action!\n", - "INFO:root:veh 11 at t = 70 returns None action!\n", - "INFO:root:veh 11 at t = 71 returns None action!\n", - "INFO:root:veh 59 at t = 31 returns None action!\n", - "INFO:root:veh 59 at t = 32 returns None action!\n", - "INFO:root:veh 59 at t = 33 returns None action!\n", - "INFO:root:veh 59 at t = 34 returns None action!\n", - "INFO:root:veh 62 at t = 35 returns None action!\n", - "INFO:root:veh 62 at t = 36 returns None action!\n", - "INFO:root:veh 8 at t = 42 returns None action!\n", - "INFO:root:veh 8 at t = 43 returns None action!\n", - "INFO:root:veh 8 at t = 44 returns None action!\n", - "INFO:root:veh 8 at t = 45 returns None action!\n", - "INFO:root:veh 8 at t = 46 returns None action!\n", - "INFO:root:veh 60 at t = 0 returns None action!\n", - "INFO:root:veh 26 at t = 0 returns None action!\n", - "INFO:root:veh 60 at t = 1 returns None action!\n", - "INFO:root:veh 60 at t = 2 returns None action!\n", - "INFO:root:veh 60 at t = 3 returns None action!\n", - "INFO:root:veh 49 at t = 3 returns None action!\n", - "INFO:root:veh 60 at t = 4 returns None action!\n", - "INFO:root:veh 49 at t = 4 returns None action!\n", - "INFO:root:veh 49 at t = 5 returns None action!\n", - "INFO:root:veh 60 at t = 7 returns None action!\n", - "INFO:root:veh 60 at t = 8 returns None action!\n", - "INFO:root:veh 29 at t = 50 returns None action!\n", - "INFO:root:veh 29 at t = 51 returns None action!\n", - "INFO:root:veh 29 at t = 52 returns None action!\n", - "INFO:root:veh 4 at t = 2 returns None action!\n", - "INFO:root:veh 4 at t = 3 returns None action!\n", - "INFO:root:veh 4 at t = 4 returns None action!\n", - "INFO:root:veh 4 at t = 5 returns None action!\n", - "INFO:root:veh 13 at t = 8 returns None action!\n", - "INFO:root:veh 13 at t = 9 returns None action!\n", - "INFO:root:veh 35 at t = 0 returns None action!\n", - "INFO:root:veh 35 at t = 1 returns None action!\n", - "INFO:root:veh 35 at t = 2 returns None action!\n", - "INFO:root:veh 35 at t = 3 returns None action!\n", - "INFO:root:veh 35 at t = 4 returns None action!\n", - "INFO:root:veh 23 at t = 5 returns None action!\n", - "INFO:root:veh 6 at t = 2 returns None action!\n", - "INFO:root:veh 6 at t = 3 returns None action!\n", - "INFO:root:veh 6 at t = 29 returns None action!\n", - "INFO:root:veh 6 at t = 30 returns None action!\n", - "INFO:root:veh 22 at t = 0 returns None action!\n", - "INFO:root:veh 22 at t = 1 returns None action!\n", - "INFO:root:veh 22 at t = 2 returns None action!\n", - "INFO:root:veh 22 at t = 3 returns None action!\n", - "INFO:root:veh 3 at t = 7 returns None action!\n", - "INFO:root:veh 3 at t = 8 returns None action!\n", - "INFO:root:veh 20 at t = 20 returns None action!\n", - "INFO:root:veh 50 at t = 0 returns None action!\n", - "INFO:root:veh 50 at t = 1 returns None action!\n", - "INFO:root:veh 2 at t = 71 returns None action!\n", - "INFO:root:veh 26 at t = 63 returns None action!\n", - "INFO:root:veh 26 at t = 64 returns None action!\n", - "INFO:root:veh 26 at t = 65 returns None action!\n", - "INFO:root:veh 26 at t = 74 returns None action!\n", - "INFO:root:veh 26 at t = 75 returns None action!\n", - "INFO:root:veh 26 at t = 76 returns None action!\n", - "INFO:root:veh 26 at t = 77 returns None action!\n", - "INFO:root:veh 7 at t = 20 returns None action!\n", - "INFO:root:veh 7 at t = 21 returns None action!\n", - "INFO:root:veh 4 at t = 41 returns None action!\n", - "INFO:root:veh 4 at t = 42 returns None action!\n", - "INFO:root:veh 3 at t = 10 returns None action!\n", - "INFO:root:veh 3 at t = 11 returns None action!\n", - "INFO:root:veh 3 at t = 12 returns None action!\n", - "INFO:root:veh 24 at t = 21 returns None action!\n", - "INFO:root:veh 24 at t = 22 returns None action!\n", - "INFO:root:veh 24 at t = 23 returns None action!\n", - "INFO:root:veh 24 at t = 24 returns None action!\n", - "INFO:root:veh 16 at t = 28 returns None action!\n", - "INFO:root:veh 16 at t = 29 returns None action!\n", - "INFO:root:veh 16 at t = 30 returns None action!\n", - "INFO:root:veh 16 at t = 31 returns None action!\n", - "INFO:root:veh 23 at t = 17 returns None action!\n", - "INFO:root:veh 23 at t = 18 returns None action!\n", - "INFO:root:veh 23 at t = 22 returns None action!\n", - "INFO:root:veh 23 at t = 23 returns None action!\n", - "INFO:root:veh 23 at t = 31 returns None action!\n", - "INFO:root:veh 23 at t = 32 returns None action!\n", - "INFO:root:veh 1 at t = 76 returns None action!\n", - "INFO:root:veh 1 at t = 77 returns None action!\n", - "INFO:root:veh 1 at t = 78 returns None action!\n", - "INFO:root:veh 36 at t = 2 returns None action!\n", - "INFO:root:veh 36 at t = 3 returns None action!\n", - "INFO:root:veh 36 at t = 4 returns None action!\n", - "INFO:root:veh 36 at t = 5 returns None action!\n", - "INFO:root:veh 36 at t = 6 returns None action!\n", - "INFO:root:veh 15 at t = 4 returns None action!\n", - "INFO:root:veh 15 at t = 5 returns None action!\n", - "INFO:root:veh 15 at t = 6 returns None action!\n", - "INFO:root:veh 15 at t = 7 returns None action!\n", - "INFO:root:veh 8 at t = 69 returns None action!\n", - "INFO:root:veh 8 at t = 70 returns None action!\n", - "INFO:root:veh 8 at t = 71 returns None action!\n", - "INFO:root:veh 1 at t = 37 returns None action!\n", - "INFO:root:veh 1 at t = 38 returns None action!\n", - "INFO:root:veh 1 at t = 39 returns None action!\n", - "INFO:root:veh 25 at t = 15 returns None action!\n", - "INFO:root:veh 25 at t = 16 returns None action!\n", - "INFO:root:veh 25 at t = 17 returns None action!\n", - "INFO:root:veh 25 at t = 18 returns None action!\n", - "INFO:root:veh 28 at t = 1 returns None action!\n", - "INFO:root:veh 28 at t = 2 returns None action!\n", - "INFO:root:veh 2 at t = 54 returns None action!\n", - "INFO:root:veh 2 at t = 55 returns None action!\n", - "INFO:root:veh 2 at t = 56 returns None action!\n", - "INFO:root:veh 2 at t = 57 returns None action!\n", - "INFO:root:veh 2 at t = 58 returns None action!\n", - "INFO:root:veh 2 at t = 59 returns None action!\n", - "INFO:root:veh 2 at t = 60 returns None action!\n", - "INFO:root:veh 2 at t = 61 returns None action!\n", - "INFO:root:veh 2 at t = 64 returns None action!\n", - "INFO:root:veh 2 at t = 65 returns None action!\n", - "INFO:root:veh 2 at t = 66 returns None action!\n", - "INFO:root:veh 9 at t = 15 returns None action!\n", - "INFO:root:veh 6 at t = 55 returns None action!\n", - "INFO:root:veh 6 at t = 56 returns None action!\n", - "INFO:root:veh 6 at t = 57 returns None action!\n", - "INFO:root:veh 6 at t = 59 returns None action!\n", - "INFO:root:veh 6 at t = 60 returns None action!\n", - "INFO:root:veh 6 at t = 61 returns None action!\n", - "INFO:root:veh 54 at t = 32 returns None action!\n", - "INFO:root:veh 54 at t = 33 returns None action!\n", - "INFO:root:veh 54 at t = 37 returns None action!\n", - "INFO:root:veh 54 at t = 38 returns None action!\n", - "INFO:root:veh 54 at t = 39 returns None action!\n", - "INFO:root:veh 54 at t = 40 returns None action!\n", - "INFO:root:veh 54 at t = 41 returns None action!\n", - "INFO:root:veh 54 at t = 42 returns None action!\n", - "INFO:root:veh 54 at t = 53 returns None action!\n", - "INFO:root:veh 54 at t = 54 returns None action!\n", - "INFO:root:veh 54 at t = 55 returns None action!\n", - "INFO:root:veh 54 at t = 56 returns None action!\n", - "INFO:root:veh 56 at t = 63 returns None action!\n", - "INFO:root:veh 56 at t = 64 returns None action!\n", - "INFO:root:veh 56 at t = 65 returns None action!\n", - "INFO:root:veh 56 at t = 66 returns None action!\n", - "INFO:root:veh 56 at t = 67 returns None action!\n", - "INFO:root:veh 56 at t = 68 returns None action!\n", - "INFO:root:veh 56 at t = 70 returns None action!\n", - "INFO:root:veh 56 at t = 71 returns None action!\n", - "INFO:root:veh 56 at t = 72 returns None action!\n", - "INFO:root:veh 56 at t = 73 returns None action!\n", - "INFO:root:veh 56 at t = 74 returns None action!\n", - "INFO:root:veh 56 at t = 75 returns None action!\n", - "INFO:root:veh 11 at t = 37 returns None action!\n", - "INFO:root:veh 11 at t = 38 returns None action!\n", - "INFO:root:veh 10 at t = 33 returns None action!\n", - "INFO:root:veh 10 at t = 34 returns None action!\n", - "INFO:root:veh 10 at t = 35 returns None action!\n", - "INFO:root:veh 10 at t = 36 returns None action!\n", - "INFO:root:veh 86 at t = 4 returns None action!\n", - "INFO:root:veh 86 at t = 5 returns None action!\n", - "INFO:root:veh 86 at t = 6 returns None action!\n", - "INFO:root:veh 86 at t = 7 returns None action!\n", - "INFO:root:veh 86 at t = 8 returns None action!\n", - "INFO:root:veh 86 at t = 21 returns None action!\n", - "INFO:root:veh 86 at t = 22 returns None action!\n", - "INFO:root:veh 6 at t = 18 returns None action!\n", - "INFO:root:veh 6 at t = 19 returns None action!\n", - "INFO:root:veh 31 at t = 54 returns None action!\n", - "INFO:root:veh 31 at t = 55 returns None action!\n", - "INFO:root:veh 36 at t = 72 returns None action!\n", - "INFO:root:veh 36 at t = 73 returns None action!\n", - "INFO:root:veh 4 at t = 19 returns None action!\n", - "INFO:root:veh 4 at t = 20 returns None action!\n", - "INFO:root:veh 54 at t = 67 returns None action!\n", - "INFO:root:veh 54 at t = 68 returns None action!\n", - "INFO:root:veh 54 at t = 69 returns None action!\n", - "INFO:root:veh 3 at t = 4 returns None action!\n", - "INFO:root:veh 3 at t = 5 returns None action!\n", - "INFO:root:veh 45 at t = 44 returns None action!\n", - "INFO:root:veh 45 at t = 45 returns None action!\n", - "INFO:root:veh 45 at t = 46 returns None action!\n", - "INFO:root:veh 45 at t = 47 returns None action!\n", - "INFO:root:veh 48 at t = 64 returns None action!\n", - "INFO:root:veh 48 at t = 65 returns None action!\n", - "INFO:root:veh 48 at t = 68 returns None action!\n", - "INFO:root:veh 48 at t = 69 returns None action!\n", - "INFO:root:veh 45 at t = 78 returns None action!\n", - "INFO:root:veh 17 at t = 37 returns None action!\n", - "INFO:root:veh 17 at t = 38 returns None action!\n", - "INFO:root:veh 17 at t = 39 returns None action!\n", - "INFO:root:veh 17 at t = 40 returns None action!\n", - "INFO:root:veh 17 at t = 41 returns None action!\n", - "INFO:root:veh 44 at t = 0 returns None action!\n", - "INFO:root:veh 44 at t = 1 returns None action!\n", - "INFO:root:veh 3 at t = 17 returns None action!\n", - "INFO:root:veh 3 at t = 18 returns None action!\n", - "INFO:root:veh 3 at t = 19 returns None action!\n", - "INFO:root:veh 3 at t = 20 returns None action!\n", - "INFO:root:veh 44 at t = 28 returns None action!\n", - "INFO:root:veh 44 at t = 29 returns None action!\n", - "INFO:root:veh 3 at t = 3 returns None action!\n", - "INFO:root:veh 3 at t = 4 returns None action!\n", - "INFO:root:veh 22 at t = 23 returns None action!\n", - "INFO:root:veh 22 at t = 24 returns None action!\n", - "INFO:root:veh 22 at t = 28 returns None action!\n", - "INFO:root:veh 22 at t = 29 returns None action!\n", - "INFO:root:veh 22 at t = 30 returns None action!\n", - "INFO:root:veh 22 at t = 31 returns None action!\n", - "INFO:root:veh 22 at t = 34 returns None action!\n", - "INFO:root:veh 22 at t = 35 returns None action!\n", - "INFO:root:veh 22 at t = 37 returns None action!\n", - "INFO:root:veh 22 at t = 38 returns None action!\n", - "INFO:root:veh 22 at t = 39 returns None action!\n", - "INFO:root:veh 22 at t = 40 returns None action!\n", - "INFO:root:veh 22 at t = 41 returns None action!\n", - "INFO:root:veh 22 at t = 42 returns None action!\n", - "INFO:root:veh 22 at t = 43 returns None action!\n", - "INFO:root:veh 22 at t = 47 returns None action!\n", - "INFO:root:veh 22 at t = 48 returns None action!\n", - "INFO:root:veh 94 at t = 24 returns None action!\n", - "INFO:root:veh 94 at t = 25 returns None action!\n", - "INFO:root:veh 94 at t = 26 returns None action!\n", - "INFO:root:veh 9 at t = 24 returns None action!\n", - "INFO:root:veh 9 at t = 25 returns None action!\n", - "INFO:root:veh 59 at t = 60 returns None action!\n", - "INFO:root:veh 59 at t = 61 returns None action!\n", - "INFO:root:veh 59 at t = 62 returns None action!\n", - "INFO:root:veh 59 at t = 63 returns None action!\n", - "INFO:root:veh 59 at t = 64 returns None action!\n", - "INFO:root:veh 59 at t = 67 returns None action!\n", - "INFO:root:veh 59 at t = 68 returns None action!\n", - "INFO:root:veh 59 at t = 69 returns None action!\n", - "INFO:root:veh 44 at t = 0 returns None action!\n", - "INFO:root:veh 9 at t = 12 returns None action!\n", - "INFO:root:veh 9 at t = 13 returns None action!\n", - "INFO:root:veh 9 at t = 14 returns None action!\n", - "INFO:root:veh 9 at t = 15 returns None action!\n", - "INFO:root:veh 9 at t = 16 returns None action!\n", - "INFO:root:veh 9 at t = 17 returns None action!\n", - "INFO:root:veh 9 at t = 18 returns None action!\n", - "INFO:root:veh 109 at t = 0 returns None action!\n", - "INFO:root:veh 109 at t = 1 returns None action!\n", - "INFO:root:veh 109 at t = 2 returns None action!\n", - "INFO:root:veh 109 at t = 3 returns None action!\n", - "INFO:root:veh 109 at t = 19 returns None action!\n", - "INFO:root:veh 109 at t = 20 returns None action!\n", - "INFO:root:veh 109 at t = 21 returns None action!\n", - "INFO:root:veh 109 at t = 22 returns None action!\n", - "INFO:root:veh 109 at t = 30 returns None action!\n", - "INFO:root:veh 109 at t = 31 returns None action!\n", - "INFO:root:veh 109 at t = 32 returns None action!\n", - "INFO:root:veh 109 at t = 63 returns None action!\n", - "INFO:root:veh 109 at t = 64 returns None action!\n", - "INFO:root:veh 109 at t = 66 returns None action!\n", - "INFO:root:veh 109 at t = 67 returns None action!\n", - "INFO:root:veh 109 at t = 78 returns None action!\n", - "INFO:root:veh 3 at t = 75 returns None action!\n", - "INFO:root:veh 3 at t = 76 returns None action!\n", - "INFO:root:veh 10 at t = 5 returns None action!\n", - "INFO:root:veh 10 at t = 6 returns None action!\n", - "INFO:root:veh 4 at t = 11 returns None action!\n", - "INFO:root:veh 4 at t = 12 returns None action!\n", - "INFO:root:veh 1 at t = 15 returns None action!\n", - "INFO:root:veh 1 at t = 16 returns None action!\n", - "INFO:root:veh 1 at t = 17 returns None action!\n", - "INFO:root:veh 1 at t = 18 returns None action!\n", - "INFO:root:veh 5 at t = 18 returns None action!\n", - "INFO:root:veh 51 at t = 65 returns None action!\n", - "INFO:root:veh 51 at t = 66 returns None action!\n", - "INFO:root:veh 48 at t = 67 returns None action!\n", - "INFO:root:veh 48 at t = 68 returns None action!\n", - "INFO:root:veh 2 at t = 6 returns None action!\n", - "INFO:root:veh 2 at t = 7 returns None action!\n", - "INFO:root:veh 2 at t = 22 returns None action!\n", - "INFO:root:veh 2 at t = 23 returns None action!\n", - "INFO:root:veh 3 at t = 22 returns None action!\n", - "INFO:root:veh 3 at t = 23 returns None action!\n", - "INFO:root:veh 3 at t = 24 returns None action!\n", - "INFO:root:veh 70 at t = 73 returns None action!\n", - "INFO:root:veh 70 at t = 74 returns None action!\n", - "INFO:root:veh 70 at t = 75 returns None action!\n", - "INFO:root:veh 55 at t = 58 returns None action!\n", - "INFO:root:veh 55 at t = 59 returns None action!\n", - "INFO:root:veh 55 at t = 60 returns None action!\n", - "INFO:root:veh 55 at t = 61 returns None action!\n", - "INFO:root:veh 13 at t = 11 returns None action!\n", - "INFO:root:veh 13 at t = 12 returns None action!\n", - "INFO:root:veh 13 at t = 36 returns None action!\n", - "INFO:root:veh 13 at t = 37 returns None action!\n", - "INFO:root:veh 84 at t = 1 returns None action!\n", - "INFO:root:veh 84 at t = 2 returns None action!\n", - "INFO:root:veh 10 at t = 15 returns None action!\n", - "INFO:root:veh 10 at t = 16 returns None action!\n", - "INFO:root:veh 10 at t = 17 returns None action!\n", - "INFO:root:veh 9 at t = 20 returns None action!\n", - "INFO:root:veh 9 at t = 21 returns None action!\n", - "INFO:root:veh 9 at t = 22 returns None action!\n", - "INFO:root:veh 5 at t = 40 returns None action!\n", - "INFO:root:veh 5 at t = 41 returns None action!\n", - "INFO:root:veh 5 at t = 44 returns None action!\n", - "INFO:root:veh 5 at t = 45 returns None action!\n", - "INFO:root:veh 5 at t = 46 returns None action!\n", - "INFO:root:veh 5 at t = 47 returns None action!\n", - "INFO:root:veh 5 at t = 50 returns None action!\n", - "INFO:root:veh 5 at t = 51 returns None action!\n", - "INFO:root:veh 5 at t = 52 returns None action!\n", - "INFO:root:veh 5 at t = 53 returns None action!\n", - "INFO:root:veh 5 at t = 54 returns None action!\n", - "INFO:root:veh 5 at t = 55 returns None action!\n", - "INFO:root:veh 5 at t = 56 returns None action!\n", - "INFO:root:veh 19 at t = 10 returns None action!\n", - "INFO:root:veh 19 at t = 11 returns None action!\n", - "INFO:root:veh 19 at t = 12 returns None action!\n", - "INFO:root:veh 3 at t = 58 returns None action!\n", - "INFO:root:veh 3 at t = 59 returns None action!\n", - "INFO:root:veh 3 at t = 60 returns None action!\n", - "INFO:root:veh 25 at t = 12 returns None action!\n", - "INFO:root:veh 25 at t = 13 returns None action!\n", - "INFO:root:veh 25 at t = 16 returns None action!\n", - "INFO:root:veh 15 at t = 3 returns None action!\n", - "INFO:root:veh 15 at t = 4 returns None action!\n", - "INFO:root:veh 15 at t = 5 returns None action!\n", - "INFO:root:veh 7 at t = 5 returns None action!\n", - "INFO:root:veh 15 at t = 6 returns None action!\n", - "INFO:root:veh 94 at t = 2 returns None action!\n", - "INFO:root:veh 94 at t = 3 returns None action!\n", - "INFO:root:veh 94 at t = 4 returns None action!\n", - "INFO:root:veh 88 at t = 5 returns None action!\n", - "INFO:root:veh 88 at t = 6 returns None action!\n", - "INFO:root:veh 88 at t = 7 returns None action!\n", - "INFO:root:veh 88 at t = 8 returns None action!\n", - "INFO:root:veh 88 at t = 22 returns None action!\n", - "INFO:root:veh 88 at t = 23 returns None action!\n", - "INFO:root:veh 10 at t = 22 returns None action!\n", - "INFO:root:veh 10 at t = 23 returns None action!\n", - "INFO:root:veh 22 at t = 2 returns None action!\n", - "INFO:root:veh 22 at t = 3 returns None action!\n", - "INFO:root:veh 34 at t = 0 returns None action!\n", - "INFO:root:veh 34 at t = 1 returns None action!\n", - "INFO:root:veh 32 at t = 7 returns None action!\n", - "INFO:root:veh 32 at t = 8 returns None action!\n", - "INFO:root:veh 28 at t = 8 returns None action!\n", - "INFO:root:veh 32 at t = 9 returns None action!\n", - "INFO:root:veh 28 at t = 9 returns None action!\n", - "INFO:root:veh 21 at t = 67 returns None action!\n", - "INFO:root:veh 21 at t = 68 returns None action!\n", - "INFO:root:veh 21 at t = 69 returns None action!\n", - "INFO:root:veh 21 at t = 70 returns None action!\n", - "INFO:root:veh 21 at t = 71 returns None action!\n", - "INFO:root:veh 21 at t = 1 returns None action!\n", - "INFO:root:veh 21 at t = 2 returns None action!\n", - "INFO:root:veh 21 at t = 3 returns None action!\n", - "INFO:root:veh 20 at t = 32 returns None action!\n", - "INFO:root:veh 20 at t = 33 returns None action!\n", - "INFO:root:veh 20 at t = 34 returns None action!\n", - "INFO:root:veh 71 at t = 13 returns None action!\n", - "INFO:root:veh 71 at t = 14 returns None action!\n", - "INFO:root:veh 71 at t = 15 returns None action!\n", - "INFO:root:veh 4 at t = 20 returns None action!\n", - "INFO:root:veh 4 at t = 21 returns None action!\n", - "INFO:root:veh 19 at t = 46 returns None action!\n", - "INFO:root:veh 19 at t = 47 returns None action!\n", - "INFO:root:veh 31 at t = 6 returns None action!\n", - "INFO:root:veh 31 at t = 7 returns None action!\n", - "INFO:root:veh 31 at t = 8 returns None action!\n", - "INFO:root:veh 26 at t = 50 returns None action!\n", - "INFO:root:veh 26 at t = 51 returns None action!\n", - "INFO:root:veh 15 at t = 58 returns None action!\n", - "INFO:root:veh 22 at t = 4 returns None action!\n", - "INFO:root:veh 22 at t = 5 returns None action!\n", - "INFO:root:veh 6 at t = 31 returns None action!\n", - "INFO:root:veh 6 at t = 32 returns None action!\n", - "INFO:root:veh 6 at t = 33 returns None action!\n", - "INFO:root:veh 6 at t = 34 returns None action!\n", - "INFO:root:veh 6 at t = 35 returns None action!\n", - "INFO:root:veh 15 at t = 16 returns None action!\n", - "INFO:root:veh 15 at t = 17 returns None action!\n", - "INFO:root:veh 15 at t = 18 returns None action!\n", - "INFO:root:veh 15 at t = 19 returns None action!\n", - "INFO:root:veh 50 at t = 1 returns None action!\n", - "INFO:root:veh 50 at t = 2 returns None action!\n", - "INFO:root:veh 50 at t = 5 returns None action!\n", - "INFO:root:veh 50 at t = 6 returns None action!\n", - "INFO:root:veh 26 at t = 7 returns None action!\n", - "INFO:root:veh 26 at t = 8 returns None action!\n", - "INFO:root:veh 41 at t = 41 returns None action!\n", - "INFO:root:veh 41 at t = 42 returns None action!\n", - "INFO:root:veh 41 at t = 43 returns None action!\n", - "INFO:root:veh 15 at t = 4 returns None action!\n", - "INFO:root:veh 15 at t = 5 returns None action!\n", - "INFO:root:veh 5 at t = 41 returns None action!\n", - "INFO:root:veh 5 at t = 42 returns None action!\n", - "INFO:root:veh 5 at t = 44 returns None action!\n", - "INFO:root:veh 5 at t = 45 returns None action!\n", - "INFO:root:veh 5 at t = 46 returns None action!\n", - "INFO:root:veh 5 at t = 47 returns None action!\n", - "INFO:root:veh 0 at t = 34 returns None action!\n", - "INFO:root:veh 0 at t = 35 returns None action!\n", - "INFO:root:veh 0 at t = 36 returns None action!\n", - "INFO:root:veh 8 at t = 37 returns None action!\n", - "INFO:root:veh 8 at t = 38 returns None action!\n", - "INFO:root:veh 8 at t = 39 returns None action!\n", - "INFO:root:veh 5 at t = 42 returns None action!\n", - "INFO:root:veh 5 at t = 43 returns None action!\n", - "INFO:root:veh 5 at t = 44 returns None action!\n", - "INFO:root:veh 5 at t = 45 returns None action!\n", - "INFO:root:veh 7 at t = 48 returns None action!\n", - "INFO:root:veh 7 at t = 49 returns None action!\n", - "INFO:root:veh 7 at t = 52 returns None action!\n", - "INFO:root:veh 7 at t = 53 returns None action!\n", - "INFO:root:veh 8 at t = 64 returns None action!\n", - "INFO:root:veh 8 at t = 65 returns None action!\n", - "INFO:root:veh 8 at t = 66 returns None action!\n", - "INFO:root:veh 21 at t = 6 returns None action!\n", - "INFO:root:veh 21 at t = 7 returns None action!\n", - "INFO:root:veh 21 at t = 8 returns None action!\n", - "INFO:root:veh 10 at t = 10 returns None action!\n", - "INFO:root:veh 10 at t = 11 returns None action!\n", - "INFO:root:veh 10 at t = 12 returns None action!\n", - "INFO:root:veh 27 at t = 21 returns None action!\n", - "INFO:root:veh 27 at t = 22 returns None action!\n", - "INFO:root:veh 27 at t = 23 returns None action!\n", - "INFO:root:veh 6 at t = 61 returns None action!\n", - "INFO:root:veh 6 at t = 62 returns None action!\n", - "INFO:root:veh 6 at t = 63 returns None action!\n", - "INFO:root:veh 6 at t = 69 returns None action!\n", - "INFO:root:veh 6 at t = 70 returns None action!\n", - "INFO:root:veh 6 at t = 71 returns None action!\n", - "INFO:root:veh 6 at t = 72 returns None action!\n", - "INFO:root:veh 6 at t = 73 returns None action!\n", - "INFO:root:veh 6 at t = 74 returns None action!\n", - "INFO:root:veh 42 at t = 20 returns None action!\n", - "INFO:root:veh 42 at t = 21 returns None action!\n", - "INFO:root:veh 42 at t = 22 returns None action!\n", - "INFO:root:veh 42 at t = 23 returns None action!\n", - "INFO:root:veh 42 at t = 65 returns None action!\n", - "INFO:root:veh 42 at t = 66 returns None action!\n", - "INFO:root:veh 42 at t = 78 returns None action!\n", - "INFO:root:veh 42 at t = 79 returns None action!\n", - "INFO:root:veh 3 at t = 0 returns None action!\n", - "INFO:root:veh 3 at t = 1 returns None action!\n", - "INFO:root:veh 3 at t = 2 returns None action!\n", - "INFO:root:veh 3 at t = 3 returns None action!\n", - "INFO:root:veh 0 at t = 1 returns None action!\n", - "INFO:root:veh 0 at t = 2 returns None action!\n", - "INFO:root:veh 17 at t = 7 returns None action!\n", - "INFO:root:veh 17 at t = 8 returns None action!\n", - "INFO:root:veh 17 at t = 10 returns None action!\n", - "INFO:root:veh 17 at t = 11 returns None action!\n", - "INFO:root:veh 17 at t = 12 returns None action!\n", - "INFO:root:veh 17 at t = 13 returns None action!\n", - "INFO:root:veh 17 at t = 32 returns None action!\n", - "INFO:root:veh 17 at t = 33 returns None action!\n", - "INFO:root:veh 7 at t = 7 returns None action!\n", - "INFO:root:veh 7 at t = 8 returns None action!\n", - "INFO:root:veh 1 at t = 36 returns None action!\n", - "INFO:root:veh 1 at t = 37 returns None action!\n", - "INFO:root:veh 11 at t = 77 returns None action!\n", - "INFO:root:veh 11 at t = 78 returns None action!\n", - "INFO:root:veh 5 at t = 6 returns None action!\n", - "INFO:root:veh 5 at t = 7 returns None action!\n", - "INFO:root:veh 5 at t = 8 returns None action!\n", - "INFO:root:veh 5 at t = 10 returns None action!\n", - "INFO:root:veh 5 at t = 11 returns None action!\n", - "INFO:root:veh 5 at t = 12 returns None action!\n", - "INFO:root:veh 5 at t = 13 returns None action!\n", - "INFO:root:veh 5 at t = 14 returns None action!\n" + "INFO:root:\n", + " Evaluating policy on 50 files...\n", + "100%|██████████| 50/50 [00:14<00:00, 3.47it/s]\n" ] } ], @@ -716,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -745,12 +242,13 @@ " agents_controlled\n", " reg_coef\n", " act_acc\n", + " accel_val_mae\n", + " steer_val_mae\n", " pos_rmse\n", " speed_mae\n", " goal_rate\n", " veh_edge_cr\n", " veh_veh_cr\n", - " num_violations\n", " class\n", " \n", " \n", @@ -758,207 +256,977 @@ " \n", " 0\n", " None\n", - " tfrecord-00004-of-01000_378.json\n", - " 5\n", + " tfrecord-00350-of-01000_410.json\n", + " 2\n", " 0.0\n", - " 0.140\n", - " 2.581\n", - " 0.288\n", - " 0.400\n", - " 0.600\n", + " 0.094\n", + " 1.526\n", + " 0.127\n", + " 5.253\n", + " 1.066\n", + " 0.000\n", + " 1.000\n", " 0.000\n", - " 0\n", " IL\n", " \n", " \n", " 1\n", " None\n", - " tfrecord-00003-of-01000_109.json\n", - " 20\n", + " tfrecord-00349-of-01000_311.json\n", + " 2\n", " 0.0\n", - " 0.112\n", - " 119.303\n", - " 20.169\n", - " 0.350\n", - " 0.250\n", - " 0.100\n", - " 0\n", + " 0.019\n", + " 2.409\n", + " 0.281\n", + " 3.528\n", + " 1.018\n", + " 0.000\n", + " 0.000\n", + " 1.000\n", " IL\n", " \n", " \n", " 2\n", " None\n", - " tfrecord-00004-of-01000_61.json\n", - " 6\n", + " tfrecord-00421-of-01000_364.json\n", + " 9\n", " 0.0\n", - " 0.210\n", - " 4.265\n", - " 0.463\n", - " 0.167\n", - " 0.500\n", + " 0.067\n", + " 1.970\n", + " 0.164\n", + " 7.574\n", + " 1.055\n", " 0.000\n", - " 0\n", + " 0.222\n", + " 0.667\n", " IL\n", " \n", " \n", " 3\n", " None\n", - " tfrecord-00012-of-01000_87.json\n", - " 9\n", + " tfrecord-00243-of-01000_280.json\n", + " 2\n", " 0.0\n", - " 0.190\n", - " 181.568\n", - " 162.275\n", - " 0.333\n", + " 0.012\n", + " 1.974\n", + " 0.239\n", + " 2.759\n", + " 0.555\n", " 0.000\n", " 0.000\n", - " 0\n", + " 1.000\n", " IL\n", " \n", " \n", " 4\n", " None\n", - " tfrecord-00007-of-01000_237.json\n", - " 14\n", + " tfrecord-00275-of-01000_483.json\n", + " 3\n", " 0.0\n", - " 0.278\n", - " 8.479\n", - " 0.513\n", - " 0.286\n", - " 0.071\n", + " 0.050\n", + " 1.833\n", + " 0.220\n", + " 3.999\n", + " 0.621\n", " 0.000\n", - " 0\n", + " 0.667\n", + " 0.333\n", " IL\n", " \n", " \n", " 5\n", " None\n", - " tfrecord-00005-of-01000_423.json\n", - " 4\n", + " tfrecord-00201-of-01000_202.json\n", + " 16\n", " 0.0\n", - " 0.391\n", - " 6.674\n", - " 0.776\n", + " 0.094\n", + " 2.250\n", + " 0.003\n", + " 140.214\n", + " 57.562\n", + " 0.312\n", " 0.250\n", " 0.000\n", - " 0.000\n", - " 0\n", " IL\n", " \n", " \n", " 6\n", " None\n", - " tfrecord-00012-of-01000_246.json\n", - " 20\n", + " tfrecord-00623-of-01000_242.json\n", + " 8\n", " 0.0\n", - " 0.108\n", - " 162.746\n", - " 53.993\n", - " 0.050\n", + " 0.058\n", + " 2.032\n", + " 0.375\n", + " 205.846\n", + " 114.555\n", + " 0.125\n", + " 0.000\n", + " 0.375\n", + " IL\n", + " \n", + " \n", + " 7\n", + " None\n", + " tfrecord-00316-of-01000_183.json\n", + " 10\n", + " 0.0\n", + " 0.162\n", + " 2.260\n", + " 0.013\n", + " 8.700\n", + " 0.760\n", + " 0.200\n", + " 0.200\n", + " 0.200\n", + " IL\n", + " \n", + " \n", + " 8\n", + " None\n", + " tfrecord-00443-of-01000_487.json\n", + " 6\n", + " 0.0\n", + " 0.025\n", + " 2.143\n", + " 0.294\n", + " 9.119\n", + " 1.796\n", + " 0.000\n", + " 0.333\n", " 0.500\n", + " IL\n", + " \n", + " \n", + " 9\n", + " None\n", + " tfrecord-00908-of-01000_302.json\n", + " 20\n", + " 0.0\n", + " 0.149\n", + " 2.089\n", + " 0.039\n", + " 12.949\n", + " 1.180\n", + " 0.150\n", + " 0.200\n", " 0.200\n", - " 0\n", " IL\n", " \n", " \n", - " 7\n", + " 10\n", + " None\n", + " tfrecord-00395-of-01000_97.json\n", + " 18\n", + " 0.0\n", + " 0.076\n", + " 2.333\n", + " 0.292\n", + " 253.132\n", + " 200.875\n", + " 0.000\n", + " 0.833\n", + " 0.111\n", + " IL\n", + " \n", + " \n", + " 11\n", + " None\n", + " tfrecord-00845-of-01000_433.json\n", + " 3\n", + " 0.0\n", + " 0.133\n", + " 1.311\n", + " 0.190\n", + " 3.827\n", + " 0.630\n", + " 0.000\n", + " 0.333\n", + " 0.667\n", + " IL\n", + " \n", + " \n", + " 12\n", + " None\n", + " tfrecord-00628-of-01000_338.json\n", + " 3\n", + " 0.0\n", + " 0.029\n", + " 1.485\n", + " 0.367\n", + " 4.083\n", + " 0.509\n", + " 0.000\n", + " 0.333\n", + " 0.333\n", + " IL\n", + " \n", + " \n", + " 13\n", + " None\n", + " tfrecord-00498-of-01000_229.json\n", + " 7\n", + " 0.0\n", + " 0.088\n", + " 2.199\n", + " 0.113\n", + " 195.827\n", + " 120.754\n", + " 0.143\n", + " 0.571\n", + " 0.143\n", + " IL\n", + " \n", + " \n", + " 14\n", + " None\n", + " tfrecord-00069-of-01000_193.json\n", + " 5\n", + " 0.0\n", + " 0.115\n", + " 1.891\n", + " 0.144\n", + " 169.086\n", + " 297.963\n", + " 0.000\n", + " 0.400\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 15\n", + " None\n", + " tfrecord-00271-of-01000_435.json\n", + " 18\n", + " 0.0\n", + " 0.178\n", + " 1.959\n", + " 0.024\n", + " 14.123\n", + " 1.407\n", + " 0.056\n", + " 0.111\n", + " 0.389\n", + " IL\n", + " \n", + " \n", + " 16\n", + " None\n", + " tfrecord-00480-of-01000_25.json\n", + " 3\n", + " 0.0\n", + " 0.142\n", + " 1.460\n", + " 0.146\n", + " 9.145\n", + " 1.885\n", + " 0.333\n", + " 0.000\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 17\n", + " None\n", + " tfrecord-00683-of-01000_42.json\n", + " 2\n", + " 0.0\n", + " 0.262\n", + " 1.287\n", + " 0.102\n", + " 8.824\n", + " 2.010\n", + " 0.000\n", + " 0.500\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 18\n", + " None\n", + " tfrecord-00384-of-01000_192.json\n", + " 3\n", + " 0.0\n", + " 0.117\n", + " 1.701\n", + " 0.096\n", + " 7.469\n", + " 1.049\n", + " 0.000\n", + " 0.667\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 19\n", " None\n", - " tfrecord-00012-of-01000_389.json\n", + " tfrecord-00073-of-01000_173.json\n", " 4\n", " 0.0\n", - " 0.172\n", - " 123.747\n", - " 73.352\n", - " 0.250\n", + " 0.178\n", + " 2.130\n", + " 0.046\n", + " 8.065\n", + " 1.571\n", + " 0.000\n", " 0.500\n", " 0.000\n", - " 0\n", " IL\n", " \n", " \n", - " 8\n", + " 20\n", " None\n", - " tfrecord-00001-of-01000_307.json\n", + " tfrecord-00468-of-01000_290.json\n", " 3\n", " 0.0\n", - " 0.258\n", - " 4.814\n", - " 0.485\n", + " 0.017\n", + " 2.952\n", + " 0.461\n", + " 4.592\n", + " 2.468\n", + " 0.000\n", + " 0.333\n", " 0.667\n", + " IL\n", + " \n", + " \n", + " 21\n", + " None\n", + " tfrecord-00656-of-01000_18.json\n", + " 3\n", + " 0.0\n", + " 0.025\n", + " 2.386\n", + " 0.392\n", + " 6.863\n", + " 3.039\n", + " 0.000\n", + " 1.000\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 22\n", + " None\n", + " tfrecord-00651-of-01000_285.json\n", + " 4\n", + " 0.0\n", + " 0.084\n", + " 1.902\n", + " 0.050\n", + " 4.980\n", + " 0.859\n", + " 0.250\n", + " 0.250\n", + " 0.250\n", + " IL\n", + " \n", + " \n", + " 23\n", + " None\n", + " tfrecord-00364-of-01000_319.json\n", + " 2\n", + " 0.0\n", + " 0.088\n", + " 1.411\n", + " 0.288\n", + " 4.417\n", + " 0.519\n", + " 0.000\n", + " 1.000\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 24\n", + " None\n", + " tfrecord-00615-of-01000_417.json\n", + " 13\n", + " 0.0\n", + " 0.034\n", + " 2.531\n", + " 0.350\n", + " 11.520\n", + " 2.748\n", + " 0.154\n", + " 0.385\n", + " 0.385\n", + " IL\n", + " \n", + " \n", + " 25\n", + " None\n", + " tfrecord-00795-of-01000_89.json\n", + " 3\n", + " 0.0\n", + " 0.042\n", + " 2.039\n", + " 0.180\n", + " 171.441\n", + " 274.731\n", " 0.333\n", + " 0.667\n", " 0.000\n", - " 0\n", " IL\n", " \n", " \n", - " 9\n", + " 26\n", + " None\n", + " tfrecord-00880-of-01000_341.json\n", + " 20\n", + " 0.0\n", + " 0.034\n", + " 2.709\n", + " 0.451\n", + " 168.273\n", + " 81.749\n", + " 0.000\n", + " 0.500\n", + " 0.500\n", + " IL\n", + " \n", + " \n", + " 27\n", + " None\n", + " tfrecord-00187-of-01000_350.json\n", + " 3\n", + " 0.0\n", + " 0.104\n", + " 1.837\n", + " 0.169\n", + " 6.644\n", + " 1.400\n", + " 0.000\n", + " 0.667\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 28\n", " None\n", - " tfrecord-00004-of-01000_157.json\n", - " 11\n", + " tfrecord-00428-of-01000_234.json\n", + " 8\n", " 0.0\n", + " 0.059\n", + " 1.558\n", + " 0.465\n", + " 8.172\n", + " 1.171\n", + " 0.000\n", + " 0.875\n", " 0.125\n", - " 260.832\n", - " 343.664\n", - " 0.182\n", + " IL\n", + " \n", + " \n", + " 29\n", + " None\n", + " tfrecord-00748-of-01000_178.json\n", + " 3\n", + " 0.0\n", " 0.000\n", - " 0.545\n", - " 0\n", + " 2.302\n", + " 0.334\n", + " 3.596\n", + " 1.461\n", + " 0.000\n", + " 0.000\n", + " 1.000\n", " IL\n", " \n", - " \n", - "\n", - "" - ], - "text/plain": [ - " run_id traffic_scene agents_controlled reg_coef \\\n", - "0 None tfrecord-00004-of-01000_378.json 5 0.0 \n", - "1 None tfrecord-00003-of-01000_109.json 20 0.0 \n", - "2 None tfrecord-00004-of-01000_61.json 6 0.0 \n", - "3 None tfrecord-00012-of-01000_87.json 9 0.0 \n", - "4 None tfrecord-00007-of-01000_237.json 14 0.0 \n", - "5 None tfrecord-00005-of-01000_423.json 4 0.0 \n", - "6 None tfrecord-00012-of-01000_246.json 20 0.0 \n", - "7 None tfrecord-00012-of-01000_389.json 4 0.0 \n", - "8 None tfrecord-00001-of-01000_307.json 3 0.0 \n", - "9 None tfrecord-00004-of-01000_157.json 11 0.0 \n", - "\n", - " act_acc pos_rmse speed_mae goal_rate veh_edge_cr veh_veh_cr \\\n", - "0 0.140 2.581 0.288 0.400 0.600 0.000 \n", - "1 0.112 119.303 20.169 0.350 0.250 0.100 \n", - "2 0.210 4.265 0.463 0.167 0.500 0.000 \n", - "3 0.190 181.568 162.275 0.333 0.000 0.000 \n", - "4 0.278 8.479 0.513 0.286 0.071 0.000 \n", - "5 0.391 6.674 0.776 0.250 0.000 0.000 \n", - "6 0.108 162.746 53.993 0.050 0.500 0.200 \n", - "7 0.172 123.747 73.352 0.250 0.500 0.000 \n", - "8 0.258 4.814 0.485 0.667 0.333 0.000 \n", - "9 0.125 260.832 343.664 0.182 0.000 0.545 \n", - "\n", - " num_violations class \n", - "0 0 IL \n", - "1 0 IL \n", - "2 0 IL \n", - "3 0 IL \n", - "4 0 IL \n", - "5 0 IL \n", - "6 0 IL \n", - "7 0 IL \n", - "8 0 IL \n", - "9 0 IL " - ] - }, - "execution_count": 133, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_il_res.round(3)" - ] - }, + " \n", + " 30\n", + " None\n", + " tfrecord-00481-of-01000_309.json\n", + " 2\n", + " 0.0\n", + " 0.106\n", + " 2.680\n", + " 0.135\n", + " 10.895\n", + " 3.673\n", + " 0.000\n", + " 1.000\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 31\n", + " None\n", + " tfrecord-00246-of-01000_104.json\n", + " 5\n", + " 0.0\n", + " 0.202\n", + " 2.069\n", + " 0.028\n", + " 13.658\n", + " 2.539\n", + " 0.000\n", + " 0.600\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 32\n", + " None\n", + " tfrecord-00880-of-01000_456.json\n", + " 20\n", + " 0.0\n", + " 0.090\n", + " 2.430\n", + " 0.159\n", + " 141.539\n", + " 35.880\n", + " 0.150\n", + " 0.450\n", + " 0.200\n", + " IL\n", + " \n", + " \n", + " 33\n", + " None\n", + " tfrecord-00576-of-01000_258.json\n", + " 5\n", + " 0.0\n", + " 0.105\n", + " 1.928\n", + " 0.166\n", + " 6.807\n", + " 1.368\n", + " 0.200\n", + " 0.600\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 34\n", + " None\n", + " tfrecord-00342-of-01000_476.json\n", + " 20\n", + " 0.0\n", + " 0.034\n", + " 2.557\n", + " 0.232\n", + " 221.795\n", + " 177.375\n", + " 0.300\n", + " 0.050\n", + " 0.600\n", + " IL\n", + " \n", + " \n", + " 35\n", + " None\n", + " tfrecord-00445-of-01000_61.json\n", + " 18\n", + " 0.0\n", + " 0.124\n", + " 2.418\n", + " 0.056\n", + " 181.373\n", + " 72.507\n", + " 0.056\n", + " 0.111\n", + " 0.333\n", + " IL\n", + " \n", + " \n", + " 36\n", + " None\n", + " tfrecord-00430-of-01000_75.json\n", + " 1\n", + " 0.0\n", + " 0.062\n", + " 1.224\n", + " 0.203\n", + " 4.086\n", + " 1.092\n", + " 0.000\n", + " 1.000\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 37\n", + " None\n", + " tfrecord-00518-of-01000_117.json\n", + " 4\n", + " 0.0\n", + " 0.078\n", + " 2.130\n", + " 0.054\n", + " 8.837\n", + " 1.530\n", + " 0.250\n", + " 0.500\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 38\n", + " None\n", + " tfrecord-00444-of-01000_226.json\n", + " 20\n", + " 0.0\n", + " 0.038\n", + " 2.863\n", + " 0.268\n", + " 255.838\n", + " 181.061\n", + " 0.150\n", + " 0.650\n", + " 0.100\n", + " IL\n", + " \n", + " \n", + " 39\n", + " None\n", + " tfrecord-00083-of-01000_389.json\n", + " 20\n", + " 0.0\n", + " 0.011\n", + " 2.175\n", + " 0.456\n", + " 33.302\n", + " 2.794\n", + " 0.150\n", + " 0.300\n", + " 0.350\n", + " IL\n", + " \n", + " \n", + " 40\n", + " None\n", + " tfrecord-00119-of-01000_361.json\n", + " 20\n", + " 0.0\n", + " 0.124\n", + " 2.261\n", + " 0.016\n", + " 139.079\n", + " 33.757\n", + " 0.300\n", + " 0.100\n", + " 0.400\n", + " IL\n", + " \n", + " \n", + " 41\n", + " None\n", + " tfrecord-00559-of-01000_436.json\n", + " 6\n", + " 0.0\n", + " 0.029\n", + " 2.778\n", + " 0.236\n", + " 148.869\n", + " 316.467\n", + " 0.333\n", + " 0.333\n", + " 0.167\n", + " IL\n", + " \n", + " \n", + " 42\n", + " None\n", + " tfrecord-00172-of-01000_326.json\n", + " 5\n", + " 0.0\n", + " 0.022\n", + " 2.604\n", + " 0.361\n", + " 6.835\n", + " 2.275\n", + " 0.000\n", + " 0.800\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 43\n", + " None\n", + " tfrecord-00444-of-01000_42.json\n", + " 8\n", + " 0.0\n", + " 0.142\n", + " 1.849\n", + " 0.063\n", + " 220.961\n", + " 209.358\n", + " 0.000\n", + " 0.125\n", + " 0.375\n", + " IL\n", + " \n", + " \n", + " 44\n", + " None\n", + " tfrecord-00569-of-01000_314.json\n", + " 13\n", + " 0.0\n", + " 0.016\n", + " 3.279\n", + " 0.481\n", + " 15.319\n", + " 5.635\n", + " 0.154\n", + " 0.385\n", + " 0.462\n", + " IL\n", + " \n", + " \n", + " 45\n", + " None\n", + " tfrecord-00496-of-01000_43.json\n", + " 9\n", + " 0.0\n", + " 0.094\n", + " 2.482\n", + " 0.000\n", + " 155.690\n", + " 256.373\n", + " 0.333\n", + " 0.222\n", + " 0.000\n", + " IL\n", + " \n", + " \n", + " 46\n", + " None\n", + " tfrecord-00601-of-01000_213.json\n", + " 3\n", + " 0.0\n", + " 0.050\n", + " 1.667\n", + " 0.194\n", + " 2.774\n", + " 0.411\n", + " 0.000\n", + " 0.333\n", + " 0.667\n", + " IL\n", + " \n", + " \n", + " 47\n", + " None\n", + " tfrecord-00537-of-01000_284.json\n", + " 6\n", + " 0.0\n", + " 0.040\n", + " 1.947\n", + " 0.313\n", + " 14.913\n", + " 3.075\n", + " 0.000\n", + " 0.333\n", + " 0.500\n", + " IL\n", + " \n", + " \n", + " 48\n", + " None\n", + " tfrecord-00736-of-01000_303.json\n", + " 13\n", + " 0.0\n", + " 0.066\n", + " 1.944\n", + " 0.195\n", + " 180.957\n", + " 115.203\n", + " 0.154\n", + " 0.231\n", + " 0.231\n", + " IL\n", + " \n", + " \n", + " 49\n", + " None\n", + " tfrecord-00646-of-01000_105.json\n", + " 2\n", + " 0.0\n", + " 0.038\n", + " 3.075\n", + " 0.070\n", + " 6.076\n", + " 2.975\n", + " 0.000\n", + " 0.000\n", + " 1.000\n", + " IL\n", + " \n", + " \n", + "\n", + "" + ], + "text/plain": [ + " run_id traffic_scene agents_controlled reg_coef \\\n", + "0 None tfrecord-00350-of-01000_410.json 2 0.0 \n", + "1 None tfrecord-00349-of-01000_311.json 2 0.0 \n", + "2 None tfrecord-00421-of-01000_364.json 9 0.0 \n", + "3 None tfrecord-00243-of-01000_280.json 2 0.0 \n", + "4 None tfrecord-00275-of-01000_483.json 3 0.0 \n", + "5 None tfrecord-00201-of-01000_202.json 16 0.0 \n", + "6 None tfrecord-00623-of-01000_242.json 8 0.0 \n", + "7 None tfrecord-00316-of-01000_183.json 10 0.0 \n", + "8 None tfrecord-00443-of-01000_487.json 6 0.0 \n", + "9 None tfrecord-00908-of-01000_302.json 20 0.0 \n", + "10 None tfrecord-00395-of-01000_97.json 18 0.0 \n", + "11 None tfrecord-00845-of-01000_433.json 3 0.0 \n", + "12 None tfrecord-00628-of-01000_338.json 3 0.0 \n", + "13 None tfrecord-00498-of-01000_229.json 7 0.0 \n", + "14 None tfrecord-00069-of-01000_193.json 5 0.0 \n", + "15 None tfrecord-00271-of-01000_435.json 18 0.0 \n", + "16 None tfrecord-00480-of-01000_25.json 3 0.0 \n", + "17 None tfrecord-00683-of-01000_42.json 2 0.0 \n", + "18 None tfrecord-00384-of-01000_192.json 3 0.0 \n", + "19 None tfrecord-00073-of-01000_173.json 4 0.0 \n", + "20 None tfrecord-00468-of-01000_290.json 3 0.0 \n", + "21 None tfrecord-00656-of-01000_18.json 3 0.0 \n", + "22 None tfrecord-00651-of-01000_285.json 4 0.0 \n", + "23 None tfrecord-00364-of-01000_319.json 2 0.0 \n", + "24 None tfrecord-00615-of-01000_417.json 13 0.0 \n", + "25 None tfrecord-00795-of-01000_89.json 3 0.0 \n", + "26 None tfrecord-00880-of-01000_341.json 20 0.0 \n", + "27 None tfrecord-00187-of-01000_350.json 3 0.0 \n", + "28 None tfrecord-00428-of-01000_234.json 8 0.0 \n", + "29 None tfrecord-00748-of-01000_178.json 3 0.0 \n", + "30 None tfrecord-00481-of-01000_309.json 2 0.0 \n", + "31 None tfrecord-00246-of-01000_104.json 5 0.0 \n", + "32 None tfrecord-00880-of-01000_456.json 20 0.0 \n", + "33 None tfrecord-00576-of-01000_258.json 5 0.0 \n", + "34 None tfrecord-00342-of-01000_476.json 20 0.0 \n", + "35 None tfrecord-00445-of-01000_61.json 18 0.0 \n", + "36 None tfrecord-00430-of-01000_75.json 1 0.0 \n", + "37 None tfrecord-00518-of-01000_117.json 4 0.0 \n", + "38 None tfrecord-00444-of-01000_226.json 20 0.0 \n", + "39 None tfrecord-00083-of-01000_389.json 20 0.0 \n", + "40 None tfrecord-00119-of-01000_361.json 20 0.0 \n", + "41 None tfrecord-00559-of-01000_436.json 6 0.0 \n", + "42 None tfrecord-00172-of-01000_326.json 5 0.0 \n", + "43 None tfrecord-00444-of-01000_42.json 8 0.0 \n", + "44 None tfrecord-00569-of-01000_314.json 13 0.0 \n", + "45 None tfrecord-00496-of-01000_43.json 9 0.0 \n", + "46 None tfrecord-00601-of-01000_213.json 3 0.0 \n", + "47 None tfrecord-00537-of-01000_284.json 6 0.0 \n", + "48 None tfrecord-00736-of-01000_303.json 13 0.0 \n", + "49 None tfrecord-00646-of-01000_105.json 2 0.0 \n", + "\n", + " act_acc accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate \\\n", + "0 0.094 1.526 0.127 5.253 1.066 0.000 \n", + "1 0.019 2.409 0.281 3.528 1.018 0.000 \n", + "2 0.067 1.970 0.164 7.574 1.055 0.000 \n", + "3 0.012 1.974 0.239 2.759 0.555 0.000 \n", + "4 0.050 1.833 0.220 3.999 0.621 0.000 \n", + "5 0.094 2.250 0.003 140.214 57.562 0.312 \n", + "6 0.058 2.032 0.375 205.846 114.555 0.125 \n", + "7 0.162 2.260 0.013 8.700 0.760 0.200 \n", + "8 0.025 2.143 0.294 9.119 1.796 0.000 \n", + "9 0.149 2.089 0.039 12.949 1.180 0.150 \n", + "10 0.076 2.333 0.292 253.132 200.875 0.000 \n", + "11 0.133 1.311 0.190 3.827 0.630 0.000 \n", + "12 0.029 1.485 0.367 4.083 0.509 0.000 \n", + "13 0.088 2.199 0.113 195.827 120.754 0.143 \n", + "14 0.115 1.891 0.144 169.086 297.963 0.000 \n", + "15 0.178 1.959 0.024 14.123 1.407 0.056 \n", + "16 0.142 1.460 0.146 9.145 1.885 0.333 \n", + "17 0.262 1.287 0.102 8.824 2.010 0.000 \n", + "18 0.117 1.701 0.096 7.469 1.049 0.000 \n", + "19 0.178 2.130 0.046 8.065 1.571 0.000 \n", + "20 0.017 2.952 0.461 4.592 2.468 0.000 \n", + "21 0.025 2.386 0.392 6.863 3.039 0.000 \n", + "22 0.084 1.902 0.050 4.980 0.859 0.250 \n", + "23 0.088 1.411 0.288 4.417 0.519 0.000 \n", + "24 0.034 2.531 0.350 11.520 2.748 0.154 \n", + "25 0.042 2.039 0.180 171.441 274.731 0.333 \n", + "26 0.034 2.709 0.451 168.273 81.749 0.000 \n", + "27 0.104 1.837 0.169 6.644 1.400 0.000 \n", + "28 0.059 1.558 0.465 8.172 1.171 0.000 \n", + "29 0.000 2.302 0.334 3.596 1.461 0.000 \n", + "30 0.106 2.680 0.135 10.895 3.673 0.000 \n", + "31 0.202 2.069 0.028 13.658 2.539 0.000 \n", + "32 0.090 2.430 0.159 141.539 35.880 0.150 \n", + "33 0.105 1.928 0.166 6.807 1.368 0.200 \n", + "34 0.034 2.557 0.232 221.795 177.375 0.300 \n", + "35 0.124 2.418 0.056 181.373 72.507 0.056 \n", + "36 0.062 1.224 0.203 4.086 1.092 0.000 \n", + "37 0.078 2.130 0.054 8.837 1.530 0.250 \n", + "38 0.038 2.863 0.268 255.838 181.061 0.150 \n", + "39 0.011 2.175 0.456 33.302 2.794 0.150 \n", + "40 0.124 2.261 0.016 139.079 33.757 0.300 \n", + "41 0.029 2.778 0.236 148.869 316.467 0.333 \n", + "42 0.022 2.604 0.361 6.835 2.275 0.000 \n", + "43 0.142 1.849 0.063 220.961 209.358 0.000 \n", + "44 0.016 3.279 0.481 15.319 5.635 0.154 \n", + "45 0.094 2.482 0.000 155.690 256.373 0.333 \n", + "46 0.050 1.667 0.194 2.774 0.411 0.000 \n", + "47 0.040 1.947 0.313 14.913 3.075 0.000 \n", + "48 0.066 1.944 0.195 180.957 115.203 0.154 \n", + "49 0.038 3.075 0.070 6.076 2.975 0.000 \n", + "\n", + " veh_edge_cr veh_veh_cr class \n", + "0 1.000 0.000 IL \n", + "1 0.000 1.000 IL \n", + "2 0.222 0.667 IL \n", + "3 0.000 1.000 IL \n", + "4 0.667 0.333 IL \n", + "5 0.250 0.000 IL \n", + "6 0.000 0.375 IL \n", + "7 0.200 0.200 IL \n", + "8 0.333 0.500 IL \n", + "9 0.200 0.200 IL \n", + "10 0.833 0.111 IL \n", + "11 0.333 0.667 IL \n", + "12 0.333 0.333 IL \n", + "13 0.571 0.143 IL \n", + "14 0.400 0.000 IL \n", + "15 0.111 0.389 IL \n", + "16 0.000 0.000 IL \n", + "17 0.500 0.000 IL \n", + "18 0.667 0.000 IL \n", + "19 0.500 0.000 IL \n", + "20 0.333 0.667 IL \n", + "21 1.000 0.000 IL \n", + "22 0.250 0.250 IL \n", + "23 1.000 0.000 IL \n", + "24 0.385 0.385 IL \n", + "25 0.667 0.000 IL \n", + "26 0.500 0.500 IL \n", + "27 0.667 0.000 IL \n", + "28 0.875 0.125 IL \n", + "29 0.000 1.000 IL \n", + "30 1.000 0.000 IL \n", + "31 0.600 0.000 IL \n", + "32 0.450 0.200 IL \n", + "33 0.600 0.000 IL \n", + "34 0.050 0.600 IL \n", + "35 0.111 0.333 IL \n", + "36 1.000 0.000 IL \n", + "37 0.500 0.000 IL \n", + "38 0.650 0.100 IL \n", + "39 0.300 0.350 IL \n", + "40 0.100 0.400 IL \n", + "41 0.333 0.167 IL \n", + "42 0.800 0.000 IL \n", + "43 0.125 0.375 IL \n", + "44 0.385 0.462 IL \n", + "45 0.222 0.000 IL \n", + "46 0.333 0.667 IL \n", + "47 0.333 0.500 IL \n", + "48 0.231 0.231 IL \n", + "49 0.000 1.000 IL " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_il_res.round(3)" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -967,7282 +1235,624 @@ "text": [ "Scene tfrecord-00012-of-01000_389.json\n" ] + } + ], + "source": [ + "scene_id = 'tfrecord-00012-of-01000_389.json'\n", + "\n", + "df_scene = df_il_trajs[df_il_trajs.traffic_scene == scene_id]\n", + "print(f\"Scene {scene_id}\")\n", + "for agent_id in df_scene.agent_id.unique():\n", + " agent_df = df_scene[df_scene.agent_id == agent_id]\n", + " plot_agent_trajectory(agent_df, evaluator.env.action_space.n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. **Human-regularized RL**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Select policies to evaluate\n", + "- Add the model folder path and models to use\n", + "- Use `manage_models.py` to download models from `wandb` first" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "RL_POLICY_PATH = \"../models/hr_rl\"\n", + "\n", + "# Scenes on which to evaluate the models\n", + "# rl_policy_paths = glob.glob(f\"{RL_POLICY_PATH}\" + \"/*.pt\")\n", + "# rl_policy_names = [os.path.basename(file)[:-3] for file in rl_policy_paths]\n", + "reg_weights = [0.0]\n", + "\n", + "rl_policy_names = ['nocturne-hr-ppo-01_02_11_16_0.00_S100']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 Evaluate policies on train dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating policy nocturne-hr-ppo-01_02_11_16_0.00_S100\n" + ] }, { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-12-29T13:41:41.799891\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.8.0, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-12-29T13:41:42.011944\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.8.0, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-12-29T13:41:42.230562\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.8.0, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
goal_rateveh_edge_crveh_veh_cr
Type
train0.108470.4234770.243685
valid0.117560.4002980.235119
\n", + "" + ], + "text/plain": [ + " goal_rate veh_edge_cr veh_veh_cr\n", + "Type \n", + "train 0.10847 0.423477 0.243685\n", + "valid 0.11756 0.400298 0.235119" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_hr_rl.groupby('Type')[performance_metrics].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
act_accaccel_val_maesteer_val_maepos_rmsespeed_mae
Type
train0.0171062.2021200.43347613.270434117.405705
valid0.0139512.1984270.42922911.46556276.767855
\n", + "
" + ], + "text/plain": [ + " act_acc accel_val_mae steer_val_mae pos_rmse speed_mae\n", + "Type \n", + "train 0.017106 2.202120 0.433476 13.270434 117.405705\n", + "valid 0.013951 2.198427 0.429229 11.465562 76.767855" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_hr_rl.groupby('Type')[human_like_metrics].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
run_idreg_coeftraffic_sceneagent_idact_accaccel_val_maesteer_val_maepos_rmsespeed_maegoal_rateveh_edge_crveh_veh_crTypenum_scenes
0None0.0tfrecord-00001-of-01000_307.json00.07502.3269230.3141037.0413831.8209370.00.00.0train100
1None0.0tfrecord-00001-of-01000_307.json10.01251.9736840.3500003.8274341.1049080.01.00.0train100
2None0.0tfrecord-00001-of-01000_307.json90.00002.0886080.4120258.2453582.6181470.00.00.0train100
3None0.0tfrecord-00004-of-01000_378.json00.00002.4210530.5035095.0615391.4135940.00.00.0train100
4None0.0tfrecord-00004-of-01000_378.json20.00001.5000000.4250001.9427670.6216201.00.00.0train100
\n", + "
" + ], + "text/plain": [ + " run_id reg_coef traffic_scene agent_id act_acc \\\n", + "0 None 0.0 tfrecord-00001-of-01000_307.json 0 0.0750 \n", + "1 None 0.0 tfrecord-00001-of-01000_307.json 1 0.0125 \n", + "2 None 0.0 tfrecord-00001-of-01000_307.json 9 0.0000 \n", + "3 None 0.0 tfrecord-00004-of-01000_378.json 0 0.0000 \n", + "4 None 0.0 tfrecord-00004-of-01000_378.json 2 0.0000 \n", + "\n", + " accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate veh_edge_cr \\\n", + "0 2.326923 0.314103 7.041383 1.820937 0.0 0.0 \n", + "1 1.973684 0.350000 3.827434 1.104908 0.0 1.0 \n", + "2 2.088608 0.412025 8.245358 2.618147 0.0 0.0 \n", + "3 2.421053 0.503509 5.061539 1.413594 0.0 0.0 \n", + "4 1.500000 0.425000 1.942767 0.621620 1.0 0.0 \n", + "\n", + " veh_veh_cr Type num_scenes \n", + "0 0.0 train 100 \n", + "1 0.0 train 100 \n", + "2 0.0 train 100 \n", + "3 0.0 train 100 \n", + "4 0.0 train 100 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_hr_rl.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1 What is the overall aggregated performance?\n", + "\n", + "- Look at the error distribution\n", + "- What makes solving a scene hard?" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-01-03T10:32:57.407483\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" style=\"fill: none\"/>\n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", + " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#pc7241d3713)\" style=\"fill: #5875a4; stroke: #ffffff; stroke-linejoin: miter\"/>\n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-12-29T13:41:42.448435\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.8.0, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + "\" clip-path=\"url(#pc7241d3713)\" style=\"fill: #cc8963; stroke: #ffffff; stroke-linejoin: miter\"/>\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#pd9cc63e097)\" style=\"fill: #cc8963; stroke: #ffffff; stroke-linejoin: miter\"/>\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n", + "\n", + "fig.suptitle('Human likeness scores for PPO policies', fontsize=16)\n", + "\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='act_acc', ax=axs[0], hue='Type', legend=False)\n", + "\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='accel_val_mae', ax=axs[1], hue='Type', legend=False)\n", + "\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='steer_val_mae', ax=axs[2], hue='Type', legend=False)\n", + "\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='speed_mae', ax=axs[3], hue='Type', legend=False)\n", + "\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='pos_rmse', ax=axs[4], hue='Type', legend=True)\n", + "\n", + "fig.tight_layout()\n", + "sns.despine()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-01-03T10:32:59.762425\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.0, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", + " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -13100,95 +4892,63 @@ } ], "source": [ - "scene_id = 'tfrecord-00012-of-01000_389.json'\n", + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", "\n", - "df_scene = df_il_trajs[df_il_trajs.traffic_scene == scene_id]\n", - "print(f\"Scene {scene_id}\")\n", - "for agent_id in df_scene.agent_id.unique():\n", - " agent_df = df_scene[df_scene.agent_id == agent_id]\n", - " plot_agent_trajectory(agent_df, evaluator.env.action_space.n)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2. **Human-regularized RL**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "RL_POLICY_PATH = \"../models/hr_rl\"\n", + "fig.suptitle(f'Performance scores for PPO policies | DET = {DETERMINISTIC}', fontsize=16)\n", "\n", - "# Scenes on which to evaluate the models\n", - "rl_policy_paths = glob.glob(f\"{RL_POLICY_PATH}\" + \"/*.pt\")\n", - "rl_policy_names = [os.path.basename(file)[:-3] for file in rl_policy_paths]\n", - "reg_weights = [0.0]\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='goal_rate', ax=axs[0], hue='Type', legend=False)\n", "\n", + "sns.barplot(data=df_hr_rl, x='num_scenes', y='veh_edge_cr', ax=axs[1], hue='Type', legend=False)\n", "\n", - "rl_policy_names" + "sns.barplot(data=df_hr_rl, x='num_scenes', y='veh_veh_cr', ax=axs[2], hue='Type', legend=False)\n", + "\n", + "fig.tight_layout()\n", + "sns.despine()" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 35, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating policy nocturne-hr-ppo-12_29_07_55_0.025\n", - "Evaluating policy nocturne-hr-ppo-12_29_07_52_0.00\n" - ] + "data": { + "text/plain": [ + "199" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "df_hr_rl_all = pd.DataFrame()\n", - "\n", - "for idx, policy in enumerate(rl_policy_names):\n", - "\n", - " print(f'Evaluating policy {policy}')\n", - "\n", - " # Load trained model from artifact dir\n", - " checkpoint = torch.load(f\"{RL_POLICY_PATH}/{policy}.pt\")\n", - " policy = LateFusionMLPPolicy(\n", - " observation_space=checkpoint['data']['observation_space'],\n", - " action_space=checkpoint['data']['action_space'],\n", - " lr_schedule=checkpoint['data']['lr_schedule'],\n", - " use_sde=checkpoint['data']['use_sde'],\n", - " env_config=env_config,\n", - " mlp_class=LateFusionMLP,\n", - " #mlp_config=checkpoint['model'],\n", - " )\n", - " policy.load_state_dict(checkpoint['state_dict'])\n", - " policy.eval();\n", - "\n", - " # Evaluate on scenes\n", - " evaluator = EvaluatePolicy(\n", - " env_config=env_config, \n", - " exp_config=exp_config,\n", - " policy=policy,\n", - " eval_files=eval_files,\n", - " log_to_wandb=False, \n", - " deterministic=True,\n", - " reg_coef=reg_weights[idx],\n", - " return_trajectories=True,\n", - " )\n", - "\n", - " df_rl_res, _ = evaluator._get_scores()\n", - " df_rl_res['class'] = f\"HR_RL: {reg_weights[idx]}\"\n", - "\n", - " df_hr_rl_all = pd.concat([df_hr_rl_all, df_rl_res])" + "len(df_hr_rl['traffic_scene'].unique())" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.04001020194409743" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_hr_rl[df_hr_rl['Type'] == 'train'].groupby('traffic_scene')['goal_rate'].mean()) / 199" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -13213,494 +4973,155 @@ " \n", " \n", " run_id\n", - " traffic_scene\n", - " agents_controlled\n", " reg_coef\n", + " traffic_scene\n", + " agent_id\n", " act_acc\n", + " accel_val_mae\n", + " steer_val_mae\n", " pos_rmse\n", " speed_mae\n", " goal_rate\n", " veh_edge_cr\n", " veh_veh_cr\n", - " num_violations\n", - " class\n", + " Type\n", + " num_scenes\n", " \n", " \n", " \n", " \n", " 0\n", " None\n", - " tfrecord-00004-of-01000_378.json\n", - " 5\n", - " 0.025\n", - " 0.142500\n", - " 8.358297\n", - " 1.440388\n", - " 0.400000\n", - " 0.000000\n", - " 0.200000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 1\n", - " None\n", - " tfrecord-00003-of-01000_109.json\n", - " 20\n", - " 0.025\n", - " 0.105000\n", - " 119.291848\n", - " 21.604723\n", - " 0.800000\n", - " 0.000000\n", - " 0.200000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 2\n", - " None\n", - " tfrecord-00004-of-01000_61.json\n", - " 6\n", - " 0.025\n", - " 0.164583\n", - " 6.606689\n", - " 1.097389\n", - " 0.666667\n", - " 0.333333\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 3\n", - " None\n", - " tfrecord-00012-of-01000_87.json\n", - " 9\n", - " 0.025\n", - " 0.133333\n", - " 164.076051\n", - " 139.820955\n", - " 0.777778\n", - " 0.000000\n", - " 0.222222\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 4\n", - " None\n", - " tfrecord-00007-of-01000_237.json\n", - " 14\n", - " 0.025\n", - " 0.184821\n", - " 13.654639\n", - " 1.436978\n", - " 0.714286\n", - " 0.142857\n", - " 0.142857\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 5\n", - " None\n", - " tfrecord-00005-of-01000_423.json\n", - " 4\n", - " 0.025\n", - " 0.312500\n", - " 13.714191\n", - " 1.655821\n", - " 0.500000\n", - " 0.000000\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 6\n", - " None\n", - " tfrecord-00012-of-01000_246.json\n", - " 20\n", - " 0.025\n", - " 0.113125\n", - " 193.501743\n", - " 98.045055\n", - " 0.750000\n", - " 0.150000\n", - " 0.100000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 7\n", - " None\n", - " tfrecord-00012-of-01000_389.json\n", - " 4\n", - " 0.025\n", - " 0.121875\n", - " 123.742438\n", - " 64.752366\n", - " 0.750000\n", - " 0.000000\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 8\n", - " None\n", + " 0.0\n", " tfrecord-00001-of-01000_307.json\n", - " 3\n", - " 0.025\n", - " 0.133333\n", - " 7.316937\n", - " 1.267483\n", - " 0.333333\n", - " 0.333333\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 9\n", - " None\n", - " tfrecord-00004-of-01000_157.json\n", - " 11\n", - " 0.025\n", - " 0.122727\n", - " 260.829870\n", - " 315.457867\n", - " 0.545455\n", - " 0.000000\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.025\n", - " \n", - " \n", - " 0\n", - " None\n", - " tfrecord-00004-of-01000_378.json\n", - " 5\n", - " 0.000\n", - " 0.025000\n", - " 8.214314\n", - " 1.658601\n", - " 0.600000\n", - " 0.200000\n", - " 0.000000\n", " 0\n", - " HR_RL: 0.0\n", + " 0.0750\n", + " 2.326923\n", + " 0.314103\n", + " 7.041383\n", + " 1.820937\n", + " 0.0\n", + " 0.0\n", + " 0.0\n", + " train\n", + " 100\n", " \n", " \n", " 1\n", " None\n", - " tfrecord-00003-of-01000_109.json\n", - " 20\n", - " 0.000\n", - " 0.059375\n", - " 119.300123\n", - " 23.499673\n", - " 0.850000\n", - " 0.050000\n", - " 0.100000\n", - " 0\n", - " HR_RL: 0.0\n", + " 0.0\n", + " tfrecord-00001-of-01000_307.json\n", + " 1\n", + " 0.0125\n", + " 1.973684\n", + " 0.350000\n", + " 3.827434\n", + " 1.104908\n", + " 0.0\n", + " 1.0\n", + " 0.0\n", + " train\n", + " 100\n", " \n", " \n", " 2\n", " None\n", - " tfrecord-00004-of-01000_61.json\n", - " 6\n", - " 0.000\n", - " 0.029167\n", - " 8.344646\n", - " 1.510286\n", - " 0.666667\n", - " 0.333333\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 3\n", - " None\n", - " tfrecord-00012-of-01000_87.json\n", - " 9\n", - " 0.000\n", - " 0.044444\n", - " 164.078821\n", - " 127.039760\n", - " 0.777778\n", - " 0.000000\n", - " 0.222222\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 4\n", - " None\n", - " tfrecord-00007-of-01000_237.json\n", - " 14\n", - " 0.000\n", - " 0.067857\n", - " 14.361484\n", - " 1.978035\n", - " 0.642857\n", - " 0.142857\n", - " 0.142857\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 5\n", - " None\n", - " tfrecord-00005-of-01000_423.json\n", - " 4\n", - " 0.000\n", - " 0.096875\n", - " 12.347980\n", - " 2.126997\n", - " 0.750000\n", - " 0.000000\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 6\n", - " None\n", - " tfrecord-00012-of-01000_246.json\n", - " 20\n", - " 0.000\n", - " 0.034375\n", - " 193.515814\n", - " 112.653371\n", - " 0.750000\n", - " 0.050000\n", - " 0.200000\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 7\n", - " None\n", - " tfrecord-00012-of-01000_389.json\n", - " 4\n", - " 0.000\n", - " 0.034375\n", - " 123.741883\n", - " 102.708780\n", - " 0.500000\n", - " 0.000000\n", - " 0.500000\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 8\n", - " None\n", + " 0.0\n", " tfrecord-00001-of-01000_307.json\n", - " 3\n", - " 0.000\n", - " 0.025000\n", - " 8.964723\n", - " 2.219035\n", - " 1.000000\n", - " 0.000000\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - " 9\n", - " None\n", - " tfrecord-00004-of-01000_157.json\n", - " 11\n", - " 0.000\n", - " 0.057955\n", - " 260.830741\n", - " 314.463060\n", - " 0.545455\n", - " 0.000000\n", - " 0.000000\n", - " 0\n", - " HR_RL: 0.0\n", - " \n", - " \n", - "\n", - "" - ], - "text/plain": [ - " run_id traffic_scene agents_controlled reg_coef \\\n", - "0 None tfrecord-00004-of-01000_378.json 5 0.025 \n", - "1 None tfrecord-00003-of-01000_109.json 20 0.025 \n", - "2 None tfrecord-00004-of-01000_61.json 6 0.025 \n", - "3 None tfrecord-00012-of-01000_87.json 9 0.025 \n", - "4 None tfrecord-00007-of-01000_237.json 14 0.025 \n", - "5 None tfrecord-00005-of-01000_423.json 4 0.025 \n", - "6 None tfrecord-00012-of-01000_246.json 20 0.025 \n", - "7 None tfrecord-00012-of-01000_389.json 4 0.025 \n", - "8 None tfrecord-00001-of-01000_307.json 3 0.025 \n", - "9 None tfrecord-00004-of-01000_157.json 11 0.025 \n", - "0 None tfrecord-00004-of-01000_378.json 5 0.000 \n", - "1 None tfrecord-00003-of-01000_109.json 20 0.000 \n", - "2 None tfrecord-00004-of-01000_61.json 6 0.000 \n", - "3 None tfrecord-00012-of-01000_87.json 9 0.000 \n", - "4 None tfrecord-00007-of-01000_237.json 14 0.000 \n", - "5 None tfrecord-00005-of-01000_423.json 4 0.000 \n", - "6 None tfrecord-00012-of-01000_246.json 20 0.000 \n", - "7 None tfrecord-00012-of-01000_389.json 4 0.000 \n", - "8 None tfrecord-00001-of-01000_307.json 3 0.000 \n", - "9 None tfrecord-00004-of-01000_157.json 11 0.000 \n", - "\n", - " act_acc pos_rmse speed_mae goal_rate veh_edge_cr veh_veh_cr \\\n", - "0 0.142500 8.358297 1.440388 0.400000 0.000000 0.200000 \n", - "1 0.105000 119.291848 21.604723 0.800000 0.000000 0.200000 \n", - "2 0.164583 6.606689 1.097389 0.666667 0.333333 0.000000 \n", - "3 0.133333 164.076051 139.820955 0.777778 0.000000 0.222222 \n", - "4 0.184821 13.654639 1.436978 0.714286 0.142857 0.142857 \n", - "5 0.312500 13.714191 1.655821 0.500000 0.000000 0.000000 \n", - "6 0.113125 193.501743 98.045055 0.750000 0.150000 0.100000 \n", - "7 0.121875 123.742438 64.752366 0.750000 0.000000 0.000000 \n", - "8 0.133333 7.316937 1.267483 0.333333 0.333333 0.000000 \n", - "9 0.122727 260.829870 315.457867 0.545455 0.000000 0.000000 \n", - "0 0.025000 8.214314 1.658601 0.600000 0.200000 0.000000 \n", - "1 0.059375 119.300123 23.499673 0.850000 0.050000 0.100000 \n", - "2 0.029167 8.344646 1.510286 0.666667 0.333333 0.000000 \n", - "3 0.044444 164.078821 127.039760 0.777778 0.000000 0.222222 \n", - "4 0.067857 14.361484 1.978035 0.642857 0.142857 0.142857 \n", - "5 0.096875 12.347980 2.126997 0.750000 0.000000 0.000000 \n", - "6 0.034375 193.515814 112.653371 0.750000 0.050000 0.200000 \n", - "7 0.034375 123.741883 102.708780 0.500000 0.000000 0.500000 \n", - "8 0.025000 8.964723 2.219035 1.000000 0.000000 0.000000 \n", - "9 0.057955 260.830741 314.463060 0.545455 0.000000 0.000000 \n", - "\n", - " num_violations class \n", - "0 0 HR_RL: 0.025 \n", - "1 0 HR_RL: 0.025 \n", - "2 0 HR_RL: 0.025 \n", - "3 0 HR_RL: 0.025 \n", - "4 0 HR_RL: 0.025 \n", - "5 0 HR_RL: 0.025 \n", - "6 0 HR_RL: 0.025 \n", - "7 0 HR_RL: 0.025 \n", - "8 0 HR_RL: 0.025 \n", - "9 0 HR_RL: 0.025 \n", - "0 0 HR_RL: 0.0 \n", - "1 0 HR_RL: 0.0 \n", - "2 0 HR_RL: 0.0 \n", - "3 0 HR_RL: 0.0 \n", - "4 0 HR_RL: 0.0 \n", - "5 0 HR_RL: 0.0 \n", - "6 0 HR_RL: 0.0 \n", - "7 0 HR_RL: 0.0 \n", - "8 0 HR_RL: 0.0 \n", - "9 0 HR_RL: 0.0 " - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_hr_rl_all" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4. **Summary figures**" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.1 What is the overall performance, and how is it distributed across scenes?\n", - "\n", - "- Error distribution" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
act_accgoal_rateveh_edge_crveh_veh_cr90.00002.0886080.4120258.2453582.6181470.00.00.0train100
reg_coef3None0.0tfrecord-00004-of-01000_378.json00.00002.4210530.5035095.0615391.4135940.00.00.0train100
0.00.1980.2830.2650.0854None0.0tfrecord-00004-of-01000_378.json20.00001.5000000.4250001.9427670.6216201.00.00.0train100
\n", "
" ], "text/plain": [ - " act_acc goal_rate veh_edge_cr veh_veh_cr\n", - "reg_coef \n", - "0.0 0.198 0.283 0.265 0.085" + " run_id reg_coef traffic_scene agent_id act_acc \\\n", + "0 None 0.0 tfrecord-00001-of-01000_307.json 0 0.0750 \n", + "1 None 0.0 tfrecord-00001-of-01000_307.json 1 0.0125 \n", + "2 None 0.0 tfrecord-00001-of-01000_307.json 9 0.0000 \n", + "3 None 0.0 tfrecord-00004-of-01000_378.json 0 0.0000 \n", + "4 None 0.0 tfrecord-00004-of-01000_378.json 2 0.0000 \n", + "\n", + " accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate veh_edge_cr \\\n", + "0 2.326923 0.314103 7.041383 1.820937 0.0 0.0 \n", + "1 1.973684 0.350000 3.827434 1.104908 0.0 1.0 \n", + "2 2.088608 0.412025 8.245358 2.618147 0.0 0.0 \n", + "3 2.421053 0.503509 5.061539 1.413594 0.0 0.0 \n", + "4 1.500000 0.425000 1.942767 0.621620 1.0 0.0 \n", + "\n", + " veh_veh_cr Type num_scenes \n", + "0 0.0 train 100 \n", + "1 0.0 train 100 \n", + "2 0.0 train 100 \n", + "3 0.0 train 100 \n", + "4 0.0 train 100 " ] }, - "execution_count": 57, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metrics = ['act_acc', 'goal_rate', 'veh_edge_cr', 'veh_veh_cr', 'reg_coef']\n", - "\n", - "# Aggregate performance HR-PPO\n", - "df_hr_rl_agg = df_hr_rl_all[metrics].groupby('reg_coef').mean().round(3)\n", - "\n", - "# Aggregate performance IL\n", - "df_il_agg = df_il_res[metrics].groupby('reg_coef').mean().round(3)\n", - "\n", - "df_il_agg" + "df_hr_rl.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_hr_rl.groupby('traffic_scene')" ] }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -13724,14 +5145,12 @@ " \n", " \n", " \n", - " act_acc\n", " goal_rate\n", " veh_edge_cr\n", " veh_veh_cr\n", " \n", " \n", - " reg_coef\n", - " \n", + " traffic_scene\n", " \n", " \n", " \n", @@ -13739,42 +5158,106 @@ " \n", " \n", " \n", - " 0.000\n", - " 0.047\n", - " 0.708\n", - " 0.078\n", - " 0.117\n", + " tfrecord-00001-of-01000_307.json\n", + " 0.000000\n", + " 0.333333\n", + " 0.000000\n", " \n", " \n", - " 0.025\n", - " 0.153\n", - " 0.624\n", - " 0.096\n", - " 0.087\n", + " tfrecord-00004-of-00150_246.json\n", + " 0.333333\n", + " 0.666667\n", + " 0.000000\n", + " \n", + " \n", + " tfrecord-00004-of-01000_378.json\n", + " 0.200000\n", + " 0.400000\n", + " 0.000000\n", + " \n", + " \n", + " tfrecord-00005-of-00150_192.json\n", + " 0.111111\n", + " 0.333333\n", + " 0.222222\n", + " \n", + " \n", + " tfrecord-00005-of-01000_423.json\n", + " 0.000000\n", + " 1.000000\n", + " 0.000000\n", + " \n", + " \n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", + " \n", + " \n", + " tfrecord-00144-of-00150_65.json\n", + " 0.000000\n", + " 0.857143\n", + " 0.000000\n", + " \n", + " \n", + " tfrecord-00146-of-00150_228.json\n", + " 0.000000\n", + " 0.571429\n", + " 0.428571\n", + " \n", + " \n", + " tfrecord-00147-of-00150_191.json\n", + " 0.000000\n", + " 0.500000\n", + " 0.000000\n", + " \n", + " \n", + " tfrecord-00147-of-00150_6.json\n", + " 0.111111\n", + " 0.333333\n", + " 0.222222\n", + " \n", + " \n", + " tfrecord-00149-of-00150_66.json\n", + " 0.000000\n", + " 1.000000\n", + " 0.000000\n", " \n", " \n", "\n", + "

199 rows × 3 columns

\n", "" ], "text/plain": [ - " act_acc goal_rate veh_edge_cr veh_veh_cr\n", - "reg_coef \n", - "0.000 0.047 0.708 0.078 0.117\n", - "0.025 0.153 0.624 0.096 0.087" + " goal_rate veh_edge_cr veh_veh_cr\n", + "traffic_scene \n", + "tfrecord-00001-of-01000_307.json 0.000000 0.333333 0.000000\n", + "tfrecord-00004-of-00150_246.json 0.333333 0.666667 0.000000\n", + "tfrecord-00004-of-01000_378.json 0.200000 0.400000 0.000000\n", + "tfrecord-00005-of-00150_192.json 0.111111 0.333333 0.222222\n", + "tfrecord-00005-of-01000_423.json 0.000000 1.000000 0.000000\n", + "... ... ... ...\n", + "tfrecord-00144-of-00150_65.json 0.000000 0.857143 0.000000\n", + "tfrecord-00146-of-00150_228.json 0.000000 0.571429 0.428571\n", + "tfrecord-00147-of-00150_191.json 0.000000 0.500000 0.000000\n", + "tfrecord-00147-of-00150_6.json 0.111111 0.333333 0.222222\n", + "tfrecord-00149-of-00150_66.json 0.000000 1.000000 0.000000\n", + "\n", + "[199 rows x 3 columns]" ] }, - "execution_count": 58, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_hr_rl_agg" + "df_hr_rl.groupby('traffic_scene')[performance_metrics].mean()" ] }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -13783,12 +5266,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-12-29T12:03:00.410283\n", + " 2024-01-03T10:35:04.178353\n", " image/svg+xml\n", " \n", " \n", @@ -13803,80 +5286,255 @@ " \n", " \n", " \n", - " \n", + "\" style=\"fill: none\"/>\n", " \n", " \n", " \n", - " \n", + "\" style=\"fill: none\"/>\n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", + "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", + "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", + "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", - " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "\" clip-path=\"url(#p1273599569)\" style=\"fill: #4c72b0; fill-opacity: 0.75; stroke: #ffffff; stroke-linejoin: miter\"/>\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" ], "text/plain": [ - "
" + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sns.histplot(data=df_hr_rl, y='goal_rate',);" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "#plot_agent_trajectory(df_rl_trajs_sub[df_rl_trajs_sub['agent_id'] == 0], evaluator.env.action_space.n)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "#plot_agent_trajectory(df_rl_trajs_sub[df_rl_trajs_sub['agent_id'] == 1], evaluator.env.action_space.n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Make videos of top 10 most difficult scenes / top 10 easiest scenes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
run_idtraffic_sceneagents_controlledreg_coefact_accaccel_val_maesteer_val_maepos_rmsespeed_maegoal_rateveh_edge_crveh_veh_crclass
48Nonetfrecord-00061-of-01000_223.json20.00.0312501.8072290.3457835.2347001.4302120.00.50.0HR_RL: 0.0
59Nonetfrecord-00072-of-01000_20.json40.00.0031252.3546510.4151166.9955872.7233120.00.01.0HR_RL: 0.0
57Nonetfrecord-00070-of-01000_158.json30.00.0000001.3301890.5877365.2339930.2808980.01.00.0HR_RL: 0.0
56Nonetfrecord-00069-of-01000_193.json50.00.0200001.9565220.281522169.067436768.9797220.00.80.0HR_RL: 0.0
84Nonetfrecord-00095-of-01000_204.json20.00.0062502.0000000.4493835.1019350.5090490.00.50.0HR_RL: 0.0
\n", + "
" + ], + "text/plain": [ + " run_id traffic_scene agents_controlled reg_coef \\\n", + "48 None tfrecord-00061-of-01000_223.json 2 0.0 \n", + "59 None tfrecord-00072-of-01000_20.json 4 0.0 \n", + "57 None tfrecord-00070-of-01000_158.json 3 0.0 \n", + "56 None tfrecord-00069-of-01000_193.json 5 0.0 \n", + "84 None tfrecord-00095-of-01000_204.json 2 0.0 \n", + "\n", + " act_acc accel_val_mae steer_val_mae pos_rmse speed_mae goal_rate \\\n", + "48 0.031250 1.807229 0.345783 5.234700 1.430212 0.0 \n", + "59 0.003125 2.354651 0.415116 6.995587 2.723312 0.0 \n", + "57 0.000000 1.330189 0.587736 5.233993 0.280898 0.0 \n", + "56 0.020000 1.956522 0.281522 169.067436 768.979722 0.0 \n", + "84 0.006250 2.000000 0.449383 5.101935 0.509049 0.0 \n", + "\n", + " veh_edge_cr veh_veh_cr class \n", + "48 0.5 0.0 HR_RL: 0.0 \n", + "59 0.0 1.0 HR_RL: 0.0 \n", + "57 1.0 0.0 HR_RL: 0.0 \n", + "56 0.8 0.0 HR_RL: 0.0 \n", + "84 0.5 0.0 HR_RL: 0.0 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "NUM_VIDEOS = 5\n", + "\n", + "df_hr_rl.sort_values('goal_rate').head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/emerge/daphne/nocturne_lab/evaluation/wandb/run-20240103_070328-k4hdxw1b" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run RL_S100_worst_stoch to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/daphnecor/eval_hr_rl_policies" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/k4hdxw1b" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating policy on tfrecord-00061-of-01000_223.json...\n", + "Evaluating policy on tfrecord-00072-of-01000_20.json...\n", + "Evaluating policy on tfrecord-00070-of-01000_158.json...\n", + "Evaluating policy on tfrecord-00069-of-01000_193.json...\n", + "Evaluating policy on tfrecord-00095-of-01000_204.json...\n" + ] + }, + { + "data": { + "text/html": [ + " View run RL_S100_worst_stoch at: https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/k4hdxw1b
View job at https://wandb.ai/daphnecor/eval_hr_rl_policies/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyNzEwNDI4OA==/version_details/v0
Synced 5 W&B file(s), 10 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240103_070328-k4hdxw1b/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "worst_scenes = df_hr_rl_all.sort_values('goal_rate')['traffic_scene'].values[:NUM_VIDEOS]\n", + "\n", + "render_traffic_scenes(\n", + " eval_scenes=worst_scenes,\n", + " policy=policy,\n", + " run_name=\"RL_S100_worst\",\n", + " deterministic=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/emerge/daphne/nocturne_lab/evaluation/wandb/run-20240103_070415-o4cve8xh" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run RL_S100_worst_det to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/daphnecor/eval_hr_rl_policies" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/o4cve8xh" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating policy on tfrecord-00061-of-01000_223.json...\n", + "Evaluating policy on tfrecord-00072-of-01000_20.json...\n", + "Evaluating policy on tfrecord-00070-of-01000_158.json...\n", + "Evaluating policy on tfrecord-00069-of-01000_193.json...\n", + "Evaluating policy on tfrecord-00095-of-01000_204.json...\n" + ] + }, + { + "data": { + "text/html": [ + " View run RL_S100_worst_det at: https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/o4cve8xh
View job at https://wandb.ai/daphnecor/eval_hr_rl_policies/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyNzEwNDI4OA==/version_details/v0
Synced 5 W&B file(s), 10 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240103_070415-o4cve8xh/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "worst_scenes = df_hr_rl_all.sort_values('goal_rate')['traffic_scene'].values[:NUM_VIDEOS]\n", + "\n", + "render_traffic_scenes(\n", + " eval_scenes=worst_scenes,\n", + " policy=policy,\n", + " run_name=\"RL_S100_worst\",\n", + " deterministic=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/emerge/daphne/nocturne_lab/evaluation/wandb/run-20240103_073548-bspg5o1a" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run RL_S100_best_stoch to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/daphnecor/eval_hr_rl_policies" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/bspg5o1a" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating policy on tfrecord-00019-of-01000_117.json...\n", + "Evaluating policy on tfrecord-00059-of-01000_12.json...\n", + "Evaluating policy on tfrecord-00111-of-01000_359.json...\n", + "Evaluating policy on tfrecord-00079-of-01000_207.json...\n", + "Evaluating policy on tfrecord-00024-of-01000_33.json...\n" + ] + }, + { + "data": { + "text/html": [ + " View run RL_S100_best_stoch at: https://wandb.ai/daphnecor/eval_hr_rl_policies/runs/bspg5o1a
View job at https://wandb.ai/daphnecor/eval_hr_rl_policies/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyNzEwNDI4OA==/version_details/v1
Synced 5 W&B file(s), 10 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240103_073548-bspg5o1a/logs" + ], + "text/plain": [ + "" ] }, "metadata": {}, @@ -14565,7 +6594,28 @@ } ], "source": [ - "sns.barplot(data=df_hr_rl_all, y='goal_rate', hue='reg_coef', palette=\"deep\");" + "best_scenes = df_hr_rl_all.sort_values('goal_rate', ascending=False)['traffic_scene'].values[:NUM_VIDEOS]\n", + "\n", + "render_traffic_scenes(\n", + " eval_scenes=best_scenes,\n", + " policy=policy,\n", + " run_name=\"RL_S100_best\",\n", + " deterministic=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Is there a trade-off between human likeness and performance? | **Pareto plots**" ] }, { diff --git a/experiments/hr_rl/run_hr_ppo.py b/experiments/hr_rl/run_hr_ppo.py index 9b1788a9..09e90da1 100644 --- a/experiments/hr_rl/run_hr_ppo.py +++ b/experiments/hr_rl/run_hr_ppo.py @@ -37,7 +37,6 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl env = MultiAgentAsVecEnv( config=env_config, num_envs=env_config.max_num_vehicles, - train_on_single_scene=exp_config.train_on_single_scene, ) # Set up run @@ -62,6 +61,9 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl logging.info(f"Learning in {len(env.env.files)} scene(s): {env.env.files} | using {exp_config.ppo.device}") logging.info(f"--- obs_space: {env.observation_space.shape[0]} ---") logging.info(f"Action_space\n: {env.env.idx_to_actions}") + + if exp_config.reg_weight > 0.0: + logging.info(f"Regularization weight: {exp_config.reg_weight} with policy: {exp_config.human_policy_path}") # Initialize custom callback custom_callback = CustomMultiAgentCallback( @@ -135,7 +137,7 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl "arch_ego_state": [8], "arch_road_objects": [64], "arch_road_graph": [128, 64], - "arch_shared_net": [], + "arch_shared_net": [128], "act_func": "tanh", "dropout": 0.0, "last_layer_dim_pi": 64, @@ -143,16 +145,26 @@ def train(env_config, exp_config, video_config, model_config): # pylint: disabl } ) - num_files_list = [10, 100, 1000] + num_files_list = [10] + #MEMORY = [4, 2] + MEMORY = [1] - for scenes in num_files_list: - # Set regularization weight - env_config.num_files = scenes - - # Train - train( - env_config=env_config, - exp_config=exp_config, - video_config=video_config, - model_config=model_config, - ) + for mem in MEMORY: + for num_scenes in num_files_list: + + # Set memory + env_config.subscriber.n_frames_stacked = mem + + # Set regularization weight + #exp_config.reg_weight = lam + + exp_config.human_policy_path = f"models/il/human_policy_S{num_scenes}_2024_01_02.pt" + env_config.num_files = num_scenes + + # Train + train( + env_config=env_config, + exp_config=exp_config, + video_config=video_config, + model_config=model_config, + ) diff --git a/experiments/hr_rl/run_hr_ppo_cli.py b/experiments/hr_rl/run_hr_ppo_cli.py index 34ca1dfd..3fb542dd 100644 --- a/experiments/hr_rl/run_hr_ppo_cli.py +++ b/experiments/hr_rl/run_hr_ppo_cli.py @@ -41,7 +41,6 @@ "large": [256, 128, 64], } - def run_hr_ppo( sweep_name: str = exp_config.group, steer_disc: int = 5, diff --git a/experiments/il/run_behavioral_cloning.py b/experiments/il/run_behavioral_cloning.py index 2812a9de..84f90880 100644 --- a/experiments/il/run_behavioral_cloning.py +++ b/experiments/il/run_behavioral_cloning.py @@ -19,7 +19,7 @@ if __name__ == "__main__": MAX_EVAL_FILES = 12 - NUM_TRAIN_FILES = 1000 + NUM_TRAIN_FILES = 50 # Create run run = wandb.init( diff --git a/networks/mlp_late_fusion.py b/networks/mlp_late_fusion.py index aed85e0b..c625628c 100644 --- a/networks/mlp_late_fusion.py +++ b/networks/mlp_late_fusion.py @@ -49,10 +49,10 @@ def __init__( self.arch_shared_net = arch_shared_net self.dropout = dropout - #TODO: write function that gets this information from config - self.input_dim_ego = 10 - self.input_dim_road_graph = 6500 - self.input_dim_road_objects = 220 + #TODO @Daphne: write function that gets this information from config + self.input_dim_ego = 10 * self.config.subscriber.n_frames_stacked + self.input_dim_road_graph = 6500 * self.config.subscriber.n_frames_stacked + self.input_dim_road_objects = 220 * self.config.subscriber.n_frames_stacked # IMPORTANT:Save output dimensions, used to create the distributions self.latent_dim_pi = last_layer_dim_pi @@ -198,10 +198,10 @@ def _unflatten_obs(self, obs_flat): # Visible state object order: road_objects, road_points, traffic_lights, stop_signs # Find the ends of each section - ROAD_OBJECTS_END = 13 * self.config.scenario.max_visible_objects - ROAD_POINTS_END = ROAD_OBJECTS_END + (13 * self.config.scenario.max_visible_road_points) - TL_END = ROAD_POINTS_END + (12 * self.config.scenario.max_visible_traffic_lights) - STOP_SIGN_END = TL_END + (3 * self.config.scenario.max_visible_stop_signs) + ROAD_OBJECTS_END = (13 * self.config.scenario.max_visible_objects) * self.config.subscriber.n_frames_stacked + ROAD_POINTS_END = (ROAD_OBJECTS_END + (13 * self.config.scenario.max_visible_road_points)) * self.config.subscriber.n_frames_stacked + TL_END = (ROAD_POINTS_END + (12 * self.config.scenario.max_visible_traffic_lights)) * self.config.subscriber.n_frames_stacked + STOP_SIGN_END = (TL_END + (3 * self.config.scenario.max_visible_stop_signs)) * self.config.subscriber.n_frames_stacked # Unflatten road_objects = vis_state[:, :ROAD_OBJECTS_END] @@ -251,12 +251,14 @@ def _build_mlp_extractor(self) -> None: # Load environment and experiment configurations env_config = load_config("env_config") exp_config = load_config("exp_config") + + env_config.subscriber.n_frames_stacked = 2 # Make environment env = MultiAgentAsVecEnv( config=env_config, num_envs=env_config.max_num_vehicles, - train_on_single_scene=exp_config.train_on_single_scene, + ) obs = env.reset() diff --git a/nocturne/envs/base_env.py b/nocturne/envs/base_env.py index ea089b7d..14eb7aff 100644 --- a/nocturne/envs/base_env.py +++ b/nocturne/envs/base_env.py @@ -43,10 +43,10 @@ def __init__( # pylint: disable=too-many-arguments self, config: Dict[str, Any], *, - img_width=1600, - img_height=1600, + img_width=1200, + img_height=1200, draw_target_positions=True, - padding=50.0, + padding=10.0, ) -> None: """Initialize a Nocturne environment. @@ -488,6 +488,10 @@ def _get_obs_space_dim(self, config, base=0): (3 * self.config.scenario.max_visible_stop_signs) + (12 * self.config.scenario.max_visible_traffic_lights) ) + + # Multiply by memory to get the final dimension + obs_space_dim = obs_space_dim * self.config.subscriber.n_frames_stacked + return (obs_space_dim,) def normalize_ego_state_by_cat(self, state): @@ -520,6 +524,8 @@ def render(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint Optional[RenderType]: Rendered image. """ return self.scenario.getImage(**self._render_settings) + + env.scenario.getImage(**video_config.render) def render_ego(self, mode: Optional[bool] = None) -> Optional[RenderType]: # pylint: disable=unused-argument """Render the ego vehicles. diff --git a/nocturne/envs/vec_env_ma.py b/nocturne/envs/vec_env_ma.py index ea2e38e9..e5056b77 100644 --- a/nocturne/envs/vec_env_ma.py +++ b/nocturne/envs/vec_env_ma.py @@ -26,7 +26,7 @@ class MultiAgentAsVecEnv(VecEnv): VecEnv (SB3 VecEnv): SB3 VecEnv base class. """ - def __init__(self, config, num_envs, psr=False, train_on_single_scene=False): + def __init__(self, config, num_envs, psr=False): # Create Nocturne env self.env = BaseEnv(config) @@ -44,7 +44,7 @@ def __init__(self, config, num_envs, psr=False, train_on_single_scene=False): self.frac_collided = [] # Log fraction of agents that collided self.frac_goal_achieved = [] # Log fraction of agents that achieved their goal self.agents_in_scene = [] - self.filename = self.env.files[0] if train_on_single_scene else None # If provided, always use the same file + self.filename = None # If provided, always use the same file def _reset_seeds(self) -> None: """Reset all environments' seeds.""" @@ -172,6 +172,11 @@ def reset_scene_dict(self): def step_num(self) -> List[int]: """The episodic timestep.""" return self.env.step_num + @property + def render(self) -> List[int]: + """The episodic timestep.""" + img = self.env.render() + return img def seed(self, seed=None): """Set the random seeds for all environments.""" diff --git a/utils/eval.py b/utils/eval.py index 062a976b..9a2c76db 100644 --- a/utils/eval.py +++ b/utils/eval.py @@ -51,13 +51,13 @@ def _get_scores(self): logging.info(f'\n Evaluating policy on {len(self.eval_files)} files...') - # Make tables - df_eval = pd.DataFrame( + # Create tables + df_eval = pd.DataFrame( columns=[ "run_id", - "traffic_scene", - "agents_controlled", "reg_coef", + "traffic_scene", + "agent_id", "act_acc", "accel_val_mae", "steer_val_mae", @@ -103,7 +103,7 @@ def _get_scores(self): nonnan_ids = ~np.isnan(expert_actions) # Compute metrics - action_accuracy = self.get_action_accuracy( + action_accuracies = self.get_action_accuracy( policy_actions, expert_actions, nonnan_ids ) @@ -120,15 +120,15 @@ def _get_scores(self): ) # Violations of the 3-second rule - violations_matrix, num_violations = self.get_veh_to_veh_distances(policy_pos, policy_speed) + #violations_matrix, num_violations = self.get_veh_to_veh_distances(policy_pos, policy_speed) # Store metrics - scene_perf = { + scene_perf = pd.DataFrame({ "run_id": self.run.id if self.log_to_wandb else None, + "reg_coef": np.repeat(self.reg_coef, len(self.agent_names)), "traffic_scene": file, - "agents_controlled": expert_actions.shape[0], - "reg_coef": self.exp_config.reg_weight if self.reg_coef is None else self.reg_coef, - "act_acc": action_accuracy, + "agent_id": self.agent_names, + "act_acc": action_accuracies, "accel_val_mae": abs_diff_accel, "steer_val_mae": abs_diff_steer, "pos_rmse": position_rmse, @@ -136,8 +136,8 @@ def _get_scores(self): "goal_rate": policy_gr, "veh_edge_cr": policy_edge_cr, "veh_veh_cr": policy_veh_cr, - } - df_eval.loc[len(df_eval)] = scene_perf + }) + df_eval = pd.concat([df_eval, scene_perf], ignore_index=True) if self.return_trajectories: scene_trajs = pd.DataFrame({ @@ -180,6 +180,7 @@ def _step_through_scene(self, filename: str, mode: str): # Make sure the agent ids are in the same order agent_ids = np.sort([veh.id for veh in self.env.controlled_vehicles]) + self.agent_names = agent_ids agent_id_to_idx_dict = {agent_id: idx for idx, agent_id in enumerate(agent_ids)} last_info_dicts = {agent_id: {} for agent_id in agent_ids} dead_agent_ids = [] @@ -188,7 +189,7 @@ def _step_through_scene(self, filename: str, mode: str): action_indices = np.full(fill_value=np.nan, shape=(self.num_agents, num_steps)) agent_positions = np.full(fill_value=np.nan, shape=(self.num_agents, num_steps, 2)) agent_speed = np.full(fill_value=np.nan, shape=(self.num_agents, num_steps)) - goal_achieved, veh_edge_collision, veh_veh_collision = 0, 0, 0 + goal_achieved, veh_edge_collision, veh_veh_collision = np.zeros(self.num_agents), np.zeros(self.num_agents), np.zeros(self.num_agents) # Set control mode if mode == "expert": @@ -263,18 +264,19 @@ def _step_through_scene(self, filename: str, mode: str): if done_dict["__all__"]: for agent_id in agent_ids: - goal_achieved += last_info_dicts[agent_id]["goal_achieved"] - veh_edge_collision += last_info_dicts[agent_id]["veh_edge_collision"] - veh_veh_collision += last_info_dicts[agent_id]["veh_veh_collision"] + agent_idx = agent_id_to_idx_dict[agent_id] + goal_achieved[agent_idx] += last_info_dicts[agent_id]["goal_achieved"] + veh_edge_collision[agent_idx] += last_info_dicts[agent_id]["veh_edge_collision"] + veh_veh_collision[agent_idx] += last_info_dicts[agent_id]["veh_veh_collision"] break return ( action_indices, agent_positions, agent_speed, - goal_achieved/self.num_agents, - veh_edge_collision/self.num_agents, - veh_veh_collision/self.num_agents, + goal_achieved, + veh_edge_collision, + veh_veh_collision, ) def get_action_val_diff(self, pred_actions, expert_actions): @@ -291,23 +293,31 @@ def get_action_val_diff(self, pred_actions, expert_actions): np.isnan(expert_actions), ) ) - valid_expert_acts = expert_actions[nonnan_ids] - valid_pred_acts = pred_actions[nonnan_ids] + num_agents = expert_actions.shape[0] + arr = np.zeros((num_agents, 2)) + for agent_idx in range(num_agents): + not_nan = nonnan_ids[agent_idx, :] - exp_acc_vals, exp_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts) - pred_acc_vals, pred_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts) + valid_expert_acts = expert_actions[agent_idx, :][not_nan] + valid_pred_acts = pred_actions[agent_idx, :][not_nan] - for idx in range(valid_expert_acts.shape[0]): + exp_acc_vals, exp_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts) + pred_acc_vals, pred_steer_vals = np.zeros_like(valid_pred_acts), np.zeros_like(valid_pred_acts) - # Get expert and predicted values - exp_acc_vals[idx], exp_steer_vals[idx] = self.env.idx_to_actions[valid_expert_acts[idx]] - pred_acc_vals[idx], pred_steer_vals[idx] = self.env.idx_to_actions[valid_pred_acts[idx]] + for idx in range(valid_expert_acts.shape[0]): + # Get expert and predicted values + exp_acc_vals[idx], exp_steer_vals[idx] = self.env.idx_to_actions[valid_expert_acts[idx]] + pred_acc_vals[idx], pred_steer_vals[idx] = self.env.idx_to_actions[valid_pred_acts[idx]] - # Get mean absolute difference - abs_accel_diff = np.abs(exp_acc_vals - pred_acc_vals).mean() - abs_steer_diff = np.abs(exp_steer_vals - pred_steer_vals).mean() + # Get mean absolute difference + abs_accel_diff = np.abs(exp_acc_vals - pred_acc_vals).mean() + abs_steer_diff = np.abs(exp_steer_vals - pred_steer_vals).mean() - return abs_accel_diff, abs_steer_diff + # Store + arr[agent_idx, 0] = abs_accel_diff + arr[agent_idx, 1] = abs_steer_diff + + return arr[:, 0], arr[:, 1] # abs_diff_accel, abs_diff_steer def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids): """Get accuracy of agent actions. @@ -316,9 +326,15 @@ def get_action_accuracy(self, pred_actions, expert_actions, nonnan_ids): expert_actions: (num_agents, num_steps_per_episode) the expert actions of the agents. nonnan_ids: (num_agents, num_steps_per_episode) the indices of non-nan actions. """ - return (expert_actions[nonnan_ids] == pred_actions[nonnan_ids]).sum() / nonnan_ids.flatten().shape[0] + num_agents = expert_actions.shape[0] + arr = np.zeros(num_agents) + for agent_idx in range(num_agents): + not_nan = nonnan_ids[agent_idx, :] + arr[agent_idx] = (expert_actions[agent_idx, :][not_nan] == pred_actions[agent_idx, :][not_nan]).sum() / not_nan.shape[0] + return arr def get_pos_rmse(self, pred_actions, expert_actions): + # Filter out invalid actions nonnan_ids = np.logical_not( np.logical_or( @@ -326,7 +342,12 @@ def get_pos_rmse(self, pred_actions, expert_actions): np.isnan(expert_actions), ) ) - return np.sqrt(np.linalg.norm(pred_actions[nonnan_ids] - expert_actions[nonnan_ids])).mean() + num_agents = expert_actions.shape[0] + arr = np.zeros(num_agents) + for agent_idx in range(num_agents): + not_nan = nonnan_ids[agent_idx, :] + arr[agent_idx] = (np.sqrt(np.linalg.norm(pred_actions[agent_idx, :][not_nan] - expert_actions[agent_idx, :][not_nan]))).mean() + return arr def get_speed_mae(self, pred_actions, expert_actions): # Filter out invalid actions @@ -336,7 +357,12 @@ def get_speed_mae(self, pred_actions, expert_actions): np.isnan(expert_actions), ) ) - return np.abs(pred_actions[nonnan_ids] - expert_actions[nonnan_ids]).mean() + num_agents = expert_actions.shape[0] + arr = np.zeros(num_agents) + for agent_idx in range(num_agents): + not_nan = nonnan_ids[agent_idx, :] + arr[agent_idx] = (np.abs(pred_actions[agent_idx, :][not_nan] - expert_actions[agent_idx, :][not_nan])).mean() + return arr def get_veh_to_veh_distances(self, positions, velocities, time_gap_in_sec=3): """Calculate distances between vehicles at each time step and track diff --git a/utils/evaluation.py b/utils/evaluation.py index 5ec2402c..9b105f73 100644 --- a/utils/evaluation.py +++ b/utils/evaluation.py @@ -2,12 +2,15 @@ import wandb from pyvirtualdisplay import Display +from utils.render import make_video + def evaluate_policy( model, env, - n_steps_per_episode, - n_eval_episodes, + n_steps_per_episode=80, + n_eval_episodes=1, eval_files=None, + eval_modes=['expert', 'policy'], deterministic = True, render = False, video_caption = None, @@ -20,8 +23,8 @@ def evaluate_policy( Args: ----- - model: The RL agent you want to evaluate. This can be any object - that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``) + model: The IL/RL policy to evaluate. This can be any object that implements + a `predict` method, such as an RL algorithm (``BaseAlgorithm``) or policy (``BasePolicy``). env: The gym environment or ``VecEnv`` environment. n_eval_episodes: Number of different traffic scenes in which to evaluate the agent @@ -39,19 +42,25 @@ def evaluate_policy( if verbose == 1: print(f"Evaluating policy on {traffic_scene}...") - for episode_i in range(n_eval_episodes): + for eval_mode in eval_modes: # Reset env observations = env.reset(filename=traffic_scene) num_agents_controlled = len(env.agent_ids) curr_rewards = np.zeros(num_agents_controlled) frames = [] - for step_j in range(n_steps_per_episode): + for timestep in range(n_steps_per_episode): + + if eval_mode == 'policy': + # Predict actions + actions, _ = model.predict( + observations, + deterministic=deterministic, + ) + elif eval_mode == 'expert': + actions = None - actions, _ = model.predict( - observations, - deterministic=deterministic, - ) + # Step environment new_observations, rewards, dones, infos = env.step(actions) for agent_idx, agent_id in enumerate(env.agent_ids): @@ -62,7 +71,7 @@ def evaluate_policy( # Render if render: - if step_j % video_config.logging.render_interval == 0: + if timestep % video_config.logging.render_interval == 0: if video_config.logging.where_am_i == "headless_machine": with Display() as disp: render_scene = env.env.scenario.getImage(**video_config.render) @@ -71,9 +80,9 @@ def evaluate_policy( render_scene = env.scenario.getImage(**video_config) frames.append(render_scene.T) - if sum(dones) == env.num_agents or step_j == (n_steps_per_episode-1): - episode_rewards[episode_i] = curr_rewards.sum() / env.num_agents - episode_lengths[episode_i] = step_j + if sum(dones) == env.num_agents or timestep == (n_steps_per_episode-1): + episode_rewards = curr_rewards.sum() / env.num_agents + episode_lengths = timestep break # Log video to wandb @@ -84,7 +93,7 @@ def evaluate_policy( { video_key: wandb.Video(movie_frames, fps=video_config.logging.fps, - caption=f'{video_caption} | norm_ep_return = {episode_rewards[episode_i]:.2f}'), + caption=f'{video_caption}_mode_:{eval_mode}'), }, ) diff --git a/utils/manage_models.py b/utils/manage_models.py index f2433cf8..a686ea69 100644 --- a/utils/manage_models.py +++ b/utils/manage_models.py @@ -9,7 +9,7 @@ entity = "daphnecor" project = "scaling_ppo" - collection_name = "nocturne-hr-ppo-12_30_21_43" + collection_name = "nocturne-hr-ppo-01_02_11_16" # Always initialize a W&B run to start tracking wandb.init() diff --git a/utils/render.py b/utils/render.py index b611bcd4..bc8a6702 100644 --- a/utils/render.py +++ b/utils/render.py @@ -25,8 +25,8 @@ def make_video( filenames = None, deterministic: bool = True, max_steps: int = 80, - snap_interval: int = 4, - frames_per_second: int = 4, + snap_interval: int = 3, + frames_per_second: int = 3, ) -> Tuple[np.ndarray, pd.DataFrame]: """Make a video of policy in traffic scene. diff --git a/utils/wrappers.py b/utils/wrappers.py index 67d956e2..c149c4ce 100644 --- a/utils/wrappers.py +++ b/utils/wrappers.py @@ -17,6 +17,7 @@ def __init__(self, config): self.observation_space = gym.spaces.Box(-np.inf, np.inf, self.env.observation_space.shape, np.float32) def step(self, actions=None): + """If actions is None, vehicles are stepped in expert control mode.""" obs = np.zeros((self.num_agents, self.observation_space.shape[0])) rews, dones, infos = np.zeros((self.num_agents)), np.zeros((self.num_agents)), [] @@ -26,7 +27,9 @@ def step(self, actions=None): agent_id: actions[idx] for idx, agent_id in enumerate(self.agent_ids) if agent_id not in self.dead_agent_ids } - else: + else: # Set in expert control mode + for veh_obj in self.env.controlled_vehicles: + veh_obj.expert_control = True agent_actions = {} # Take a step to obtain dicts