Skip to content

Commit

Permalink
prep for comples model
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Sep 16, 2023
1 parent 3afc7d8 commit 8642673
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 223 deletions.
1 change: 0 additions & 1 deletion config/coco.yaml

This file was deleted.

21 changes: 0 additions & 21 deletions config/config.yaml

This file was deleted.

25 changes: 0 additions & 25 deletions config/fmri.yaml

This file was deleted.

5 changes: 0 additions & 5 deletions config/model.yaml

This file was deleted.

65 changes: 0 additions & 65 deletions config/rois.yaml

This file was deleted.

29 changes: 0 additions & 29 deletions config/sweep.yaml

This file was deleted.

97 changes: 47 additions & 50 deletions notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,69 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"import optax\n",
"import haiku as hk\n",
"import jax\n",
"import numpy as np\n",
"from jax import jit, grad\n",
"import jax.numpy as jnp\n",
"from functools import partial\n",
"from IPython.display import display, HTML, clear_output\n",
"import time\n",
"\n",
"from src.data import load_subject, make_kfolds\n",
"from src.model import loss_fn, init, apply\n",
"from src.plots import plot_brain\n",
"from src.utils import CONFIG, matrix_to_image\n",
"from src.utils import CONFIG\n",
"from src.train import train_folds, hyperparam_fn"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-16 11:29:37.420742: W pjrt_plugin/src/mps_client.cc:535] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metal device set to: Apple M1 Pro\n",
"\n",
"systemMemory: 16.00 GB\n",
"maxCacheSize: 5.33 GB\n",
"\n"
]
},
{
"ename": "TypeError",
"evalue": "sub got incompatible shapes for broadcasting: (64, 512, 512, 3), (64, 32, 32, 3).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/syrkis/code/neuroscope/notebook.ipynb Cell 2\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/syrkis/code/neuroscope/notebook.ipynb#W4sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m hyperparams \u001b[39m=\u001b[39m hyperparam_fn()\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/syrkis/code/neuroscope/notebook.ipynb#W4sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m kfolds \u001b[39m=\u001b[39m make_kfolds(subject, hyperparams)\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/syrkis/code/neuroscope/notebook.ipynb#W4sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m metrics, params \u001b[39m=\u001b[39m train_folds(kfolds, hyperparams)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/syrkis/code/neuroscope/notebook.ipynb#W4sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39m# loader, _ = next(kfolds)\u001b[39;00m\n",
"File \u001b[0;32m~/code/neuroscope/src/train.py:77\u001b[0m, in \u001b[0;36mtrain_folds\u001b[0;34m(kfolds, hyperparams, seed)\u001b[0m\n\u001b[1;32m 75\u001b[0m plot_batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m(train_loader) \u001b[39mif\u001b[39;00m plot_batch \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m plot_batch\n\u001b[1;32m 76\u001b[0m rng, key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(rng)\n\u001b[0;32m---> 77\u001b[0m fold_metrics, fold_params \u001b[39m=\u001b[39m train_loop(key, opt, train_loader, val_loader, plot_batch, hyperparams)\n\u001b[1;32m 78\u001b[0m metrics[idx] \u001b[39m=\u001b[39m fold_metrics\n\u001b[1;32m 79\u001b[0m \u001b[39mreturn\u001b[39;00m metrics, fold_params\n",
"File \u001b[0;32m~/code/neuroscope/src/train.py:46\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(rng, opt, train_loader, val_loader, plot_batch, hyperparams)\u001b[0m\n\u001b[1;32m 44\u001b[0m rng, key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(rng)\n\u001b[1;32m 45\u001b[0m lh, rh, img \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m(train_loader)\n\u001b[0;32m---> 46\u001b[0m params, opt_state \u001b[39m=\u001b[39m update(params, key, lh, img, opt_state)\n\u001b[1;32m 47\u001b[0m \u001b[39mif\u001b[39;00m (step \u001b[39m%\u001b[39m (hyperparams[\u001b[39m'\u001b[39m\u001b[39mn_steps\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m \u001b[39m100\u001b[39m)) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 48\u001b[0m rng, key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(rng)\n",
"File \u001b[0;32m~/code/neuroscope/src/train.py:30\u001b[0m, in \u001b[0;36mupdate_fn\u001b[0;34m(params, rng, fmri, img, opt_state, opt, dropout_rate)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mupdate_fn\u001b[39m(params, rng, fmri, img, opt_state, opt, dropout_rate):\n\u001b[1;32m 29\u001b[0m rng, key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(rng)\n\u001b[0;32m---> 30\u001b[0m grads \u001b[39m=\u001b[39m grad(loss_fn)(params, key, fmri, img, dropout_rate\u001b[39m=\u001b[39;49mdropout_rate)\n\u001b[1;32m 31\u001b[0m updates, opt_state \u001b[39m=\u001b[39m opt\u001b[39m.\u001b[39mupdate(grads, opt_state, params)\n\u001b[1;32m 32\u001b[0m params \u001b[39m=\u001b[39m optax\u001b[39m.\u001b[39mapply_updates(params, updates)\n",
" \u001b[0;31m[... skipping hidden 10 frame]\u001b[0m\n",
"File \u001b[0;32m~/code/neuroscope/src/model.py:47\u001b[0m, in \u001b[0;36mloss_fn\u001b[0;34m(params, rng, fmri, img, dropout_rate)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mloss_fn\u001b[39m(params: hk\u001b[39m.\u001b[39mParams, rng: jnp\u001b[39m.\u001b[39mndarray, fmri: jnp\u001b[39m.\u001b[39mndarray, img: jnp\u001b[39m.\u001b[39mndarray, dropout_rate: Optional[\u001b[39mfloat\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m jnp\u001b[39m.\u001b[39mndarray:\n\u001b[1;32m 46\u001b[0m pred \u001b[39m=\u001b[39m apply(params, rng, fmri, dropout_rate\u001b[39m=\u001b[39mdropout_rate)\n\u001b[0;32m---> 47\u001b[0m loss \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mmean((pred \u001b[39m-\u001b[39;49m img) \u001b[39m*\u001b[39m\u001b[39m*\u001b[39m \u001b[39m2\u001b[39m)\n\u001b[1;32m 48\u001b[0m \u001b[39mreturn\u001b[39;00m loss\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791\u001b[0m, in \u001b[0;36m_forward_operator_to_aval.<locals>.op\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mop\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs):\n\u001b[0;32m--> 791\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mgetattr\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49maval, \u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m_\u001b[39;49m\u001b[39m{\u001b[39;49;00mname\u001b[39m}\u001b[39;49;00m\u001b[39m\"\u001b[39;49m)(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs)\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258\u001b[0m, in \u001b[0;36m_defer_to_unrecognized_arg.<locals>.deferring_binary_op\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 256\u001b[0m args \u001b[39m=\u001b[39m (other, \u001b[39mself\u001b[39m) \u001b[39mif\u001b[39;00m swap \u001b[39melse\u001b[39;00m (\u001b[39mself\u001b[39m, other)\n\u001b[1;32m 257\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(other, _accepted_binop_types):\n\u001b[0;32m--> 258\u001b[0m \u001b[39mreturn\u001b[39;00m binary_op(\u001b[39m*\u001b[39;49margs)\n\u001b[1;32m 259\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(other, _rejected_binop_types):\n\u001b[1;32m 260\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39munsupported operand type(s) for \u001b[39m\u001b[39m{\u001b[39;00mopchar\u001b[39m}\u001b[39;00m\u001b[39m: \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 261\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(args[\u001b[39m0\u001b[39m])\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m!r}\u001b[39;00m\u001b[39m and \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(args[\u001b[39m1\u001b[39m])\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m!r}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
" \u001b[0;31m[... skipping hidden 12 frame]\u001b[0m\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:82\u001b[0m, in \u001b[0;36m_one_to_one_binop.<locals>.<lambda>\u001b[0;34m(x1, x2)\u001b[0m\n\u001b[1;32m 80\u001b[0m fn \u001b[39m=\u001b[39m \u001b[39mlambda\u001b[39;00m x1, x2, \u001b[39m/\u001b[39m: lax_fn(\u001b[39m*\u001b[39mpromote_args_numeric(numpy_fn\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, x1, x2))\n\u001b[1;32m 81\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 82\u001b[0m fn \u001b[39m=\u001b[39m \u001b[39mlambda\u001b[39;00m x1, x2, \u001b[39m/\u001b[39m: lax_fn(\u001b[39m*\u001b[39;49mpromote_args(numpy_fn\u001b[39m.\u001b[39;49m\u001b[39m__name__\u001b[39;49m, x1, x2))\n\u001b[1;32m 83\u001b[0m fn\u001b[39m.\u001b[39m\u001b[39m__qualname__\u001b[39m \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mjax.numpy.\u001b[39m\u001b[39m{\u001b[39;00mnumpy_fn\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 84\u001b[0m fn \u001b[39m=\u001b[39m jit(fn, inline\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n",
" \u001b[0;31m[... skipping hidden 7 frame]\u001b[0m\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/lax/lax.py:1577\u001b[0m, in \u001b[0;36mbroadcasting_shape_rule\u001b[0;34m(name, *avals)\u001b[0m\n\u001b[1;32m 1575\u001b[0m result_shape\u001b[39m.\u001b[39mappend(non_1s[\u001b[39m0\u001b[39m])\n\u001b[1;32m 1576\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1577\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mname\u001b[39m}\u001b[39;00m\u001b[39m got incompatible shapes for broadcasting: \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 1578\u001b[0m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m, \u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mmap\u001b[39m(\u001b[39mstr\u001b[39m,\u001b[39m \u001b[39m\u001b[39mmap\u001b[39m(\u001b[39mtuple\u001b[39m,\u001b[39m \u001b[39mshapes)))\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m 1580\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mtuple\u001b[39m(result_shape)\n",
"\u001b[0;31mTypeError\u001b[0m: sub got incompatible shapes for broadcasting: (64, 512, 512, 3), (64, 32, 32, 3)."
]
}
],
"source": [
"subject = load_subject('subj05', image_size=CONFIG['image_size'])\n",
"hyperparams = hyperparam_fn()\n",
"kfolds = make_kfolds(subject, hyperparams)\n",
"metrics, params = train_folds(kfolds, hyperparams)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def display_image(matrix_lst):\n",
" html = '<div style=\"width: 300px; height: 300px; display: flex; flex-wrap: wrap;\">'\n",
" for matrix in matrix_lst:\n",
" image = matrix_to_image(matrix)\n",
" html += f\"\"\"\n",
" <div style=\"display: flex; justify-content: center; align-items: center; width: 140px; height: 140px; margin: 5px; background-image: url('data:image/png;base64,{image}'); background-size: cover; background-position: center; background-repeat: no-repeat;\">\n",
" </div>\"\"\"\n",
"\n",
" html += '</div>'\n",
" clear_output(wait=True)\n",
" display(HTML(html))\n",
"\n",
"# Example usage with a random 100x100 matrix\n",
"for i in range(10):\n",
" matrix_lst = [np.random.rand(100, 100) for _ in range(4)]\n",
" display_image(matrix_lst)\n",
" time.sleep(1)"
"metrics, params = train_folds(kfolds, hyperparams)\n",
"# loader, _ = next(kfolds)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
28 changes: 21 additions & 7 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,32 @@

# Define your function
def network_fn(fmri: jnp.ndarray, dropout_rate: Optional[float] = None) -> jnp.ndarray:
layers = [hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu]
fc = [hk.Linear(128), jax.nn.gelu,
hk.Linear(256), jax.nn.gelu,
hk.Linear(256), jax.nn.gelu,
hk.Linear(CONFIG['image_size'] * CONFIG['image_size'] * 3)]

fmri = hk.Sequential(layers)(fmri)
# apply fc and reshape to image
z = hk.Sequential(fc)(fmri)

# apply dropout if training
if dropout_rate is not None:
rng = hk.next_rng_key()
fmri = hk.dropout(rng, dropout_rate, fmri)
z = hk.dropout(rng, dropout_rate, z)

fmri = hk.Linear(CONFIG['image_size'] * CONFIG['image_size'] * 3)(fmri)
fmri = fmri.reshape(-1, CONFIG['image_size'], CONFIG['image_size'], 3)
return fmri
# reshape to image
z = z.reshape(-1, CONFIG['image_size'], CONFIG['image_size'], 3)

# deconv (transpose conv) layers
deconv = [hk.Conv2DTranspose(output_channels=21, kernel_shape=3, stride=2, padding='SAME'), jax.nn.gelu,
hk.Conv2DTranspose(output_channels=64, kernel_shape=3, stride=2, padding='SAME'), jax.nn.gelu,
hk.Conv2DTranspose(output_channels=32, kernel_shape=3, stride=2, padding='SAME'), jax.nn.gelu,
hk.Conv2DTranspose(output_channels=3, kernel_shape=3, stride=2, padding='SAME')]

# apply deconv
z = hk.Sequential(deconv)(z)
z = jax.nn.sigmoid(z)
return z


def loss_fn(params: hk.Params, rng: jnp.ndarray, fmri: jnp.ndarray, img: jnp.ndarray, dropout_rate: Optional[float] = None) -> jnp.ndarray:
Expand Down
33 changes: 16 additions & 17 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,27 @@
import base64
from PIL import Image as PILImage
from io import BytesIO
from jinja2 import Template, Environment, FileSystemLoader
from jax import vmap
from src.fmri import ATLAS, fsaverage_vec

from src.utils import matrix_to_image, CONFIG


# globals
plt.style.use("dark_background")
env = Environment(loader=FileSystemLoader('templates'))


def plot_small_multiples_html(pred_batch, target_batch, n_cols=3):
pred_batch, target_batch = np.array(pred_batch[: n_cols ** 2]), np.array(target_batch[: n_cols ** 2])
batch = np.zeros_like(pred_batch)
batch[:, :, : CONFIG['image_size'] // 2, :] = pred_batch[:, :, : CONFIG['image_size'] // 2, :]
batch[:, :, CONFIG['image_size'] // 2 :, :] = target_batch[:, :, CONFIG['image_size'] // 2 :, :]
images = [ matrix_to_image(pred) for pred in batch ]
template = env.get_template('images.html')
html = template.render(images=images)
clear_output(wait=True)
display(HTML(html))


# functions
Expand All @@ -42,19 +57,3 @@ def plot_brain(challenge_vec, subject, hem, roi=None):
black_bg=True,
)
return view.resize(height=900, width=1200)

# plot decodings
def monitor_decoding(decodings, n=3):
"""small multiple gifs of decodings at differnt stages of training"""
decodings = decodings[: n * n]
fig, axs = plt.subplots(n, n, figsize=(n * 2, n * 2))
for i, ax in enumerate(axs.flatten()):
ax.imshow(decodings[i])
ax.axis("off")
plt.tight_layout()
plt.close()
return fig


def plot_decoding_progress():
pass
7 changes: 4 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import partial
from tqdm import tqdm
from src.model import loss_fn, init, apply
from src.plots import plot_small_multiples_html


# functions
Expand All @@ -39,15 +40,15 @@ def train_loop(rng, opt, train_loader, val_loader, plot_batch, hyperparams):
params = init(key, lh)
opt_state = opt.init(params)
update = partial(update_fn, opt=opt, dropout_rate=hyperparams['dropout_rate'])
for step in tqdm(range(hyperparams['n_steps'])):
for step in range(hyperparams['n_steps']):
rng, key = jax.random.split(rng)
lh, rh, img = next(train_loader)
params, opt_state = update(params, key, lh, img, opt_state)
if (step % (hyperparams['n_steps'] // 100)) == 0:
rng, key = jax.random.split(rng)
metrics.append(evaluate(params, key, train_loader, val_loader))
# plot_pred = apply(params, key, plot_batch[0])
# plot_decodings(plot_pred)
plot_pred = apply(params, key, plot_batch[0])
plot_small_multiples_html(plot_pred, plot_batch[2])
return metrics, params


Expand Down
6 changes: 6 additions & 0 deletions templates/images.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 60px; width: 1200px; height: 1200px; margin: 0; background-color: black; padding: 60px;">
{% for image in images %}
<div style="display: flex; justify-content: center; align-items: center; background-image: url('data:image/png;base64,{{image}}'); background-size: cover; background-position: center; background-repeat: no-repeat;">
</div>
{% endfor %}
</div>

0 comments on commit 8642673

Please sign in to comment.