From a0b67f62e3493d8a69d7b6925c14abcac047e29f Mon Sep 17 00:00:00 2001 From: Noah Syrkis Date: Tue, 21 Nov 2023 17:43:54 -0300 Subject: [PATCH] test fmri modality --- notebook.ipynb | 457 ++++++++++++++++++++++++++++++++++--------------- src/data.py | 59 +++++-- 2 files changed, 362 insertions(+), 154 deletions(-) diff --git a/notebook.ipynb b/notebook.ipynb index 3a5f95f..2ef42c4 100644 --- a/notebook.ipynb +++ b/notebook.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -29,37 +29,37 @@ "import matplotlib.pyplot as plt\n", "\n", "import syrkis\n", - "from src.data import load_subject, make_kfolds" + "from src.data import load_subjects, make_kfolds" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 38, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Metal device set to: Apple M1 Pro\n", - "\n", - "systemMemory: 16.00 GB\n", - "maxCacheSize: 5.33 GB\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-11-20 22:08:51.202607: W pjrt_plugin/src/mps_client.cc:534] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!\n" - ] - } - ], + "outputs": [], "source": [ "# GLOBALS\n", "rng = random.PRNGKey(0)\n", - "cfg = syrkis.training.load_config()" + "cfg = syrkis.training.load_config()\n", + "cfg['image_size'] = 28 if cfg['dataset'] == 'mnist' else cfg['image_size']" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "def mnist_loader(rng, data_split):\n", + " split_n_samples = data_split.shape[0]\n", + " n_batches = split_n_samples // cfg['batch_size']\n", + " \n", + " while True:\n", + " rng, key = jax.random.split(rng)\n", + " idxs = jax.random.permutation(key, split_n_samples)\n", + " for i in range(n_batches):\n", + " batch_idxs = idxs[i * cfg['batch_size']:(i + 1) * cfg['batch_size']]\n", + " yield None, None, data_split[batch_idxs], None" ] }, { @@ -71,22 +71,62 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 40, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of samples: 7872\n", + "Number of steps: 12300\n", + "Number of batches: 384\n", + "Number of epochs: 50\n" + ] + } + ], "source": [ - "if 'subject' not in locals():\n", - " subject = load_subject('subj07', image_size=cfg['image_size'])\n", - " kfolds = make_kfolds(subject, cfg)\n", - " batches, neuroscope_eval_batches = next(kfolds) # type: ignore\n", - " lh_sample, rh_sample, img_sample = next(batches)" + "if 'subjects' not in locals() and cfg['dataset'] == 'neuroscope':\n", + " subjects = load_subjects(['subj07'], cfg['image_size'])\n", + " n_samples = sum([len(s[0]) for s in subjects.values()])\n", + " kfolds = make_kfolds(subjects, cfg)\n", + " train_batches, eval_batches = next(kfolds)\n", + " lh_sample, rh_sample, img_sample, subject_idx_sample = next(train_batches)\n", + "\n", + "if 'mnist_data' not in locals() and cfg['dataset'] == 'mnist':\n", + " mnist_data = syrkis.data.mnist()\n", + " n_samples = mnist_data[0].shape[0]\n", + " mnist_train, mnist_eval = mnist_data[0][:55000], mnist_data[0][55000:]\n", + " rng, train_rng, eval_rng = jax.random.split(rng, 3)\n", + " train_batches = mnist_loader(train_rng, mnist_train)\n", + " eval_batches = mnist_loader(eval_rng, mnist_eval)\n", + " lh_sample, rh_sample, img_sample, subject_idx_sample = next(train_batches)\n", + "\n", + "\n", + "\n", + "print(f'Number of samples: {n_samples}')\n", + "n_steps = n_samples // cfg['batch_size'] * cfg['epochs']\n", + "print(f'Number of steps: {n_steps}')\n", + "print(f'Number of batches: {n_steps // cfg[\"batch_size\"]}')\n", + "print(f'Number of epochs: {cfg[\"epochs\"]}')" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 41, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12288\n", + "24576\n", + "49152\n", + "64\n" + ] + } + ], "source": [ "def glorot_init(rng, shape):\n", " if len(shape) == 2: # Dense layer\n", @@ -101,22 +141,25 @@ "def latent_side_fn(cfg):\n", " return cfg['image_size'] // cfg['stride'] ** cfg['conv_layers']\n", "\n", + "\n", "def latent_dim_fn(cfg):\n", " # should return the size of the latente dim depending on initial image size, stride, and number of layers, and channels\n", " image_size = cfg['image_size']\n", " stride = cfg['stride']\n", " layers = cfg['conv_layers']\n", " channels = 3 if cfg['dataset'] == 'neuroscope' else 1\n", - " if stride == 1:\n", - " latent_dim = image_size ** 2 * channels * 2 ** layers\n", - " if stride == 2:\n", - " latent_dim = image_size ** 2 * channels // (stride ** layers)\n", - " # Multiply by 2 because of bug (this is a temporary fix)\n", - " return latent_dim\n", + " dim = image_size ** 2 * channels\n", + " for i in range(layers):\n", + " print(dim)\n", + " dim = (dim * 2) // (stride ** 2)\n", + " return dim\n", + " \n", "\n", "\n", "latent_dim = latent_dim_fn(cfg)\n", - "latent_side = latent_side_fn(cfg)\n" + "latent_side = latent_side_fn(cfg)\n", + "print(latent_dim)\n", + "print(latent_side)" ] }, { @@ -128,12 +171,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "@jit\n", "def batch_norm(x, gamma, beta, eps=1e-5):\n", + " if not cfg['batch_norm']:\n", + " return x\n", " # x: batch x height x width x channels\n", " axis = tuple(range(len(x.shape) - 1))\n", " mean = jnp.mean(x, axis=axis, keepdims=True)\n", @@ -159,32 +204,44 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 43, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-1.9884775 0.638638 1.3899286 -0.11827302]\n" + ] + } + ], "source": [ - "def init_linear_layer(rng, in_dim, out_dim):\n", + "def init_linear_layer(rng, in_dim, out_dim, tensor_dim):\n", + " # tensor dim is for having fmri embedding in same array, but seperate layers.\n", " rng, key = jax.random.split(rng, 2)\n", " w_shape = (in_dim, out_dim)\n", " b_shape = (out_dim,)\n", " w = glorot_init(key, w_shape)\n", + " if tensor_dim > 0:\n", + " w = w.reshape((-1, out_dim, tensor_dim))\n", " b = jnp.zeros(b_shape)\n", " gamma, beta = init_batch_norm(b_shape)\n", " return w, b, gamma, beta\n", "\n", - "def init_linear_layers(rng, in_dim, out_dim, cfg):\n", + "def init_linear_layers(rng, in_dim, out_dim, cfg, tensor_dim=0):\n", " # first layer goes from in_dim to embed_dim, rest are embed to embed, and last is embed to out\n", " rngs = jax.random.split(rng, cfg['fc_layers'])\n", " params = []\n", " for idx, rng in enumerate(rngs):\n", " layer_in_dim = cfg['embed_dim'] if idx != 0 else in_dim\n", " layer_out_dim = cfg['embed_dim'] if idx != cfg['fc_layers'] - 1 else out_dim\n", - " params.append(init_linear_layer(rng, layer_in_dim, layer_out_dim))\n", + " params.append(init_linear_layer(rng, layer_in_dim, layer_out_dim, tensor_dim))\n", " return params\n", "\n", "def linear(params, x):\n", " for idx, (w, b, gamma, beta) in enumerate(params):\n", - " x = jax.nn.gelu(x @ w + b)\n", + " x = x @ w + b\n", + " x = jax.nn.gelu(x) if idx != len(params) - 1 else x\n", " x = batch_norm(x, gamma, beta) if idx != len(params) - 1 else x\n", " return x\n", "\n", @@ -195,7 +252,9 @@ " params = init_linear_layers(rng, 2, 4, cfg)\n", " y = linear(params, x)\n", " assert y.shape == (4,)\n", - " print(y)" + " print(y)\n", + "\n", + "test_linear()" ] }, { @@ -207,13 +266,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "# Global constants for common parameters\n", "DIMENSION_NUMBERS = (\"NHWC\", \"HWIO\", \"NHWC\")\n", "\n", + "\n", "@jit\n", "def conv2d(x, w):\n", " return jax.lax.conv_general_dilated(\n", @@ -222,6 +282,7 @@ " padding='SAME',\n", " dimension_numbers=DIMENSION_NUMBERS)\n", "\n", + "\n", "@jit\n", "def upscale_nearest_neighbor(x, scale_factor=cfg['stride']):\n", " # Assuming x has shape (batch, height, width, channels)\n", @@ -230,6 +291,7 @@ " x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))\n", " return x.reshape(b, h * scale_factor, w * scale_factor, c)\n", "\n", + "\n", "@jit\n", "def deconv2d(x, w):\n", " x_upscaled = upscale_nearest_neighbor(x)\n", @@ -239,18 +301,21 @@ " padding='SAME',\n", " dimension_numbers=DIMENSION_NUMBERS) \n", "\n", + "\n", "def conv_fn(fn):\n", " def apply_fn(params, x):\n", " for i, (w, b, gamma, beta) in enumerate(params):\n", " x = fn(x, w, b)\n", - " x = batch_norm(x, gamma, beta)\n", + " x = batch_norm(x, gamma, beta) if i != len(params) - 1 else x\n", " x = jax.nn.gelu(x) if i != len(params) - 1 else x\n", " return x\n", " return apply_fn\n", "\n", + "\n", "conv = conv_fn(lambda x, w, b: conv2d(x, w) + b)\n", "deconv = conv_fn(lambda x, w, b: deconv2d(x, w) + b)\n", "\n", + "\n", "def init_conv_params(rng, chan, cfg, deconv=False):\n", " rng, key = jax.random.split(rng, 2)\n", " in_chan = chan * 2 if deconv else chan\n", @@ -261,6 +326,7 @@ " gamma, beta = init_batch_norm(b.shape)\n", " return w, b, gamma, beta\n", "\n", + "\n", "def init_conv_layers(rng, cfg, deconv=False):\n", " ds_channels = 3 if cfg['dataset'] == 'neuroscope' else 1\n", " rngs = jax.random.split(rng, cfg['conv_layers'])\n", @@ -271,9 +337,7 @@ " return params[::-1] if deconv else params\n", "\n", "\n", - "\n", "def test_conv():\n", - " cfg = syrkis.training.load_config()\n", " channels = 3 if cfg['dataset'] == 'neuroscope' else 1\n", " x = jnp.ones((cfg['batch_size'], cfg['image_size'], cfg['image_size'], channels))\n", " conv_params = init_conv_layers(jax.random.PRNGKey(0), cfg)\n", @@ -293,17 +357,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ - "def reparametrize(mu, logvar, rng):\n", - " # mu, logvar: (batch, embed_dim)\n", - " std = jnp.exp(0.5 * logvar)\n", - " eps = jax.random.normal(rng, mu.shape)\n", - " return mu + eps * std\n", - "\n", - "\n", "def print_model(params):\n", " print(f'{syrkis.training.n_params(params)} total params', end='\\n\\n')\n", " print('\\tconv params')\n", @@ -320,27 +377,31 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ - "def kl_divergence(mu, logvar):\n", - " return -0.5 * jnp.sum(1 + logvar - mu ** 2 - jnp.exp(logvar), axis=1).mean()\n", - "\n", - "def cross_entropy(logits, labels, epsilon=1e-12):\n", - " max_logits = jnp.max(logits, axis=1, keepdims=True)\n", - " stabilized_logits = logits - max_logits\n", - " log_sum_exp = jnp.log(jnp.sum(jnp.exp(stabilized_logits), axis=1, keepdims=True) + epsilon)\n", - " labels_one_hot = jnp.eye(logits.shape[1])[labels]\n", - " loss = -jnp.mean(labels_one_hot * (stabilized_logits - log_sum_exp))\n", - " return loss\n", - "\n", - "def encode_fn(params, x):\n", - " z = jax.nn.gelu(conv(params['conv'], x))\n", + "\n", + "\n", + "def encode_img_fn(params, img):\n", + " z = jax.nn.gelu(conv(params['conv'], img))\n", " z = z.reshape(cfg['batch_size'], -1)\n", " z = jax.nn.gelu(linear(params['linear_encode'], z))\n", - " mu, logvar = z, z\n", - " return mu, logvar\n", + " return z\n", + "\n", + "def matmul_slice(A, B_slice):\n", + " return jnp.dot(A, B_slice)\n", + "batched_matmul = vmap(matmul_slice, in_axes=(None, 2))\n", + "\n", + "def encode_fmri_fn(params, fmri, subj):\n", + " embed_cube = params['fmri_embed'][0][0] # fmri_dim x embed_dim x n_subjects\n", + " z = batched_matmul(fmri, embed_cube) # n_subjects x batch_size x embed_dim\n", + " z = z.transpose((2, 0, 1)) # embed_dim x n_subjects x batch_size # prep for broadcast\n", + " one_hot = jax.nn.one_hot(subj, len(subjects)).T # n_subjects x batch_size # prep for broadcast\n", + " z = one_hot * z # embed_dim x n_subjects x batch_size # one subject dimension has zeros after broadcast\n", + " z = z.sum(axis=1).T # batch_size x embed_dim # sum over subjects\n", + " z = jax.nn.gelu(linear(params['fmri_dense'], z))\n", + " return z\n", "\n", "def decode_fn(params, z):\n", " z = jax.nn.gelu(linear(params['linear_decode'], z))\n", @@ -349,29 +410,68 @@ " return jax.nn.sigmoid(z)\n", "\n", "@jit\n", - "def apply_fn(params, x):\n", - " mu, logvar = encode_fn(params, x)\n", - " z = decode_fn(params, mu) # reparametrize(mu, logvar, rng))\n", - " return jax.nn.sigmoid(z)\n", - "\n", - "def loss_fn(params, x):\n", - " x_hat = apply_fn(params, x)\n", - " return jnp.mean((x - x_hat) ** 2)\n", - "\n", - "def update_fn(params, x, opt_state, opt):\n", - " loss, grads = value_and_grad(loss_fn)(params, x)\n", + "def apply_fn(params, fmri, img, subj):\n", + " if cfg['dataset'] == 'mnist' or cfg['source'] == 'img':\n", + " z = encode_img_fn(params, img)\n", + " else:\n", + " z = encode_fmri_fn(params, fmri, subj)\n", + " x_hat = decode_fn(params, z) # reparametrize(mu, logvar, rng))\n", + " return x_hat\n", + "\n", + "def loss_fn(params, fmri, img, subj):\n", + " x_img = apply_fn(params, fmri, img, subj)\n", + " return jnp.mean((img - x_img) ** 2)\n", + "\n", + "def update_fn(params, fmri, img, subj, opt_state, opt):\n", + " loss, grads = value_and_grad(loss_fn)(params, fmri, img, subj)\n", " updates, opt_state = opt.update(grads, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " return params, opt_state, loss\n", "\n", "def init_fn(rng, cfg):\n", - " conv_params = init_conv_layers(rng, cfg)\n", - " deconv_params = init_conv_layers(rng, cfg, deconv=True)\n", - " linear_encode_params = init_linear_layers(rng, latent_dim, cfg['embed_dim'], cfg) # linear layers\n", - " linear_decode_params = init_linear_layers(rng, cfg['embed_dim'], latent_dim, cfg) # linear layers\n", - " params = {\"conv\": conv_params, \"deconv\": deconv_params,\n", - " \"linear_encode\": linear_encode_params, \"linear_decode\": linear_decode_params}\n", + "\n", + " if cfg['dataset'] == 'mnist' or cfg['source'] == 'img':\n", + " conv_params = init_conv_layers(rng, cfg)\n", + " linear_encode_params = init_linear_layers(rng, latent_dim, cfg['embed_dim'], cfg) # linear layers\n", + " linear_decode_params = init_linear_layers(rng, cfg['embed_dim'], latent_dim, cfg) # linear layers\n", + " deconv_params = init_conv_layers(rng, cfg, deconv=True)\n", + " params = {\"conv\": conv_params, \"deconv\": deconv_params,\n", + " \"linear_encode\": linear_encode_params, \"linear_decode\": linear_decode_params}\n", + "\n", + " if cfg['dataset'] != 'mnist' and cfg['source'] == 'fmri':\n", + " fmri_embed_params = init_linear_layers(rng, 19004 * len(subjects), cfg['embed_dim'], cfg, tensor_dim=len(subjects)) # linear layers\n", + " fmri_dense_params = init_linear_layers(rng, cfg['embed_dim'], cfg['embed_dim'], cfg) # linear layers\n", + " linear_decode_params = init_linear_layers(rng, cfg['embed_dim'], latent_dim, cfg) # linear layers\n", + " deconv_params = init_conv_layers(rng, cfg, deconv=True)\n", + " params = {\"fmri_embed\": fmri_embed_params, \"fmri_dense\": fmri_dense_params,\n", + " \"deconv\": deconv_params, \"linear_decode\": linear_decode_params}\n", + "\n", " return params\n", + "\n", + "\n", + "def print_model(params):\n", + " print(f'{syrkis.training.n_params(params)} total params', end='\\n\\n')\n", + " if cfg['dataset'] == 'mnist' or cfg['source'] == 'img':\n", + " print('\\tconv params')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['conv']):\n", + " print(f'\\t\\tconv_{idx}\\t\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['deconv']):\n", + " print(f'\\t\\tdeconv_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " print('\\n\\tlinear params')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['linear_encode']):\n", + " print(f'\\t\\tencode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['linear_decode']):\n", + " print(f'\\t\\tdecode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " if cfg['dataset'] != 'mnist' and cfg['source'] == 'fmri':\n", + " print('\\tfmri params')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['fmri_embed']):\n", + " print(f'\\t\\tfmri_embed_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['fmri_dense']):\n", + " print(f'\\t\\tfmri_dense_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['linear_decode']):\n", + " print(f'\\t\\tdecode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", + " for idx, (w, b, gamma, beta) in enumerate(params['deconv']):\n", + " print(f'\\t\\tdeconv_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", "\n" ] }, @@ -384,45 +484,36 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "840249 total params\n", + "12732453 total params\n", "\n", "\tconv params\n", "\t\tconv_0\t\t: 168\t|\t3 x 3 x 3 x 6\n", "\t\tconv_1\t\t: 660\t|\t3 x 3 x 6 x 12\n", - "\t\tconv_2\t\t: 2616\t|\t3 x 3 x 12 x 24\n", - "\t\tconv_3\t\t: 10416\t|\t3 x 3 x 24 x 48\n", - "\t\tconv_4\t\t: 41568\t|\t3 x 3 x 48 x 96\n", - "\t\tconv_5\t\t: 166080\t|\t3 x 3 x 96 x 192\n", - "\t\tdeconv_0\t: 165984\t|\t3 x 3 x 192 x 96\n", - "\t\tdeconv_1\t: 41520\t|\t3 x 3 x 96 x 48\n", - "\t\tdeconv_2\t: 10392\t|\t3 x 3 x 48 x 24\n", - "\t\tdeconv_3\t: 2604\t|\t3 x 3 x 24 x 12\n", - "\t\tdeconv_4\t: 654\t|\t3 x 3 x 12 x 6\n", - "\t\tdeconv_5\t: 165\t|\t3 x 3 x 6 x 3\n", + "\t\tdeconv_0\t: 654\t|\t3 x 3 x 12 x 6\n", + "\t\tdeconv_1\t: 165\t|\t3 x 3 x 6 x 3\n", "\n", "\tlinear params\n", - "\t\tencode_0\t: 196864\t|\t768 x 256\n", - "\t\tdecode_0\t: 197376\t|\t256 x 768\n" + "\t\tencode_0\t: 6291584\t|\t49152 x 128\n", + "\t\tdecode_0\t: 6340608\t|\t128 x 49152\n" ] } ], "source": [ "rng = jax.random.PRNGKey(0)\n", "params = init_fn(rng, cfg)\n", - "losses = []\n", "print_model(params)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -486,51 +577,99 @@ "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", "
\n", " \n", "
\n", - " \"image\"\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", + "
\n", + " \n", + "
\n", + " \"image\"\n", "
\n", " \n", "
\n", @@ -538,7 +677,11 @@ " \n", @@ -553,30 +696,70 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/syrkis/code/neuroscope/notebook.ipynb Cell 19\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 29\u001b[0m wandb\u001b[39m.\u001b[39mlog({\u001b[39m\"\u001b[39m\u001b[39mtrain_loss\u001b[39m\u001b[39m\"\u001b[39m: train_loss, \u001b[39m\"\u001b[39m\u001b[39meval_loss\u001b[39m\u001b[39m\"\u001b[39m: eval_loss})\n\u001b[1;32m 30\u001b[0m wandb\u001b[39m.\u001b[39mfinish()\n\u001b[0;32m---> 32\u001b[0m train(params, opt_state, cfg, train_batches, eval_batches, monitor\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m 33\u001b[0m z \u001b[39m=\u001b[39m conv(params[\u001b[39m'\u001b[39m\u001b[39mconv\u001b[39m\u001b[39m'\u001b[39m], img_sample)\n\u001b[1;32m 34\u001b[0m z \u001b[39m=\u001b[39m z\u001b[39m.\u001b[39mreshape(cfg[\u001b[39m'\u001b[39m\u001b[39mbatch_size\u001b[39m\u001b[39m'\u001b[39m], \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n", + "\u001b[1;32m/Users/syrkis/code/neuroscope/notebook.ipynb Cell 19\u001b[0m line \u001b[0;36m2\n\u001b[1;32m 20\u001b[0m lh, rh, img, subj \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m(train_batches)\n\u001b[1;32m 21\u001b[0m params, opt_state, loss \u001b[39m=\u001b[39m update(params, lh, img, subj, opt_state)\n\u001b[0;32m---> 22\u001b[0m info_bar \u001b[39m=\u001b[39m [\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mloss : \u001b[39m\u001b[39m{\u001b[39;00mloss\u001b[39m:\u001b[39;00m\u001b[39m.4f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mstep : \u001b[39m\u001b[39m{\u001b[39;00mstep\u001b[39m}\u001b[39;00m\u001b[39m / \u001b[39m\u001b[39m{\u001b[39;00mn_steps\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 23\u001b[0m imgs \u001b[39m=\u001b[39m [img_sample[:\u001b[39m6\u001b[39m], apply_fn(params, lh, img_sample, subject_idx_sample)[:\u001b[39m6\u001b[39m],\n\u001b[1;32m 24\u001b[0m train_eval_img[:\u001b[39m6\u001b[39m], apply_fn(params, train_eval_lh, train_eval_img, train_eval_subj)[:\u001b[39m6\u001b[39m]]\n\u001b[1;32m 25\u001b[0m syrkis\u001b[39m.\u001b[39mtraining\u001b[39m.\u001b[39mplot_multiples(jnp\u001b[39m.\u001b[39mconcatenate(imgs, axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m), n_rows\u001b[39m=\u001b[39m\u001b[39m4\u001b[39m, info_bar\u001b[39m=\u001b[39minfo_bar, invertable\u001b[39m=\u001b[39mcfg[\u001b[39m'\u001b[39m\u001b[39mdataset\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m==\u001b[39m \u001b[39m'\u001b[39m\u001b[39mmnist\u001b[39m\u001b[39m'\u001b[39m)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/array.py:291\u001b[0m, in \u001b[0;36mArrayImpl.__format__\u001b[0;34m(self, format_spec)\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__format__\u001b[39m(\u001b[39mself\u001b[39m, format_spec):\n\u001b[1;32m 289\u001b[0m \u001b[39m# Simulates behavior of https://github.com/numpy/numpy/pull/9883\u001b[39;00m\n\u001b[1;32m 290\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mndim \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m--> 291\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mformat\u001b[39m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_value[()], format_spec)\n\u001b[1;32m 292\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 293\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mformat\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_value, format_spec)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[39m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mwrapper\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[39mwith\u001b[39;00m TraceAnnotation(name, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 315\u001b[0m \u001b[39mreturn\u001b[39;00m wrapper\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/array.py:516\u001b[0m, in \u001b[0;36mArrayImpl._value\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_npy_value \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 515\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_fully_replicated:\n\u001b[0;32m--> 516\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_npy_value \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_single_device_array_to_np_array() \u001b[39m# type: ignore\u001b[39;00m\n\u001b[1;32m 517\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_npy_value\u001b[39m.\u001b[39mflags\u001b[39m.\u001b[39mwriteable \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 518\u001b[0m \u001b[39mreturn\u001b[39;00m cast(np\u001b[39m.\u001b[39mndarray, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_npy_value)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] } ], "source": [ "\n", - "opt = optax.lion(cfg['lr']) if cfg['optimizer'] == 'lion' else optax.adamw(cfg['lr'])\n", + "opt = optax.adamw(cfg['lr'], weight_decay=0.02)\n", "opt_state = opt.init(params)\n", "update = jit(partial(update_fn, opt=opt))\n", "\n", - "hof = 0\n", - "\n", - "def train(params, opt_state, cfg, batches):\n", - " wandb.init(project=\"neuroscope\", entity=\"syrkis\", config=cfg)\n", - " for step in range(10_000):\n", - " lh, rh, img = next(batches)\n", - " params, opt_state, loss = update(params, img, opt_state)\n", - " losses.append(loss.item())\n", - " info_bar = [f\"loss : {loss:.4f}\", f\"step : {step}\", f\"hof : {hof}\"]\n", - " syrkis.training.plot_multiples(jnp.concatenate([img_sample[:6], apply_fn(params, img_sample)[:6]], axis=0), n_rows=2)\n", - " if step % 20 == 0:\n", - " wandb.log({\"loss\": loss})\n", - " \n", - "train(params, opt_state, cfg, batches)" + "def evaluate(params, train_batches, eval_batches, eval_steps=5):\n", + " train_loss, eval_loss = 0, 0\n", + " for i in range(eval_steps):\n", + " lh_train, rh_train, train_img, subj_train = next(train_batches)\n", + " lh_eval, rh_eval, eval_img, subj_eval = next(eval_batches)\n", + " train_loss += loss_fn(params, lh_train, train_img, subj_train) / eval_steps\n", + " eval_loss += loss_fn(params, lh_eval, eval_img, subj_eval) / eval_steps\n", + " return train_loss, eval_loss\n", + "\n", + "\n", + "train_eval_lh, train_eval_rh, train_eval_img, train_eval_subj = next(train_batches)\n", + "def train(params, opt_state, cfg, train_batches, eval_batches, eval_step=100, monitor=False):\n", + " if monitor:\n", + " wandb.init(project=\"neuroscope\", entity=\"syrkis\", config=cfg)\n", + " for step in range(n_steps):\n", + " lh, rh, img, subj = next(train_batches)\n", + " params, opt_state, loss = update(params, lh, img, subj, opt_state)\n", + " info_bar = [f\"loss : {loss:.4f}\", f\"step : {step} / {n_steps}\"]\n", + " imgs = [img_sample[:6], apply_fn(params, lh, img_sample, subject_idx_sample)[:6],\n", + " train_eval_img[:6], apply_fn(params, train_eval_lh, train_eval_img, train_eval_subj)[:6]]\n", + " syrkis.training.plot_multiples(jnp.concatenate(imgs, axis=0), n_rows=4, info_bar=info_bar, invertable=cfg['dataset'] == 'mnist')\n", + " # monitor eval_bin times throughout training\n", + " if monitor and step % eval_step == 0:\n", + " train_loss, eval_loss = evaluate(params, train_batches, eval_batches)\n", + " wandb.log({\"train_loss\": train_loss, \"eval_loss\": eval_loss})\n", + " wandb.finish()\n", + "\n", + "train(params, opt_state, cfg, train_batches, eval_batches, monitor=False)\n", + "z = conv(params['conv'], img_sample)\n", + "z = z.reshape(cfg['batch_size'], -1)\n", + "print(z.shape)\n", + "print(params['linear_encode'][0][0].shape)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/data.py b/src/data.py index 26d1664..a83f656 100644 --- a/src/data.py +++ b/src/data.py @@ -14,6 +14,10 @@ # functions +def load_subjects(subjects: list, image_size: int=32, precision=jnp.float32) -> dict: + return {subject: load_subject(subject, image_size, precision) for subject in subjects} + + def load_subject(subject: str, image_size: int=32, precision=jnp.float32) -> tuple: path = os.path.join(DATA_DIR, 'algonauts', subject, 'training_split') n_samples = len([f for f in os.listdir(os.path.join(path, 'training_images')) if f.endswith('.png')]) @@ -74,7 +78,7 @@ def get_metadata(image_file: str, metadata_sources: tuple) -> tuple: return coco_id, categories, captions -def make_batches(lh_fmri: list, rh_fmri: list, images: list , batch_size: int) -> tuple: +def make_batches(lh_fmri: list, rh_fmri: list, images: list , subject_idx, batch_size: int, n_subjects: int) -> tuple: lh_fmri = np.array(lh_fmri) rh_fmri = np.array(rh_fmri) images = np.array(images) @@ -82,20 +86,41 @@ def make_batches(lh_fmri: list, rh_fmri: list, images: list , batch_size: int) - perm = np.random.permutation(len(images) // batch_size * batch_size) for i in range(0, len(perm), batch_size): batch_perm = perm[i:i + batch_size] - lh_batch = jnp.array(lh_fmri[batch_perm]) - rh_batch = jnp.array(rh_fmri[batch_perm]) - image_batch = jnp.array(images[batch_perm]) - yield lh_batch, rh_batch, image_batch - - -def make_kfolds(subject: tuple, hyperparams: dict, n_splits: int=5) -> tuple: - lh_fmri, rh_fmri, images, _ = subject + lh_batch = lh_fmri[batch_perm] + #expanded_lh = expand(lh_batch, subject_idx[batch_perm], n_subjects) + rh_batch = rh_fmri[batch_perm] + #expanded_rh = expand(rh_batch, subject_idx[batch_perm], n_subjects) + image_batch = images[batch_perm] + yield lh_batch, rh_batch, image_batch, subject_idx[batch_perm] + +""" def expand(A, v, k): + N, M = A.shape + B = jnp.zeros((N, M, k)) + mask = jnp.arange(k) == v[:, None] + B = jnp.where(mask[:, None, :], A[:, :, None], B) + return B """ + +def combine_subjects(subjects, cfg): + # subjects is a dict of (lh_fmri, rh_fmri, images) tuples + lh_fmri = np.concatenate([subject[0] for subject in subjects.values()]) + rh_fmri = np.concatenate([subject[1] for subject in subjects.values()]) + images = np.concatenate([subject[2] for subject in subjects.values()]) + subject_idx = np.concatenate([np.ones(len(subject[0])) * i for i, subject in enumerate(subjects.values())]) + return lh_fmri, rh_fmri, images, subject_idx + + +def make_kfolds(subjects_data, cfg, n_splits=5): + # subject data is a dict of (lh_fmri, rh_fmri, images) tuples kf = KFold(n_splits=n_splits, shuffle=True, random_state=42) - for train_idx, val_idx in kf.split(images): - train_lh, train_rh = lh_fmri[train_idx], rh_fmri[train_idx] - train_images = [images[i] for i in train_idx] - val_lh, val_rh = lh_fmri[val_idx], rh_fmri[val_idx] - val_images = [images[i] for i in val_idx] - train_batches = make_batches(train_lh, train_rh, train_images, hyperparams['batch_size']) - val_batches = make_batches(val_lh, val_rh, val_images, hyperparams['batch_size']) - yield train_batches, val_batches + combined_subjects = combine_subjects(subjects_data, cfg) + for fold in kf.split(combined_subjects[0]): + train_idx, val_idx = fold + train_lh, train_rh = combined_subjects[0][train_idx], combined_subjects[1][train_idx] + train_images = combined_subjects[2][train_idx] + train_subject_idx = combined_subjects[3][train_idx] + val_lh, val_rh = combined_subjects[0][val_idx], combined_subjects[1][val_idx] + val_images = combined_subjects[2][val_idx] + val_subject_idx = combined_subjects[3][val_idx] + train_batches = make_batches(train_lh, train_rh, train_images, train_subject_idx, cfg['batch_size'], len(subjects_data)) + val_batches = make_batches(val_lh, val_rh, val_images, val_subject_idx, cfg['batch_size'], len(subjects_data)) + yield train_batches, val_batches \ No newline at end of file