From b5cf5896558439f1b7ae6254e9176e1be1fa75f2 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Mon, 18 Mar 2024 05:01:45 +0000 Subject: [PATCH] Add occlusion training warpup code to notebooks --- notebooks/occlusion_training_wrapup.ipynb | 359 ++++++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 notebooks/occlusion_training_wrapup.ipynb diff --git a/notebooks/occlusion_training_wrapup.ipynb b/notebooks/occlusion_training_wrapup.ipynb new file mode 100644 index 0000000..e246ae1 --- /dev/null +++ b/notebooks/occlusion_training_wrapup.ipynb @@ -0,0 +1,359 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wrap-up codes for Occlusion Training" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "83" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from collections import defaultdict\n", + "\n", + "import wandb\n", + "\n", + "# Open target RoI indices\n", + "with open(\"assets/dkt_indices.txt\", mode=\"r\") as f:\n", + " gt_rois = {int(roi.strip()) for roi in f.readlines()}\n", + "\n", + "api = wandb.Api()\n", + "runs = api.runs(path=\"1pha/brain-age\",\n", + " filters={\"config.dataloader.dataset._target_\": \"sage.data.mask.UKB_MaskDataset\"})\n", + "len(runs)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "runs_dict = defaultdict(dict)\n", + "runs_idx = defaultdict(set)\n", + "for run in runs:\n", + " state = run.state\n", + " name: str = run.name.split(\" \")[1]\n", + " if not name.isnumeric():\n", + " # Skip outlier run name that does not follow `M XXXX | SEED`\n", + " continue\n", + " idx = int(name)\n", + " if state == \"finished\":\n", + " runs_dict[\"success\"][idx] = run\n", + " runs_idx[\"success\"].add(idx)\n", + " elif state == \"failed\":\n", + " runs_dict[\"fail\"][idx] = run\n", + " runs_idx[\"fail\"].add(idx)\n", + " else:\n", + " runs_dict[\"etc\"][idx] = (run, state)\n", + " runs_idx[\"etc\"].add(idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(66, 11, 3)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(runs_idx[\"success\"]), len(runs_idx[\"fail\"]), len(runs_idx[\"etc\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "assert gt_rois.issuperset(runs_idx[\"success\"])\n", + "leftover = gt_rois - runs_idx[\"success\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "34" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(leftover)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Actual failed runs: 9\n", + "1025\n", + "2018\n", + "2024\n", + "1002\n", + "77\n", + "2030\n", + "1009\n", + "54\n", + "24\n" + ] + } + ], + "source": [ + "true_fail = {idx for idx in runs_idx[\"fail\"] if idx not in runs_idx[\"success\"]}\n", + "print(f\"Actual failed runs: {len(true_fail)}\")\n", + "for fail in true_fail:\n", + " print(fail)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Runs checked manually that has final test performance \n", + "manual_check_success: set = {1025, 2018, 2024, 1002, 77, 54, 24}\n", + "\n", + "# Was not able to get test performance but has final checkpoint\n", + "manual_check_ckpt: set = {2030, 1009}\n", + "\n", + "# Sanity check for empty runs\n", + "assert len(true_fail - (manual_check_ckpt | manual_check_success)) == 0\n", + "assert len((manual_check_ckpt | manual_check_success) - true_fail) == 0" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "25" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "leftover = leftover - true_fail\n", + "len(leftover)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{2017}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "leftover & runs_idx[\"etc\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Mask idx 2017 turned out to be crashed. Crahsed runs should be re-batched." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{2,\n", + " 46,\n", + " 47,\n", + " 49,\n", + " 50,\n", + " 251,\n", + " 252,\n", + " 253,\n", + " 254,\n", + " 255,\n", + " 1010,\n", + " 1011,\n", + " 1012,\n", + " 1013,\n", + " 1014,\n", + " 1015,\n", + " 1016,\n", + " 2005,\n", + " 2006,\n", + " 2007,\n", + " 2008,\n", + " 2017,\n", + " 2031,\n", + " 2034,\n", + " 2035}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "leftover" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Split runs and re-batch" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "def split_into_chunks(data_set, num_chunks):\n", + " # Convert set to list for shuffling\n", + " data_list = list(data_set)\n", + " \n", + " # Calculate the size of each chunk\n", + " chunk_size = len(data_list) // num_chunks\n", + " remainder = len(data_list) % num_chunks\n", + " \n", + " # Split the list into chunks\n", + " chunks = []\n", + " start = 0\n", + " for i in range(num_chunks):\n", + " chunk_length = chunk_size + (1 if i < remainder else 0)\n", + " chunks.append(data_list[start:start + chunk_length])\n", + " start += chunk_length\n", + " \n", + " flattened_chunks = [item for sublist in chunks for item in sublist]\n", + " assert set(flattened_chunks) == data_set, \"Chunks do not contain all elements of the data set\"\n", + " return chunks\n", + "\n", + "\n", + "machines = [\"185-0\", \"185-1\", \"245-0\", \"245-1\"]\n", + "chunks = split_into_chunks(data_set=leftover, num_chunks=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[2, 46, 47, 49, 50, 2005, 2006],\n", + " [2007, 2008, 2017, 2031, 2034, 2035],\n", + " [1010, 1011, 1012, 1013, 1016, 1014],\n", + " [1015, 251, 252, 253, 254, 255]]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chunks" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "for chunk, machine in zip(chunks, machines):\n", + " with open(f\"assets/dkt_leftover_{machine}.txt\", mode=\"w\") as f:\n", + " for roi in chunk:\n", + " f.write(f\"{roi}\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "age", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}