diff --git a/notebook.ipynb b/notebook.ipynb index 6034d6b..639cd7b 100644 --- a/notebook.ipynb +++ b/notebook.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -14,32 +14,124 @@ "from jax import jit, grad\n", "import jax.numpy as jnp\n", "from functools import partial\n", + "from IPython.display import display, HTML, clear_output\n", + "import time\n", + "\n", "from src.data import load_subject, make_kfolds\n", "from src.model import loss_fn, init, apply\n", - "from src.plots import plot_brain, plot_decoding\n", - "from src.utils import CONFIG\n", + "from src.plots import plot_brain\n", + "from src.utils import CONFIG, matrix_to_image\n", "from src.train import train_folds, hyperparam_fn" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-09-16 09:15:32.502197: W pjrt_plugin/src/mps_client.cc:535] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metal device set to: Apple M1 Pro\n", + "\n", + "systemMemory: 16.00 GB\n", + "maxCacheSize: 5.33 GB\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 131/131 [00:07<00:00, 16.66it/s]\n" + ] + } + ], "source": [ - "subject = load_subject('subj05', image_size=CONFIG['image_size'])" + "subject = load_subject('subj05', image_size=CONFIG['image_size'])\n", + "hyperparams = hyperparam_fn()\n", + "kfolds = make_kfolds(subject, hyperparams)\n", + "metrics, params = train_folds(kfolds, hyperparams)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 38, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/syrkis/code/neuroscope/notebook.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 15\u001b[0m matrix_lst \u001b[39m=\u001b[39m [np\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39mrand(\u001b[39m100\u001b[39m, \u001b[39m100\u001b[39m) \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m4\u001b[39m)]\n\u001b[1;32m 16\u001b[0m display_image(matrix_lst)\n\u001b[0;32m---> 17\u001b[0m time\u001b[39m.\u001b[39;49msleep(\u001b[39m1\u001b[39;49m)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ - "hyperparams = hyperparam_fn()\n", - "kfolds = make_kfolds(subject, hyperparams)\n", - "train_folds(kfolds, hyperparams)" + "def display_image(matrix_lst):\n", + " html = '
'\n", + " for matrix in matrix_lst:\n", + " image = matrix_to_image(matrix)\n", + " html += f\"\"\"\n", + "
\n", + "
\"\"\"\n", + "\n", + " html += '
'\n", + " clear_output(wait=True)\n", + " display(HTML(html))\n", + "\n", + "# Example usage with a random 100x100 matrix\n", + "for i in range(10):\n", + " matrix_lst = [np.random.rand(100, 100) for _ in range(4)]\n", + " display_image(matrix_lst)\n", + " time.sleep(1)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/alex.ipynb b/notebooks/alex.ipynb deleted file mode 100644 index 874c22b..0000000 --- a/notebooks/alex.ipynb +++ /dev/null @@ -1,150 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Alexnet based feature extractor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# imports\n", - "import warnings; warnings.filterwarnings('ignore')\n", - "import os\n", - "import torch\n", - "from multiprocessing import Pool\n", - "from torchvision.models.feature_extraction import create_feature_extractor\n", - "from sklearn.decomposition import PCA\n", - "import numpy as np\n", - "from tqdm import tqdm\n", - "from PIL import Image\n", - "from src.utils import DATA_DIR" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get data and model\n", - "subjs = ['subj01', 'subj02', 'subj03', 'subj04', 'subj05', 'subj06', 'subj07', 'subj08']\n", - "N_SAMPLES = 0\n", - "model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet')\n", - "feature_extractor = create_feature_extractor(model, return_nodes=[\"features.2\"])" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Get image data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_img_files(subj):\n", - " subj_img_dir = os.path.join(DATA_DIR, subj, 'training_split/training_images')\n", - " subj_img_files = [os.path.join(subj_img_dir, f) for f in os.listdir(subj_img_dir) if f.endswith('.png')]\n", - " return sorted(subj_img_files)\n", - "\n", - "def load_img_files(subj):\n", - " # images are pngs\n", - " img_files = get_img_files(subj)\n", - " img_files = img_files[:N_SAMPLES] if N_SAMPLES else img_files\n", - " imgs = []\n", - " for f in tqdm(img_files): # make sure not to have too many files open\n", - " with Image.open(f) as img:\n", - " img = img.convert('RGB').resize((224, 224))\n", - " img = torch.from_numpy(np.array(img))\n", - " imgs.append(img)\n", - " imgs = torch.stack(imgs)\n", - " imgs = imgs / 255.0\n", - " imgs = imgs.permute(0, 3, 1, 2)\n", - " imgs = normalize(imgs)\n", - " return imgs\n", - "\n", - "def normalize(imgs):\n", - " means = [0.485, 0.456, 0.406]\n", - " stds = [0.229, 0.224, 0.225]\n", - " imgs = imgs.float()\n", - " for i in range(3):\n", - " imgs[:, i, :, :] = (imgs[:, i, :, :] - means[i]) / stds[i]\n", - " return imgs\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def run_subj(subj):\n", - " pca = PCA(n_components=100)\n", - " data = load_img_files(subj)\n", - " feats = feature_extractor(data)\n", - " feats = torch.hstack([torch.flatten(l, start_dim=1) for l in feats.values()])\n", - " feats = feats.detach().numpy()\n", - " feats = feats.reshape(feats.shape[0], -1)\n", - " feats = pca.fit_transform(feats)\n", - " np.save(os.path.join(DATA_DIR, subj, 'training_split', 'alexnet_pca.npy'), feats)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# with Pool(2) as p:\n", - "# p.map(run_subj, subjs)\n", - "# run last 4 subjects in parallel\n", - "for subj in subjs[4:]:\n", - " print(f'running {subj}')\n", - " run_subj(subj)\n", - " print()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/data.ipynb b/notebooks/data.ipynb deleted file mode 100644 index eec9937..0000000 --- a/notebooks/data.ipynb +++ /dev/null @@ -1,58 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt\n", - "from src.utils import get_args_and_config\n", - "from src.data import get_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "args, config = get_args_and_config()\n", - "data = get_data(args, config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for subject, (folds, test_data) in data.items():\n", - " print(subject)\n", - " for img, cat, lh, rh in folds:\n", - " print(img.shape, cat.shape, lh.shape, rh.shape)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/dropout.ipynb b/notebooks/dropout.ipynb deleted file mode 100644 index 13bf9c2..0000000 --- a/notebooks/dropout.ipynb +++ /dev/null @@ -1,98 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import haiku as hk" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def network_fn(x, training=True):\n", - " mlp = hk.nets.MLP([10, 10])\n", - " # apply dropout if training\n", - " x = hk.dropout(hk.next_rng_key(), 0.5, x) if training else x\n", - " x = mlp(x)\n", - " x = hk.dropout(hk.next_rng_key(), 0.5, x) if training else x\n", - " return x\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = jnp.ones((8, 28*28))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "forward = hk.transform(network_fn)\n", - "rng = jax.random.PRNGKey(42)\n", - "params = forward.init(rng, x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(10):\n", - " pred = forward.apply(params, rng, x, training=True)\n", - " print(pred)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/eval.ipynb b/notebooks/eval.ipynb deleted file mode 100644 index ae5b8ed..0000000 --- a/notebooks/eval.ipynb +++ /dev/null @@ -1,173 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# eval" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 6/6 [00:00<00:00, 41.18it/s]\n", - "100%|██████████| 6/6 [00:07<00:00, 1.19s/it]\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "from sklearn.linear_model import LinearRegression\n", - "import pickle\n", - "import yaml\n", - "from tqdm import tqdm\n", - "import jax.numpy as jnp\n", - "from src.utils import get_args_and_config, SUBJECTS\n", - "from src.data import get_data\n", - "from src.eval import corr\n", - "from src.fmri import plot_brain" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/6 [00:00