From a2606e382f8d36be3c44c4bd9c0a35e5cb217056 Mon Sep 17 00:00:00 2001 From: Noah Syrkis Date: Sat, 16 Sep 2023 08:52:27 -0300 Subject: [PATCH] code to train.py --- notebook.ipynb | 440 +------------------------------------------------ src/train.py | 54 +++--- 2 files changed, 41 insertions(+), 453 deletions(-) diff --git a/notebook.ipynb b/notebook.ipynb index e0ac396..6034d6b 100644 --- a/notebook.ipynb +++ b/notebook.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -17,12 +17,13 @@ "from src.data import load_subject, make_kfolds\n", "from src.model import loss_fn, init, apply\n", "from src.plots import plot_brain, plot_decoding\n", - "from src.utils import CONFIG" + "from src.utils import CONFIG\n", + "from src.train import train_folds, hyperparam_fn" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -31,443 +32,14 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "def hyperparam_fn():\n", - " return {\n", - " 'batch_size': np.random.choice([32, 64]),\n", - " 'n_steps': np.random.randint(low=100, high=200),\n", - " 'dropout_rate': np.random.uniform(low=0.1, high=0.5),\n", - " }\n", - "\n", - "def update_fn(params, rng, fmri, img, opt_state, opt, dropout_rate):\n", - " rng, key = jax.random.split(rng)\n", - " grads = grad(loss_fn)(params, key, fmri, img, dropout_rate=dropout_rate)\n", - " updates, opt_state = opt.update(grads, opt_state, params)\n", - " params = optax.apply_updates(params, updates)\n", - " return params, opt_state\n", - "\n", - "\n", - "def train_loop(rng, opt, train_loader, val_loader, plot_batch, hyperparams):\n", - " metrics = []\n", - " rng, key = jax.random.split(rng, 2)\n", - " lh, rh, img = next(train_loader)\n", - " params = init(key, lh)\n", - " opt_state = opt.init(params)\n", - " update = partial(update_fn, opt=opt, dropout_rate=hyperparams['dropout_rate'])\n", - " for step in tqdm(range(hyperparams['n_steps'])):\n", - " rng, key = jax.random.split(rng)\n", - " lh, rh, img = next(train_loader)\n", - " params, opt_state = update(params, key, lh, img, opt_state)\n", - " if (step % (hyperparams['n_steps'] // 100)) == 0:\n", - " rng, key = jax.random.split(rng)\n", - " metrics.append(evaluate(params, key, train_loader, val_loader))\n", - " # plot_pred = apply(params, key, plot_batch[0])\n", - " # plot_decodings(plot_pred)\n", - " return metrics\n", - "\n", - "\n", - "def evaluate(params, rng, train_loader, val_loader, n_steps=10):\n", - " # each batch is a tuple(lh, rh, img). Connect n_steps batches into 1\n", - " train_loss, val_loss = 0, 0\n", - " for _ in range(n_steps):\n", - " rng, key_train, key_val = jax.random.split(rng, 3)\n", - " lh, rh, img = next(train_loader)\n", - " train_loss += loss_fn(params, key_train, lh, img)\n", - " lh, rh, img = next(val_loader)\n", - " val_loss += loss_fn(params, key_val, lh, img)\n", - " train_loss /= n_steps\n", - " val_loss /= n_steps\n", - " return(f'train_loss: {train_loss}, val_loss: {val_loss}')\n", - "\n", - "\n", - "def train_folds(kfolds, hyperparams, seed=0):\n", - " metrics = {}\n", - " rng = jax.random.PRNGKey(seed)\n", - " opt = optax.lion(1e-3)\n", - " plot_batch = None\n", - " for idx, (train_loader, val_loader) in enumerate(kfolds):\n", - " plot_batch = next(train_loader) if plot_batch is None else plot_batch\n", - " rng, key = jax.random.split(rng)\n", - " fold_metrics = train_loop(key, opt, train_loader, val_loader, plot_batch, hyperparams)\n", - " metrics[idx] = fold_metrics\n", - " return metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 23%|██▎ | 44/191 [00:06<00:18, 7.91it/s]" - ] - } - ], "source": [ "hyperparams = hyperparam_fn()\n", "kfolds = make_kfolds(subject, hyperparams)\n", - "train_folds(kfolds, hyperparams)\n", - "# train_loader, _ = next(kfolds)" + "train_folds(kfolds, hyperparams)" ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "rng = jax.random.PRNGKey(0)\n", - "rng, key = jax.random.split(rng)\n", - "lh, rh, img = next(train_loader)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "params = init(key, lh)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[[[-3.16308439e-02, -3.31890076e-01, -5.12075238e-02],\n", - " [ 4.85625491e-02, 2.17759639e-01, 5.90159558e-02],\n", - " [-2.78418064e-01, -6.86557889e-02, 2.25962013e-01],\n", - " ...,\n", - " [-9.51540470e-02, -2.32540250e-01, -4.32052374e-01],\n", - " [-2.11033791e-01, -2.48100042e-01, 4.59478647e-01],\n", - " [ 3.46645892e-01, -2.35105738e-01, -1.11824542e-01]],\n", - "\n", - " [[-2.35176325e-01, -1.32810012e-01, -9.88911763e-02],\n", - " [-4.55371052e-01, 1.92810908e-01, -2.92965651e-01],\n", - " [-6.06114268e-02, -4.25028712e-01, 1.31005749e-01],\n", - " ...,\n", - " [-1.79364175e-01, 9.62745026e-02, 4.54996139e-01],\n", - " [-8.48789662e-02, -7.97832534e-02, -1.35977253e-01],\n", - " [-9.35600474e-02, 1.97868705e-01, 3.33004594e-01]],\n", - "\n", - " [[ 1.24057353e-01, -1.16122976e-01, 2.40907833e-01],\n", - " [ 2.87088841e-01, -2.07734168e-01, -6.42771600e-03],\n", - " [ 5.59472516e-02, 2.47780457e-02, -2.08534971e-02],\n", - " ...,\n", - " [ 2.19584569e-01, -1.29181109e-02, -2.01986358e-01],\n", - " [-8.34162161e-02, -5.34862995e-01, -1.60237160e-02],\n", - " [-3.82530898e-01, 1.32596970e-01, 4.07823861e-01]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.12461054e-01, -3.35605502e-01, -1.54761687e-01],\n", - " [-3.17773636e-04, 2.72257682e-02, 9.25498828e-02],\n", - " [-2.20756739e-01, 2.38331810e-01, 1.72837913e-01],\n", - " ...,\n", - " [ 2.98125744e-01, 5.27467439e-03, -1.08079553e-01],\n", - " [ 1.90505220e-04, 2.02768460e-01, 4.82104048e-02],\n", - " [-7.06977621e-02, 1.87110364e-01, -1.25799716e-01]],\n", - "\n", - " [[-4.53040078e-02, 4.22944576e-01, -3.33101451e-02],\n", - " [-1.60117090e-01, -2.00824633e-01, 2.15869457e-01],\n", - " [ 2.85510749e-01, 2.02360347e-01, -1.50161162e-01],\n", - " ...,\n", - " [-4.87241745e-01, 1.85682207e-01, 1.49394423e-01],\n", - " [-7.20767826e-02, -2.46882334e-01, 2.29621725e-03],\n", - " [-3.19331050e-01, 1.13522656e-01, -1.77007407e-01]],\n", - "\n", - " [[ 5.49485050e-02, 2.58143514e-01, 1.13540359e-01],\n", - " [ 3.29844654e-03, 1.05308942e-01, 5.04658185e-02],\n", - " [-3.63346100e-01, 1.55933246e-01, -1.13624088e-01],\n", - " ...,\n", - " [ 1.47531992e-02, -8.28974228e-03, 5.53916022e-02],\n", - " [ 5.96236289e-02, 2.53428161e-01, 9.67664346e-02],\n", - " [-3.48339140e-01, -7.61434110e-03, 3.02523285e-01]]],\n", - "\n", - "\n", - " [[[ 2.76848316e-01, -4.47192788e-01, 3.38576250e-02],\n", - " [-1.68515861e-01, 1.60554066e-01, 1.48149788e-01],\n", - " [-1.75228313e-01, -3.34572136e-01, 3.56876135e-01],\n", - " ...,\n", - " [-1.98685303e-01, 3.53242159e-01, -1.85697779e-01],\n", - " [ 1.14692040e-01, -1.28888153e-03, 3.90966356e-01],\n", - " [-8.34309831e-02, -2.24249214e-01, -1.84188619e-01]],\n", - "\n", - " [[-1.60889119e-01, -3.85369003e-01, 1.66815639e-01],\n", - " [-1.50478408e-01, 1.07559137e-01, -1.12522632e-01],\n", - " [-9.73116830e-02, -3.47812712e-01, 1.50650060e-02],\n", - " ...,\n", - " [-1.34173436e-02, 6.63158000e-02, 3.82588655e-01],\n", - " [-1.01495586e-01, -2.42465153e-01, -2.20987216e-01],\n", - " [-1.81468889e-01, 4.55672115e-01, 1.02318460e-02]],\n", - "\n", - " [[ 1.33297309e-01, -6.23867922e-02, 5.86281478e-01],\n", - " [ 4.22343284e-01, -8.80789757e-02, 5.06852157e-02],\n", - " [-1.93326734e-02, -1.49723655e-02, -1.50692791e-01],\n", - " ...,\n", - " [ 1.16948098e-01, 1.12659566e-01, 1.20082840e-01],\n", - " [-1.67915568e-01, -5.01734793e-01, -2.75213748e-01],\n", - " [-1.80195436e-01, -1.62525415e-01, 6.80903867e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.98876753e-01, -4.53408539e-01, -2.69167662e-01],\n", - " [ 1.64142698e-02, 1.20938823e-01, -9.10500661e-02],\n", - " [ 5.92849180e-02, -2.84007967e-01, 3.32212418e-01],\n", - " ...,\n", - " [ 4.58866388e-01, -6.33756071e-02, -1.36650100e-01],\n", - " [-1.45580471e-02, 3.32152963e-01, -2.14704290e-01],\n", - " [-3.23088497e-01, -1.41888469e-01, -1.97360858e-01]],\n", - "\n", - " [[ 4.76914614e-01, -2.64918841e-02, 2.04132780e-01],\n", - " [-2.77870417e-01, -1.87925458e-01, 5.09443104e-01],\n", - " [ 1.13214955e-01, -3.06309983e-02, -2.05622107e-01],\n", - " ...,\n", - " [-2.16346070e-01, 1.33156583e-01, 5.91033213e-02],\n", - " [-7.09464820e-03, -3.32767218e-02, 3.14093120e-02],\n", - " [-6.30219877e-01, 2.77944654e-01, -1.82400178e-02]],\n", - "\n", - " [[ 2.89052725e-01, 7.94600397e-02, 1.20491117e-01],\n", - " [-8.38417746e-03, -3.32417458e-01, -3.08977455e-01],\n", - " [-3.73546124e-01, 3.50247741e-01, -1.86779842e-01],\n", - " ...,\n", - " [ 3.58485669e-01, 2.18735158e-01, 2.54556119e-01],\n", - " [ 3.49774390e-01, 3.19396019e-01, 1.25494283e-02],\n", - " [-3.15304637e-01, 1.32932171e-01, -7.74939060e-02]]],\n", - "\n", - "\n", - " [[[-1.71573013e-02, -4.58791018e-01, 7.03756437e-02],\n", - " [ 7.93122575e-02, 3.64519320e-02, 7.19713047e-02],\n", - " [-2.87017196e-01, -4.36477810e-01, 1.91330105e-01],\n", - " ...,\n", - " [-2.10230917e-01, 2.79428869e-01, -3.70985828e-02],\n", - " [ 2.63381340e-02, -1.99434206e-01, 2.02702522e-01],\n", - " [-2.78595567e-01, -1.05767079e-01, -1.36459142e-01]],\n", - "\n", - " [[-4.64002818e-01, -4.62651439e-02, 8.87793303e-02],\n", - " [-3.94162357e-01, -1.44905254e-01, -1.28615558e-01],\n", - " [-4.20472980e-01, -2.06376210e-01, -5.72985224e-02],\n", - " ...,\n", - " [ 3.73510993e-03, 1.95806429e-01, 2.36873165e-01],\n", - " [-1.93745330e-01, -2.09315792e-01, -1.50508523e-01],\n", - " [ 9.82454121e-02, -2.10798144e-01, -3.03285718e-02]],\n", - "\n", - " [[ 1.87803265e-02, -8.80425517e-03, 5.99161200e-02],\n", - " [ 1.31371215e-01, -2.28759170e-01, -1.22479014e-01],\n", - " [-9.31648985e-02, 5.53092500e-03, 1.04010507e-01],\n", - " ...,\n", - " [ 2.71606058e-01, -2.71918569e-02, -4.23889995e-01],\n", - " [ 1.95655212e-01, -2.12338530e-02, -1.77653953e-01],\n", - " [-2.18370885e-01, -5.74656166e-02, 3.32462251e-01]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.18326530e-01, -2.91892052e-01, 5.14426902e-02],\n", - " [-3.30696627e-02, 2.68478602e-01, 3.11449111e-01],\n", - " [-1.32025406e-01, 1.68051064e-01, 1.29379466e-01],\n", - " ...,\n", - " [ 2.43585438e-01, -1.81304380e-01, -1.31767228e-01],\n", - " [-2.36473545e-01, 1.46659732e-01, -1.13205895e-01],\n", - " [-1.30029768e-02, 1.70662552e-01, -2.19250724e-01]],\n", - "\n", - " [[-2.76047349e-01, -1.82572603e-01, 1.48242354e-01],\n", - " [ 2.62171835e-01, 1.53515786e-01, 4.14435208e-01],\n", - " [ 2.93871582e-01, -1.43329531e-01, -1.22132860e-01],\n", - " ...,\n", - " [-2.59676486e-01, -4.78881970e-02, -2.05224723e-01],\n", - " [-2.57028461e-01, -5.01709640e-01, 1.73464730e-01],\n", - " [-3.65213513e-01, 3.65435511e-01, 8.07523727e-02]],\n", - "\n", - " [[-1.06481001e-01, 2.88241804e-01, 9.65256989e-02],\n", - " [ 2.06365392e-01, -5.35896532e-02, -6.13915861e-01],\n", - " [-1.01075783e-01, 1.57415792e-01, -2.68994689e-01],\n", - " ...,\n", - " [ 8.75082314e-02, -2.41302587e-02, 1.14921056e-01],\n", - " [ 1.36913918e-02, 1.82227239e-01, -1.42613798e-01],\n", - " [ 3.77787091e-02, 9.85990465e-02, 1.84514984e-01]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 7.29908124e-02, -4.32248294e-01, -3.88291515e-02],\n", - " [-2.02636570e-01, 2.70098746e-01, 1.13987558e-01],\n", - " [-1.64443985e-01, -5.11896551e-01, 3.45437199e-01],\n", - " ...,\n", - " [-1.79636836e-01, 1.14088193e-01, -1.19496807e-01],\n", - " [-1.00764163e-01, 1.03761740e-01, 4.91529822e-01],\n", - " [-2.88839906e-01, -2.40335748e-01, -1.47642002e-01]],\n", - "\n", - " [[-4.88939494e-01, -2.41623804e-01, 1.68584287e-01],\n", - " [-2.85874009e-01, -8.13421309e-02, -9.32594240e-02],\n", - " [-1.73792705e-01, -4.44709301e-01, 3.49871069e-01],\n", - " ...,\n", - " [ 2.56149992e-02, 1.99611023e-01, 4.49322194e-01],\n", - " [ 4.72315401e-03, -3.36297005e-01, -2.34886527e-01],\n", - " [-3.39235097e-01, 2.71154344e-01, -8.37586727e-03]],\n", - "\n", - " [[ 1.38872311e-01, -4.89961803e-02, 4.80804145e-01],\n", - " [ 5.01828134e-01, -5.23126721e-02, 2.92063225e-03],\n", - " [ 2.05624506e-01, -3.38365254e-03, 1.13075525e-02],\n", - " ...,\n", - " [ 1.99026197e-01, -7.00076669e-02, 7.80869797e-02],\n", - " [-1.06316973e-02, -4.05847669e-01, -3.05183142e-01],\n", - " [-1.64455622e-01, 1.37957418e-03, 2.71631151e-01]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.87628552e-01, -4.04370576e-01, -2.96329707e-01],\n", - " [-9.68069211e-02, 3.60369593e-01, 1.59806326e-01],\n", - " [ 3.78509350e-02, -6.53520301e-02, 5.20098805e-01],\n", - " ...,\n", - " [ 3.86202842e-01, -1.01383179e-01, -1.00357272e-01],\n", - " [ 1.49571806e-01, 3.66316825e-01, -9.77360830e-02],\n", - " [-1.17045559e-01, 8.46182406e-02, -4.25682664e-02]],\n", - "\n", - " [[ 2.32568517e-01, 1.01576909e-01, -1.01913858e-04],\n", - " [-2.90077418e-01, -1.59666672e-01, 4.34943825e-01],\n", - " [-3.25486548e-02, -1.60638407e-01, -1.87192962e-01],\n", - " ...,\n", - " [-2.77157098e-01, 7.05263987e-02, 7.31498227e-02],\n", - " [-2.50827312e-01, -2.10799441e-01, 4.89038043e-02],\n", - " [-7.94779360e-01, 2.63429910e-01, 4.65564756e-03]],\n", - "\n", - " [[ 2.31905580e-01, 2.57656649e-02, -1.62076965e-01],\n", - " [ 7.33081624e-02, -2.61856198e-01, -8.50606263e-02],\n", - " [-1.84318498e-01, 1.84499398e-01, -3.74886617e-02],\n", - " ...,\n", - " [ 3.71837944e-01, -1.01691209e-01, 9.09964666e-02],\n", - " [ 5.74446730e-02, 2.65218019e-01, 1.78798605e-02],\n", - " [-2.16180146e-01, 8.21374170e-03, -6.59115165e-02]]],\n", - "\n", - "\n", - " [[[-2.52042115e-01, -7.11863279e-01, 5.89845106e-02],\n", - " [-2.21859813e-01, -1.95343390e-01, -2.69230932e-01],\n", - " [-1.67375103e-01, -2.31902316e-01, 5.00592768e-01],\n", - " ...,\n", - " [-2.46498451e-01, 1.00462347e-01, 2.57418640e-02],\n", - " [-1.54497430e-01, -5.60016781e-02, 9.00640339e-02],\n", - " [ 3.39719594e-01, -4.48350549e-01, -1.40127003e-01]],\n", - "\n", - " [[-3.38311255e-01, -4.79807138e-01, 2.61082631e-02],\n", - " [-3.18907619e-01, -1.02070808e-01, -1.19623773e-01],\n", - " [-5.66370785e-01, -5.35722971e-02, -6.71801120e-02],\n", - " ...,\n", - " [ 6.53107837e-02, 1.58424288e-01, 4.63038504e-01],\n", - " [-6.43283874e-02, -2.28271618e-01, -3.38626772e-01],\n", - " [ 2.28519469e-01, 4.05774973e-02, 9.44428593e-02]],\n", - "\n", - " [[ 3.34999524e-02, 1.34975627e-01, 4.42358136e-01],\n", - " [ 1.36479780e-01, -7.69354776e-02, -1.56208873e-01],\n", - " [ 8.63366574e-02, -7.57174045e-02, -1.23119846e-01],\n", - " ...,\n", - " [ 3.25773418e-01, 1.07700257e-02, -3.11167151e-01],\n", - " [ 1.02954991e-02, -7.46230558e-02, -3.89628448e-02],\n", - " [-5.69392622e-01, 6.84801042e-02, 5.96192002e-01]],\n", - "\n", - " ...,\n", - "\n", - " [[-3.35957378e-01, -2.48730823e-01, -2.05005273e-01],\n", - " [ 1.70171678e-01, 7.18605220e-02, 1.81563482e-01],\n", - " [-2.69290924e-01, 7.37955980e-03, 2.50166982e-01],\n", - " ...,\n", - " [ 5.13379931e-01, -5.25573492e-01, -2.19751149e-01],\n", - " [ 3.29484269e-02, 1.67924121e-01, -1.29641220e-01],\n", - " [-1.68855652e-01, 1.71709552e-01, -1.01016313e-02]],\n", - "\n", - " [[-1.02149457e-01, 1.06720082e-01, 2.32731193e-01],\n", - " [ 1.18524078e-02, -1.05101420e-02, 2.96982616e-01],\n", - " [ 2.63219804e-01, 1.52253821e-01, 3.09205782e-02],\n", - " ...,\n", - " [-3.63187909e-01, -1.18413448e-01, -2.79962063e-01],\n", - " [-1.25339881e-01, -2.67861068e-01, 1.99888483e-01],\n", - " [-4.07818049e-01, 2.46473923e-02, 2.97356337e-01]],\n", - "\n", - " [[-1.21665947e-01, 1.10766761e-01, 1.90921292e-01],\n", - " [ 1.72671199e-01, -2.80128568e-01, -5.12248397e-01],\n", - " [-6.35129819e-03, 2.95040756e-01, -1.44337133e-01],\n", - " ...,\n", - " [-4.03462024e-03, -5.17651774e-02, -3.15233879e-02],\n", - " [-2.14578044e-02, 1.35702223e-01, -3.22658598e-01],\n", - " [-2.47183263e-01, 1.64212957e-01, 7.46603459e-02]]],\n", - "\n", - "\n", - " [[[-1.40348166e-01, -4.02131945e-01, -2.66431749e-01],\n", - " [-1.40490159e-01, 5.47040179e-02, -1.76301878e-02],\n", - " [-3.17811161e-01, -4.72789928e-02, 2.83692591e-02],\n", - " ...,\n", - " [-1.47473574e-01, -6.10140525e-02, -1.86585739e-01],\n", - " [-3.40489328e-01, -4.47973758e-02, 3.39297622e-01],\n", - " [ 3.59431729e-02, -3.39788616e-01, -1.93661228e-02]],\n", - "\n", - " [[-7.13095784e-01, -3.79408956e-01, -1.57573476e-01],\n", - " [-3.27915639e-01, -1.67500286e-03, -4.80078794e-02],\n", - " [-1.56435639e-01, -2.72406757e-01, 7.88685158e-02],\n", - " ...,\n", - " [ 1.61855876e-01, -1.31664068e-01, 4.37129319e-01],\n", - " [ 3.61562707e-02, -4.02420312e-01, -2.95527577e-01],\n", - " [ 2.09122580e-02, 2.08237693e-01, 1.27728313e-01]],\n", - "\n", - " [[ 1.64767861e-01, 1.69824496e-01, 3.50465983e-01],\n", - " [ 2.77423859e-01, -1.76073000e-01, 1.59218553e-02],\n", - " [ 7.22590312e-02, -1.16188765e-01, -3.31168883e-02],\n", - " ...,\n", - " [ 2.78191477e-01, 7.04992041e-02, 7.45742582e-03],\n", - " [ 4.02962677e-02, -1.98125705e-01, -1.09192230e-01],\n", - " [-3.96155953e-01, 7.06734583e-02, 2.26026312e-01]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.41948834e-01, -3.87322664e-01, -4.27537948e-01],\n", - " [-2.19580293e-01, 2.15253457e-02, -9.88570526e-02],\n", - " [-5.46731502e-02, -1.30170107e-01, 1.33895233e-01],\n", - " ...,\n", - " [ 2.00435042e-01, -2.03811720e-01, -2.48642564e-01],\n", - " [ 1.78164870e-01, 1.67930827e-01, 4.01013754e-02],\n", - " [-3.63528252e-01, 8.78631398e-02, 2.00340524e-01]],\n", - "\n", - " [[ 2.71228731e-01, 1.40770316e-01, -5.98722836e-04],\n", - " [-3.91015001e-02, 5.72595708e-02, 9.59905088e-02],\n", - " [-2.78714187e-02, 1.50697842e-01, 3.43184359e-02],\n", - " ...,\n", - " [-2.88357377e-01, -2.69743823e-03, -2.54112091e-02],\n", - " [-4.65802960e-02, -4.25636828e-01, 1.90108031e-01],\n", - " [-2.41023764e-01, 9.39113200e-02, -5.46947569e-02]],\n", - "\n", - " [[ 6.73992187e-02, 5.80136590e-02, 6.63957745e-02],\n", - " [ 2.44472787e-01, -1.66147128e-01, -1.88459471e-01],\n", - " [ 2.63425559e-01, 4.23022240e-01, -4.00576741e-02],\n", - " ...,\n", - " [ 1.52654707e-01, 1.47170663e-01, -1.72685385e-01],\n", - " [-2.48365670e-01, -7.92512819e-02, -1.21184878e-01],\n", - " [-9.70430002e-02, 1.27249956e-01, -2.13855192e-01]]]], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apply(params, key, lh)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/src/train.py b/src/train.py index d9a1461..e021928 100644 --- a/src/train.py +++ b/src/train.py @@ -13,50 +13,66 @@ from typing import List, Tuple, Dict from functools import partial from tqdm import tqdm -from src.model import loss_fn, network_fn +from src.model import loss_fn, init, apply -def hyperparam_fn(): +# functions +def hyperparam_fn(): # TODO: perhaps have hyperparam ranges be in config.yaml return { - 'lr': np.random.choice([1e-3, 1e-4, 1e-5]), 'batch_size': np.random.choice([32, 64]), 'n_steps': np.random.randint(low=100, high=200), 'dropout_rate': np.random.uniform(low=0.1, high=0.5), } -def update_fn(params, fmri, img, opt_state, rng, opt): +def update_fn(params, rng, fmri, img, opt_state, opt, dropout_rate): rng, key = jax.random.split(rng) - grads = grad(loss_fn)(params, key, fmri, img) + grads = grad(loss_fn)(params, key, fmri, img, dropout_rate=dropout_rate) updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state -def train_loop(opt, init, train_loader, val_loader, plot_batch, hyperparams, rng): - rng, key = jax.random.split(rng) +def train_loop(rng, opt, train_loader, val_loader, plot_batch, hyperparams): + metrics = [] + rng, key = jax.random.split(rng, 2) lh, rh, img = next(train_loader) params = init(key, lh) opt_state = opt.init(params) - update = partial(update_fn, opt=opt) + update = partial(update_fn, opt=opt, dropout_rate=hyperparams['dropout_rate']) for step in tqdm(range(hyperparams['n_steps'])): rng, key = jax.random.split(rng) lh, rh, img = next(train_loader) - params, opt_state = update(params, lh, img, opt_state, key) + params, opt_state = update(params, key, lh, img, opt_state) if (step % (hyperparams['n_steps'] // 100)) == 0: - evaluate(params, train_loader, val_loader) - # plot_decodings(apply(params, key, plot_batch[0]), plot_batch[2]) - return params + rng, key = jax.random.split(rng) + metrics.append(evaluate(params, key, train_loader, val_loader)) + # plot_pred = apply(params, key, plot_batch[0]) + # plot_decodings(plot_pred) + return metrics + -def evaluate(params, train_loader, val_loader, n_steps=4): - pass +def evaluate(params, rng, train_loader, val_loader, n_steps=10): + # each batch is a tuple(lh, rh, img). Connect n_steps batches into 1 + train_loss, val_loss = 0, 0 + for _ in range(n_steps): + rng, key_train, key_val = jax.random.split(rng, 3) + lh, rh, img = next(train_loader) + train_loss += loss_fn(params, key_train, lh, img) + lh, rh, img = next(val_loader) + val_loss += loss_fn(params, key_val, lh, img) + train_loss /= n_steps + val_loss /= n_steps + return(f'train_loss: {train_loss}, val_loss: {val_loss}') -def train_folds(kfolds, hyperparams, args, seed=0): - init, apply = hk.transform(partial(network_fn, image_size=args.image_size)) +def train_folds(kfolds, hyperparams, seed=0): + metrics = {} rng = jax.random.PRNGKey(seed) - opt = optax.lion(hyperparams['lr']) + opt = optax.lion(1e-3) plot_batch = None - for train_loader, val_loader in kfolds: + for idx, (train_loader, val_loader) in enumerate(kfolds): plot_batch = next(train_loader) if plot_batch is None else plot_batch rng, key = jax.random.split(rng) - params = train_loop(opt, init, apply, train_loader, val_loader, plot_batch, hyperparams, key) \ No newline at end of file + fold_metrics = train_loop(key, opt, train_loader, val_loader, plot_batch, hyperparams) + metrics[idx] = fold_metrics + return metrics \ No newline at end of file